Skip to content

Commit b2cc9ef

Browse files
committed
Allow geometry to be changed when reloading an xBOUT-saved Dataset
If a Dataset is loaded with xBOUT using one geometry, and then saved (e.g. using geometry=None just to squash the data into a single file), and then reloaded, previously the geometry was re-applied from whatever was set at the first load. Now it is possible to override the geometry while reloading (e.g. to apply toroidal geometry when opening the squashed file).
1 parent fa3ab8b commit b2cc9ef

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

xbout/load.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,26 @@ def attrs_remove_section(obj, section):
175175
ds = _add_options(ds, inputfilepath)
176176

177177
# If geometry was set, apply geometry again
178-
if "geometry" in ds.attrs:
178+
if geometry is not None:
179+
if "geometry" != ds.attrs.get("geometry", None):
180+
warn(
181+
f'open_boutdataset() called with geometry="{geometry}", but we are '
182+
f"reloading a Dataset that was saved after being loaded with "
183+
f'geometry="{ds.attrs.get("geometry", None)}". Applying '
184+
f'geometry="{geometry}" from the argument.'
185+
)
186+
if gridfilepath is not None:
187+
grid = _open_grid(
188+
gridfilepath,
189+
chunks=chunks,
190+
keep_xboundaries=keep_xboundaries,
191+
keep_yboundaries=keep_yboundaries,
192+
mxg=ds.metadata["MXG"],
193+
)
194+
else:
195+
grid = None
196+
ds = geometries.apply_geometry(ds, geometry, grid=grid)
197+
elif "geometry" in ds.attrs:
179198
ds = geometries.apply_geometry(ds, ds.attrs["geometry"])
180199
else:
181200
ds = geometries.apply_geometry(ds, None)

xbout/tests/test_boutdataset.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,20 +1378,12 @@ def test_save_all(self, tmpdir_factory, bout_xyt_example_files):
13781378

13791379
@pytest.mark.parametrize("geometry", [None, "toroidal"])
13801380
def test_reload_all(self, tmpdir_factory, bout_xyt_example_files, geometry):
1381-
if geometry is not None:
1382-
grid = "grid"
1383-
else:
1384-
grid = None
1385-
13861381
# Create data
13871382
path = bout_xyt_example_files(
1388-
tmpdir_factory, nxpe=4, nype=5, nt=1, grid=grid, write_to_disk=True
1383+
tmpdir_factory, nxpe=4, nype=5, nt=1, grid="grid", write_to_disk=True
13891384
)
13901385

1391-
if grid is not None:
1392-
gridpath = str(Path(path).parent) + "/grid.nc"
1393-
else:
1394-
gridpath = None
1386+
gridpath = str(Path(path).parent) + "/grid.nc"
13951387

13961388
# Load it as a boutdataset
13971389
if geometry is None:
@@ -1400,14 +1392,14 @@ def test_reload_all(self, tmpdir_factory, bout_xyt_example_files, geometry):
14001392
datapath=path,
14011393
inputfilepath=None,
14021394
geometry=geometry,
1403-
gridfilepath=gridpath,
1395+
gridfilepath=None if geometry is None else gridpath,
14041396
)
14051397
else:
14061398
original = open_boutdataset(
14071399
datapath=path,
14081400
inputfilepath=None,
14091401
geometry=geometry,
1410-
gridfilepath=gridpath,
1402+
gridfilepath=None if geometry is None else gridpath,
14111403
)
14121404

14131405
# Save it to a netCDF file
@@ -1419,6 +1411,25 @@ def test_reload_all(self, tmpdir_factory, bout_xyt_example_files, geometry):
14191411

14201412
xrt.assert_identical(original.load(), recovered.load())
14211413

1414+
# Check if we can load with a different geometry argument
1415+
for reload_geometry in [None, "toroidal"]:
1416+
if reload_geometry is None or geometry == reload_geometry:
1417+
recovered = open_boutdataset(
1418+
savepath,
1419+
geometry=reload_geometry,
1420+
gridfilepath=None if reload_geometry is None else gridpath,
1421+
)
1422+
xrt.assert_identical(original.load(), recovered.load())
1423+
else:
1424+
# Expect a warning because we change the geometry
1425+
print("here", gridpath)
1426+
with pytest.warns(UserWarning):
1427+
recovered = open_boutdataset(
1428+
savepath, geometry=reload_geometry, gridfilepath=gridpath
1429+
)
1430+
# Datasets won't be exactly the same because different geometry was
1431+
# applied
1432+
14221433
@pytest.mark.parametrize("save_dtype", [np.float64, np.float32])
14231434
@pytest.mark.parametrize(
14241435
"separate_vars", [False, pytest.param(True, marks=pytest.mark.long)]

0 commit comments

Comments
 (0)