Skip to content

Commit 8ca5f1a

Browse files
authored
Merge pull request #232 from boutproject/better-gridfile-restart-handling
Improvements to grid- and restart-file loading
2 parents f2dd91e + 27a75f7 commit 8ca5f1a

File tree

7 files changed

+139
-49
lines changed

7 files changed

+139
-49
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,6 @@ venv.bak/
105105

106106
# mypy
107107
.mypy_cache/
108+
109+
# Generated files:
110+
_version.py

xbout/boutdataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,9 @@ def to_restart(
790790
variables : str or sequence of str, optional
791791
The evolving variables needed in the restart files. If not given explicitly,
792792
all time-evolving variables in the Dataset will be used, which may result in
793-
larger restart files than necessary.
793+
larger restart files than necessary. If there is no time-dimension in the
794+
Dataset (e.g. if it was loaded from restart files), then all variables will
795+
be added if this argument is not given explicitly.
794796
savepath : str, default '.'
795797
Directory to save the created restart files under
796798
nxpe : int, optional

xbout/geometries.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,19 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
364364

365365
# Get extra geometry information from grid file if it's not in the dump files
366366
ds = _add_vars_from_grid(
367-
ds, grid, ["psixy", "Rxy", "Zxy"], optional_variables=["Bpxy", "Brxy", "Bzxy"]
367+
ds,
368+
grid,
369+
["psixy", "Rxy", "Zxy"],
370+
optional_variables=[
371+
"Bpxy",
372+
"Brxy",
373+
"Bzxy",
374+
"poloidal_distance",
375+
"poloidal_distance_ylow",
376+
"total_poloidal_distance",
377+
"zShift",
378+
"zShift_ylow",
379+
],
368380
)
369381

370382
if "t" in ds.dims:
@@ -403,8 +415,23 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
403415
else:
404416
ds = ds.set_coords(("Rxy", "Zxy"))
405417

418+
# Rename zShift_ylow if it was added from grid file, to be consistent with name if
419+
# it was added from dump file
420+
if "zShift_CELL_YLOW" in ds and "zShift_ylow" in ds:
421+
# Remove redundant copy
422+
del ds["zShift_ylow"]
423+
elif "zShift_ylow" in ds:
424+
ds = ds.rename(zShift_ylow="zShift_CELL_YLOW")
425+
426+
if "poloidal_distance" in ds:
427+
ds = ds.set_coords(
428+
["poloidal_distance", "poloidal_distance_ylow", "total_poloidal_distance"]
429+
)
430+
406431
# Add zShift as a coordinate, so that it gets interpolated along with a variable
407432
ds = _set_as_coord(ds, "zShift")
433+
if "zShift_CELL_YLOW" in ds:
434+
ds = _set_as_coord(ds, "zShift_CELL_YLOW")
408435

409436
ds = _create_regions_toroidal(ds)
410437

xbout/load.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def attrs_remove_section(obj, section):
260260
chunks=chunks,
261261
keep_xboundaries=keep_xboundaries,
262262
keep_yboundaries=keep_yboundaries,
263+
**kwargs,
263264
)
264265
else:
265266
raise ValueError(f"internal error: unexpected input_type={input_type}")
@@ -899,7 +900,7 @@ def _get_limit(side, dim, keep_boundaries, boundaries, guards):
899900
return limit
900901

901902

902-
def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2):
903+
def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2, **kwargs):
903904
"""
904905
Opens a single grid file. Implements slightly different logic for
905906
boundaries to deal with different conventions in a BOUT grid file.
@@ -917,7 +918,9 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2):
917918

918919
if _is_path(datapath):
919920
gridfilepath = Path(datapath)
920-
grid = xr.open_dataset(gridfilepath, engine=_check_filetype(gridfilepath))
921+
grid = xr.open_dataset(
922+
gridfilepath, engine=_check_filetype(gridfilepath), **kwargs
923+
)
921924
else:
922925
grid = datapath
923926

@@ -933,17 +936,28 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2):
933936
)
934937
grid = grid.drop_dims(unrecognised_dims)
935938

936-
if not keep_xboundaries:
939+
if keep_xboundaries:
940+
# Set MXG so that it is picked up in metadata - needed for applying geometry,
941+
# etc.
942+
grid["MXG"] = mxg
943+
else:
937944
xboundaries = mxg
938945
if xboundaries > 0:
939946
grid = grid.isel(x=slice(xboundaries, -xboundaries, None))
940-
if not keep_yboundaries:
941-
try:
942-
yboundaries = int(grid["y_boundary_guards"])
943-
except KeyError:
944-
# y_boundary_guards variable not in grid file - older grid files
945-
# never had y-boundary cells
946-
yboundaries = 0
947+
# Set MXG so that it is picked up in metadata - needed for applying geometry,
948+
# etc.
949+
grid["MXG"] = 0
950+
try:
951+
yboundaries = int(grid["y_boundary_guards"])
952+
except KeyError:
953+
# y_boundary_guards variable not in grid file - older grid files
954+
# never had y-boundary cells
955+
yboundaries = 0
956+
if keep_yboundaries:
957+
# Set MYG so that it is picked up in metadata - needed for applying geometry,
958+
# etc.
959+
grid["MYG"] = yboundaries
960+
else:
947961
if yboundaries > 0:
948962
# Remove y-boundary cells from first divertor target
949963
grid = grid.isel(y=slice(yboundaries, -yboundaries, None))
@@ -960,6 +974,9 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2):
960974
compat="identical",
961975
join="exact",
962976
)
977+
# Set MYG so that it is picked up in metadata - needed for applying geometry,
978+
# etc.
979+
grid["MYG"] = 0
963980

964981
if "z" in grid_chunks and "z" not in grid.dims:
965982
del grid_chunks["z"]

xbout/region.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,22 @@ def _concat_lower_guards(da, da_global, mxg, myg):
15681568
da_lower[xcoord].data[...] = new_xcoord.data
15691569
da_lower[ycoord].data[...] = new_ycoord.data
15701570

1571+
if "poloidal_distance" in da.coords and myg > 0:
1572+
# Special handling for core regions to deal with branch cut
1573+
if "core" in region.name:
1574+
# import pdb; pdb.set_trace()
1575+
# Try to detect whether there is branch cut at lower boundary: if there is
1576+
# poloidal_distance_ylow should be zero at the boundary of this region
1577+
poloidal_distance_bottom = da["poloidal_distance_ylow"].isel({ycoord: 0})
1578+
if all(abs(poloidal_distance_bottom) < 1.0e-16):
1579+
# Offset so that the poloidal_distance in da_lower is continuous from
1580+
# the poloidal_distance in this region.
1581+
# Expect there to be y-boundary cells in the Dataset, this will probably
1582+
# fail if there are not.
1583+
total_poloidal_distance = da["total_poloidal_distance"]
1584+
da_lower["poloidal_distance"] -= total_poloidal_distance
1585+
da_lower["poloidal_distance_ylow"] -= total_poloidal_distance
1586+
15711587
save_regions = da.bout._regions
15721588
da = xr.concat((da_lower, da), ycoord, join="exact")
15731589
# xr.concat takes attributes from the first variable (for xarray>=0.15.0, keeps attrs
@@ -1668,6 +1684,24 @@ def _concat_upper_guards(da, da_global, mxg, myg):
16681684
da_upper[xcoord].data[...] = new_xcoord.data
16691685
da_upper[ycoord].data[...] = new_ycoord.data
16701686

1687+
if "poloidal_distance" in da.coords and myg > 0:
1688+
# Special handling for core regions to deal with branch cut
1689+
if "core" in region.name:
1690+
# import pdb; pdb.set_trace()
1691+
# Try to detect whether there is branch cut at upper boundary: if there is
1692+
# poloidal_distance_ylow should be zero at the boundary of da_upper
1693+
poloidal_distance_bottom = da_upper["poloidal_distance_ylow"].isel(
1694+
{ycoord: 0}
1695+
)
1696+
if all(abs(poloidal_distance_bottom) < 1.0e-16):
1697+
# Offset so that the poloidal_distance in da_upper is continuous from
1698+
# the poloidal_distance in this region.
1699+
# Expect there to be y-boundary cells in the Dataset, this will probably
1700+
# fail if there are not.
1701+
total_poloidal_distance = da["total_poloidal_distance"]
1702+
da_upper["poloidal_distance"] += total_poloidal_distance
1703+
da_upper["poloidal_distance_ylow"] += total_poloidal_distance
1704+
16711705
save_regions = da.bout._regions
16721706
da = xr.concat((da, da_upper), ycoord, join="exact")
16731707
# xarray<0.15.0 only keeps attrs that are the same on all variables passed to concat

xbout/tests/test_grid.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ def create_example_grid_file(tmp_path_factory):
1717
"""
1818

1919
# Create grid dataset
20-
arr = np.arange(6).reshape(2, 3)
20+
arr = np.arange(15).reshape(5, 3)
2121
grid = DataArray(data=arr, name="arr", dims=["x", "y"]).to_dataset()
22-
grid["dy"] = DataArray(np.ones((2, 3)), dims=["x", "y"])
22+
grid["dy"] = DataArray(np.ones((5, 3)), dims=["x", "y"])
2323
grid = grid.set_coords(["dy"])
2424

2525
# Create temporary directory
@@ -44,7 +44,11 @@ def test_open_grid(self, create_example_grid_file):
4444
def test_open_grid_extra_dims(self, create_example_grid_file, tmp_path_factory):
4545
example_grid = open_dataset(create_example_grid_file)
4646

47-
new_var = DataArray(name="new", data=[[1, 2], [8, 9]], dims=["x", "w"])
47+
new_var = DataArray(
48+
name="new",
49+
data=[[1, 2], [8, 9], [16, 17], [27, 28], [37, 38]],
50+
dims=["x", "w"],
51+
)
4852

4953
dodgy_grid_directory = tmp_path_factory.mktemp("dodgy_grid")
5054
dodgy_grid_path = dodgy_grid_directory.joinpath("dodgy_grid.nc")

xbout/utils.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -344,28 +344,17 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over
344344

345345
ny_inner = ds.metadata["ny_inner"]
346346

347-
# These variables need to be saved to restart files in addition to evolving ones
348-
restart_metadata_vars = [
349-
"zperiod",
350-
"MZSUB",
351-
"MXG",
352-
"MYG",
353-
"MZG",
354-
"nx",
355-
"ny",
356-
"nz",
357-
"MZ",
358-
"NZPE",
359-
"ixseps1",
360-
"ixseps2",
361-
"jyseps1_1",
362-
"jyseps2_1",
363-
"jyseps1_2",
364-
"jyseps2_2",
365-
"ny_inner",
366-
"ZMAX",
367-
"ZMIN",
368-
"BOUT_VERSION",
347+
# These metadata variables are created by xBOUT, so should not be saved to restart
348+
# files
349+
restart_exclude_metadata_vars = [
350+
"bout_tdim",
351+
"bout_xdim",
352+
"bout_ydim",
353+
"bout_zdim",
354+
"fine_interpolation_factor",
355+
"is_restart",
356+
"keep_xboundaries",
357+
"keep_yboundaries",
369358
]
370359

371360
if variables is None:
@@ -374,7 +363,8 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over
374363
# variables
375364
variables = [v for v in ds if "t" in ds[v].dims]
376365
else:
377-
# No time dimension, so just save all variables
366+
# No time dimension in Dataset, so cannot distinguish time-evolving
367+
# variables: just include all variables
378368
variables = [v for v in ds]
379369

380370
# Add extra variables always needed
@@ -395,25 +385,28 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over
395385
"g_13",
396386
"g_23",
397387
"J",
388+
"zShift",
398389
]:
399-
if v not in variables:
390+
if v not in variables and v in ds:
400391
variables.append(v)
401392

402393
# number of points in the domain on each processor, not including guard or boundary
403394
# points
404395
mxsub = (ds.metadata["nx"] - 2 * mxg) // nxpe
405396
mysub = ds.metadata["ny"] // nype
406397

407-
if "hist_hi" in ds.metadata:
408-
hist_hi = ds.metadata["hist_hi"]
409-
else:
410-
# hist_hi represents the number of iterations before the restart. Attempt to
411-
# reconstruct here
412-
iteration = ds.metadata.get("iteration", -1)
398+
# hist_hi represents the number of iterations before the restart. Attempt to
399+
# reconstruct here
400+
iteration = ds.metadata.get("iteration", -1)
401+
if "t" in ds.dims:
413402
nt = ds.sizes["t"]
414403
hist_hi = iteration - (nt - tind)
415404
if hist_hi < 0:
416405
hist_hi = -1
406+
elif "hist_hi" in ds.metadata:
407+
hist_hi = ds.metadata["hist_hi"]
408+
else:
409+
hist_hi = -1
417410

418411
has_second_divertor = ds.metadata["jyseps2_1"] != ds.metadata["jyseps1_2"]
419412

@@ -422,8 +415,11 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over
422415
ds = ds.isel({"t": tind}).persist()
423416
tt = ds["t"].values.flatten()[0]
424417
else:
425-
# If loaded from restart files, "tt" should be a scalar in metadata
426-
tt = ds.metadata["tt"]
418+
if "tt" in ds.metadata:
419+
# If loaded from restart files, "tt" should be a scalar in metadata
420+
tt = ds.metadata["tt"]
421+
else:
422+
tt = 0.0
427423

428424
ds = _pad_x_boundaries(ds)
429425
ds = _pad_y_boundaries(ds)
@@ -448,15 +444,22 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over
448444
data_variable.attrs = {}
449445

450446
restart_ds[v] = data_variable
451-
for v in restart_metadata_vars:
452-
restart_ds[v] = ds.metadata[v]
447+
for v in ds.metadata:
448+
if v not in restart_exclude_metadata_vars:
449+
restart_ds[v] = ds.metadata[v]
450+
451+
# These variables need to be altered, because they depend on the number of
452+
# files and/or the rank of this file.
453453
restart_ds["MXSUB"] = mxsub
454454
restart_ds["MYSUB"] = mysub
455455
restart_ds["NXPE"] = nxpe
456456
restart_ds["NYPE"] = nype
457457
restart_ds["PE_XIND"] = xproc
458458
restart_ds["PE_YIND"] = yproc
459459
restart_ds["hist_hi"] = hist_hi
460+
restart_ds["PE_XIND"] = xproc
461+
restart_ds["PE_YIND"] = yproc
462+
restart_ds["MYPE"] = yproc * nxpe + xproc
460463

461464
# tt is the simulation time where the restart happens
462465
restart_ds["tt"] = tt

0 commit comments

Comments
 (0)