diff --git a/xbout/load.py b/xbout/load.py index 3a82a617..bd36497e 100644 --- a/xbout/load.py +++ b/xbout/load.py @@ -1031,7 +1031,11 @@ def _trim(ds, *, guards, keep_boundaries, nxpe, nype, is_restart): ): trimmed_ds = trimmed_ds.drop_vars(name) - to_drop = _BOUT_PER_PROC_VARIABLES + if ds["MYPE"] == 0: + # Keep per-process variables from the root process + to_drop = None + else: + to_drop = _BOUT_PER_PROC_VARIABLES return trimmed_ds.drop_vars(to_drop, errors="ignore") diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index e115d434..5f944c16 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -752,9 +752,9 @@ def create_or_update_plot(plot_objects=None, tind=None, this_save_as=None): X, Y, Z, scalars=data, vmin=vmin, vmax=vmax, **kwargs ) else: - plot_objects[ - region_name + str(i) - ].mlab_source.scalars = data + plot_objects[region_name + str(i)].mlab_source.scalars = ( + data + ) if mayavi_view is not None: mlab.view(*mayavi_view) diff --git a/xbout/tests/test_against_collect.py b/xbout/tests/test_against_collect.py index 5f22cf97..a1bf3da3 100644 --- a/xbout/tests/test_against_collect.py +++ b/xbout/tests/test_against_collect.py @@ -220,5 +220,4 @@ def test_new_collect_indexing_slice(self, tmp_path_factory): @pytest.mark.skip -class test_speed_against_old_collect: - ... +class test_speed_against_old_collect: ... diff --git a/xbout/tests/test_load.py b/xbout/tests/test_load.py index d8766236..bb4c917e 100644 --- a/xbout/tests/test_load.py +++ b/xbout/tests/test_load.py @@ -472,8 +472,7 @@ def test_combine_along_y(self, tmp_path_factory, bout_xyt_example_files): xrt.assert_identical(actual, fake) @pytest.mark.skip - def test_combine_along_t(self): - ... + def test_combine_along_t(self): ... @pytest.mark.parametrize( "bout_v5,metric_3D", [(False, False), (True, False), (True, True)] @@ -623,8 +622,7 @@ def test_drop_vars(self, tmp_path_factory, bout_xyt_example_files): assert "n" in ds.keys() @pytest.mark.skip - def test_combine_along_tx(self): - ... + def test_combine_along_tx(self): ... def test_restarts(self): datapath = Path(__file__).parent.joinpath( diff --git a/xbout/utils.py b/xbout/utils.py index 32be7edc..66dd9593 100644 --- a/xbout/utils.py +++ b/xbout/utils.py @@ -167,12 +167,16 @@ def _1d_coord_from_spacing(spacing, dim, ds=None, *, origin_at=None): ) point_to_use = { - spacing.metadata["bout_xdim"]: spacing.metadata.get("MXG", 0) - if spacing.metadata["keep_xboundaries"] - else 0, - spacing.metadata["bout_ydim"]: spacing.metadata.get("MYG", 0) - if spacing.metadata["keep_yboundaries"] - else 0, + spacing.metadata["bout_xdim"]: ( + spacing.metadata.get("MXG", 0) + if spacing.metadata["keep_xboundaries"] + else 0 + ), + spacing.metadata["bout_ydim"]: ( + spacing.metadata.get("MYG", 0) + if spacing.metadata["keep_yboundaries"] + else 0 + ), spacing.metadata["bout_zdim"]: spacing.metadata.get("MZG", 0), }