Skip to content
Merged
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,6 @@ venv.bak/

# mypy
.mypy_cache/

# Generated files:
_version.py
4 changes: 3 additions & 1 deletion xbout/boutdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,9 @@ def to_restart(
variables : str or sequence of str, optional
The evolving variables needed in the restart files. If not given explicitly,
all time-evolving variables in the Dataset will be used, which may result in
larger restart files than necessary.
larger restart files than necessary. If there is no time-dimension in the
Dataset (e.g. if it was loaded from restart files), then all variables will
be added if this argument is not given explicitly.
savepath : str, default '.'
Directory to save the created restart files under
nxpe : int, optional
Expand Down
29 changes: 28 additions & 1 deletion xbout/geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,19 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):

# Get extra geometry information from grid file if it's not in the dump files
ds = _add_vars_from_grid(
ds, grid, ["psixy", "Rxy", "Zxy"], optional_variables=["Bpxy", "Brxy", "Bzxy"]
ds,
grid,
["psixy", "Rxy", "Zxy"],
optional_variables=[
"Bpxy",
"Brxy",
"Bzxy",
"poloidal_distance",
"poloidal_distance_ylow",
"total_poloidal_distance",
"zShift",
"zShift_ylow",
],
)

if "t" in ds.dims:
Expand Down Expand Up @@ -403,8 +415,23 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
else:
ds = ds.set_coords(("Rxy", "Zxy"))

# Rename zShift_ylow if it was added from grid file, to be consistent with name if
# it was added from dump file
if "zShift_CELL_YLOW" in ds and "zShift_ylow" in ds:
# Remove redundant copy
del ds["zShift_ylow"]
elif "zShift_ylow" in ds:
ds = ds.rename(zShift_ylow="zShift_CELL_YLOW")

if "poloidal_distance" in ds:
ds = ds.set_coords(
["poloidal_distance", "poloidal_distance_ylow", "total_poloidal_distance"]
)

# Add zShift as a coordinate, so that it gets interpolated along with a variable
ds = _set_as_coord(ds, "zShift")
if "zShift_CELL_YLOW" in ds:
ds = _set_as_coord(ds, "zShift_CELL_YLOW")

ds = _create_regions_toroidal(ds)

Expand Down
37 changes: 27 additions & 10 deletions xbout/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def attrs_remove_section(obj, section):
chunks=chunks,
keep_xboundaries=keep_xboundaries,
keep_yboundaries=keep_yboundaries,
**kwargs,
)
else:
raise ValueError(f"internal error: unexpected input_type={input_type}")
Expand Down Expand Up @@ -899,7 +900,7 @@ def _get_limit(side, dim, keep_boundaries, boundaries, guards):
return limit


def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2):
def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2, **kwargs):
"""
Opens a single grid file. Implements slightly different logic for
boundaries to deal with different conventions in a BOUT grid file.
Expand All @@ -917,7 +918,9 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2):

if _is_path(datapath):
gridfilepath = Path(datapath)
grid = xr.open_dataset(gridfilepath, engine=_check_filetype(gridfilepath))
grid = xr.open_dataset(
gridfilepath, engine=_check_filetype(gridfilepath), **kwargs
)
else:
grid = datapath

Expand All @@ -933,17 +936,28 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2):
)
grid = grid.drop_dims(unrecognised_dims)

if not keep_xboundaries:
if keep_xboundaries:
# Set MXG so that it is picked up in metadata - needed for applying geometry,
# etc.
grid["MXG"] = mxg
else:
xboundaries = mxg
if xboundaries > 0:
grid = grid.isel(x=slice(xboundaries, -xboundaries, None))
if not keep_yboundaries:
try:
yboundaries = int(grid["y_boundary_guards"])
except KeyError:
# y_boundary_guards variable not in grid file - older grid files
# never had y-boundary cells
yboundaries = 0
# Set MXG so that it is picked up in metadata - needed for applying geometry,
# etc.
grid["MXG"] = 0
try:
yboundaries = int(grid["y_boundary_guards"])
except KeyError:
# y_boundary_guards variable not in grid file - older grid files
# never had y-boundary cells
yboundaries = 0
if keep_yboundaries:
# Set MYG so that it is picked up in metadata - needed for applying geometry,
# etc.
grid["MYG"] = yboundaries
else:
if yboundaries > 0:
# Remove y-boundary cells from first divertor target
grid = grid.isel(y=slice(yboundaries, -yboundaries, None))
Expand All @@ -960,6 +974,9 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2):
compat="identical",
join="exact",
)
# Set MYG so that it is picked up in metadata - needed for applying geometry,
# etc.
grid["MYG"] = 0

if "z" in grid_chunks and "z" not in grid.dims:
del grid_chunks["z"]
Expand Down
34 changes: 34 additions & 0 deletions xbout/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,22 @@ def _concat_lower_guards(da, da_global, mxg, myg):
da_lower[xcoord].data[...] = new_xcoord.data
da_lower[ycoord].data[...] = new_ycoord.data

if "poloidal_distance" in da.coords and myg > 0:
# Special handling for core regions to deal with branch cut
if "core" in region.name:
# import pdb; pdb.set_trace()
# Try to detect whether there is branch cut at lower boundary: if there is
# poloidal_distance_ylow should be zero at the boundary of this region
poloidal_distance_bottom = da["poloidal_distance_ylow"].isel({ycoord: 0})
if all(abs(poloidal_distance_bottom) < 1.0e-16):
# Offset so that the poloidal_distance in da_lower is continuous from
# the poloidal_distance in this region.
# Expect there to be y-boundary cells in the Dataset, this will probably
# fail if there are not.
total_poloidal_distance = da["total_poloidal_distance"]
da_lower["poloidal_distance"] -= total_poloidal_distance
da_lower["poloidal_distance_ylow"] -= total_poloidal_distance

save_regions = da.bout._regions
da = xr.concat((da_lower, da), ycoord, join="exact")
# xr.concat takes attributes from the first variable (for xarray>=0.15.0, keeps attrs
Expand Down Expand Up @@ -1668,6 +1684,24 @@ def _concat_upper_guards(da, da_global, mxg, myg):
da_upper[xcoord].data[...] = new_xcoord.data
da_upper[ycoord].data[...] = new_ycoord.data

if "poloidal_distance" in da.coords and myg > 0:
# Special handling for core regions to deal with branch cut
if "core" in region.name:
# import pdb; pdb.set_trace()
# Try to detect whether there is branch cut at upper boundary: if there is
# poloidal_distance_ylow should be zero at the boundary of da_upper
poloidal_distance_bottom = da_upper["poloidal_distance_ylow"].isel(
{ycoord: 0}
)
if all(abs(poloidal_distance_bottom) < 1.0e-16):
# Offset so that the poloidal_distance in da_upper is continuous from
# the poloidal_distance in this region.
# Expect there to be y-boundary cells in the Dataset, this will probably
# fail if there are not.
total_poloidal_distance = da["total_poloidal_distance"]
da_upper["poloidal_distance"] += total_poloidal_distance
da_upper["poloidal_distance_ylow"] += total_poloidal_distance

save_regions = da.bout._regions
da = xr.concat((da, da_upper), ycoord, join="exact")
# xarray<0.15.0 only keeps attrs that are the same on all variables passed to concat
Expand Down
10 changes: 7 additions & 3 deletions xbout/tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def create_example_grid_file(tmp_path_factory):
"""

# Create grid dataset
arr = np.arange(6).reshape(2, 3)
arr = np.arange(15).reshape(5, 3)
grid = DataArray(data=arr, name="arr", dims=["x", "y"]).to_dataset()
grid["dy"] = DataArray(np.ones((2, 3)), dims=["x", "y"])
grid["dy"] = DataArray(np.ones((5, 3)), dims=["x", "y"])
grid = grid.set_coords(["dy"])

# Create temporary directory
Expand All @@ -44,7 +44,11 @@ def test_open_grid(self, create_example_grid_file):
def test_open_grid_extra_dims(self, create_example_grid_file, tmp_path_factory):
example_grid = open_dataset(create_example_grid_file)

new_var = DataArray(name="new", data=[[1, 2], [8, 9]], dims=["x", "w"])
new_var = DataArray(
name="new",
data=[[1, 2], [8, 9], [16, 17], [27, 28], [37, 38]],
dims=["x", "w"],
)

dodgy_grid_directory = tmp_path_factory.mktemp("dodgy_grid")
dodgy_grid_path = dodgy_grid_directory.joinpath("dodgy_grid.nc")
Expand Down
71 changes: 37 additions & 34 deletions xbout/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,28 +344,17 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over

ny_inner = ds.metadata["ny_inner"]

# These variables need to be saved to restart files in addition to evolving ones
restart_metadata_vars = [
"zperiod",
"MZSUB",
"MXG",
"MYG",
"MZG",
"nx",
"ny",
"nz",
"MZ",
"NZPE",
"ixseps1",
"ixseps2",
"jyseps1_1",
"jyseps2_1",
"jyseps1_2",
"jyseps2_2",
"ny_inner",
"ZMAX",
"ZMIN",
"BOUT_VERSION",
# These metadata variables are created by xBOUT, so should not be saved to restart
# files
restart_exclude_metadata_vars = [
"bout_tdim",
"bout_xdim",
"bout_ydim",
"bout_zdim",
"fine_interpolation_factor",
"is_restart",
"keep_xboundaries",
"keep_yboundaries",
]

if variables is None:
Expand All @@ -374,7 +363,8 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over
# variables
variables = [v for v in ds if "t" in ds[v].dims]
else:
# No time dimension, so just save all variables
# No time dimension in Dataset, so cannot distinguish time-evolving
# variables: just include all variables
variables = [v for v in ds]

# Add extra variables always needed
Expand All @@ -395,25 +385,28 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over
"g_13",
"g_23",
"J",
"zShift",
]:
if v not in variables:
if v not in variables and v in ds:
variables.append(v)

# number of points in the domain on each processor, not including guard or boundary
# points
mxsub = (ds.metadata["nx"] - 2 * mxg) // nxpe
mysub = ds.metadata["ny"] // nype

if "hist_hi" in ds.metadata:
hist_hi = ds.metadata["hist_hi"]
else:
# hist_hi represents the number of iterations before the restart. Attempt to
# reconstruct here
iteration = ds.metadata.get("iteration", -1)
# hist_hi represents the number of iterations before the restart. Attempt to
# reconstruct here
iteration = ds.metadata.get("iteration", -1)
if "t" in ds.dims:
nt = ds.sizes["t"]
hist_hi = iteration - (nt - tind)
if hist_hi < 0:
hist_hi = -1
elif "hist_hi" in ds.metadata:
hist_hi = ds.metadata["hist_hi"]
else:
hist_hi = -1

has_second_divertor = ds.metadata["jyseps2_1"] != ds.metadata["jyseps1_2"]

Expand All @@ -422,8 +415,11 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over
ds = ds.isel({"t": tind}).persist()
tt = ds["t"].values.flatten()[0]
else:
# If loaded from restart files, "tt" should be a scalar in metadata
tt = ds.metadata["tt"]
if "tt" in ds.metadata:
# If loaded from restart files, "tt" should be a scalar in metadata
tt = ds.metadata["tt"]
else:
tt = 0.0

ds = _pad_x_boundaries(ds)
ds = _pad_y_boundaries(ds)
Expand All @@ -448,15 +444,22 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over
data_variable.attrs = {}

restart_ds[v] = data_variable
for v in restart_metadata_vars:
restart_ds[v] = ds.metadata[v]
for v in ds.metadata:
if v not in restart_exclude_metadata_vars:
restart_ds[v] = ds.metadata[v]

# These variables need to be altered, because they depend on the number of
# files and/or the rank of this file.
restart_ds["MXSUB"] = mxsub
restart_ds["MYSUB"] = mysub
restart_ds["NXPE"] = nxpe
restart_ds["NYPE"] = nype
restart_ds["PE_XIND"] = xproc
restart_ds["PE_YIND"] = yproc
restart_ds["hist_hi"] = hist_hi
restart_ds["PE_XIND"] = xproc
restart_ds["PE_YIND"] = yproc
restart_ds["MYPE"] = yproc * nxpe + xproc

# tt is the simulation time where the restart happens
restart_ds["tt"] = tt
Expand Down