diff --git a/.gitignore b/.gitignore index a58eab0b..907b66d0 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,6 @@ venv.bak/ # mypy .mypy_cache/ + +# Generated files: +_version.py diff --git a/xbout/boutdataset.py b/xbout/boutdataset.py index 05788fcc..e2d2f0c1 100644 --- a/xbout/boutdataset.py +++ b/xbout/boutdataset.py @@ -790,7 +790,9 @@ def to_restart( variables : str or sequence of str, optional The evolving variables needed in the restart files. If not given explicitly, all time-evolving variables in the Dataset will be used, which may result in - larger restart files than necessary. + larger restart files than necessary. If there is no time-dimension in the + Dataset (e.g. if it was loaded from restart files), then all variables will + be added if this argument is not given explicitly. savepath : str, default '.' Directory to save the created restart files under nxpe : int, optional diff --git a/xbout/geometries.py b/xbout/geometries.py index da3996be..6774c4fa 100644 --- a/xbout/geometries.py +++ b/xbout/geometries.py @@ -364,7 +364,19 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): # Get extra geometry information from grid file if it's not in the dump files ds = _add_vars_from_grid( - ds, grid, ["psixy", "Rxy", "Zxy"], optional_variables=["Bpxy", "Brxy", "Bzxy"] + ds, + grid, + ["psixy", "Rxy", "Zxy"], + optional_variables=[ + "Bpxy", + "Brxy", + "Bzxy", + "poloidal_distance", + "poloidal_distance_ylow", + "total_poloidal_distance", + "zShift", + "zShift_ylow", + ], ) if "t" in ds.dims: @@ -403,8 +415,23 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): else: ds = ds.set_coords(("Rxy", "Zxy")) + # Rename zShift_ylow if it was added from grid file, to be consistent with name if + # it was added from dump file + if "zShift_CELL_YLOW" in ds and "zShift_ylow" in ds: + # Remove redundant copy + del ds["zShift_ylow"] + elif "zShift_ylow" in ds: + ds = ds.rename(zShift_ylow="zShift_CELL_YLOW") + + if "poloidal_distance" in ds: + ds = ds.set_coords( + ["poloidal_distance", "poloidal_distance_ylow", "total_poloidal_distance"] + ) + # Add zShift as a coordinate, so that it gets interpolated along with a variable ds = _set_as_coord(ds, "zShift") + if "zShift_CELL_YLOW" in ds: + ds = _set_as_coord(ds, "zShift_CELL_YLOW") ds = _create_regions_toroidal(ds) diff --git a/xbout/load.py b/xbout/load.py index 802c643c..e4e5d9b7 100644 --- a/xbout/load.py +++ b/xbout/load.py @@ -260,6 +260,7 @@ def attrs_remove_section(obj, section): chunks=chunks, keep_xboundaries=keep_xboundaries, keep_yboundaries=keep_yboundaries, + **kwargs, ) else: raise ValueError(f"internal error: unexpected input_type={input_type}") @@ -899,7 +900,7 @@ def _get_limit(side, dim, keep_boundaries, boundaries, guards): return limit -def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2): +def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2, **kwargs): """ Opens a single grid file. Implements slightly different logic for boundaries to deal with different conventions in a BOUT grid file. @@ -917,7 +918,9 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2): if _is_path(datapath): gridfilepath = Path(datapath) - grid = xr.open_dataset(gridfilepath, engine=_check_filetype(gridfilepath)) + grid = xr.open_dataset( + gridfilepath, engine=_check_filetype(gridfilepath), **kwargs + ) else: grid = datapath @@ -933,17 +936,28 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2): ) grid = grid.drop_dims(unrecognised_dims) - if not keep_xboundaries: + if keep_xboundaries: + # Set MXG so that it is picked up in metadata - needed for applying geometry, + # etc. + grid["MXG"] = mxg + else: xboundaries = mxg if xboundaries > 0: grid = grid.isel(x=slice(xboundaries, -xboundaries, None)) - if not keep_yboundaries: - try: - yboundaries = int(grid["y_boundary_guards"]) - except KeyError: - # y_boundary_guards variable not in grid file - older grid files - # never had y-boundary cells - yboundaries = 0 + # Set MXG so that it is picked up in metadata - needed for applying geometry, + # etc. + grid["MXG"] = 0 + try: + yboundaries = int(grid["y_boundary_guards"]) + except KeyError: + # y_boundary_guards variable not in grid file - older grid files + # never had y-boundary cells + yboundaries = 0 + if keep_yboundaries: + # Set MYG so that it is picked up in metadata - needed for applying geometry, + # etc. + grid["MYG"] = yboundaries + else: if yboundaries > 0: # Remove y-boundary cells from first divertor target grid = grid.isel(y=slice(yboundaries, -yboundaries, None)) @@ -960,6 +974,9 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2): compat="identical", join="exact", ) + # Set MYG so that it is picked up in metadata - needed for applying geometry, + # etc. + grid["MYG"] = 0 if "z" in grid_chunks and "z" not in grid.dims: del grid_chunks["z"] diff --git a/xbout/region.py b/xbout/region.py index 12391b12..56634f3e 100644 --- a/xbout/region.py +++ b/xbout/region.py @@ -1568,6 +1568,22 @@ def _concat_lower_guards(da, da_global, mxg, myg): da_lower[xcoord].data[...] = new_xcoord.data da_lower[ycoord].data[...] = new_ycoord.data + if "poloidal_distance" in da.coords and myg > 0: + # Special handling for core regions to deal with branch cut + if "core" in region.name: + # import pdb; pdb.set_trace() + # Try to detect whether there is branch cut at lower boundary: if there is + # poloidal_distance_ylow should be zero at the boundary of this region + poloidal_distance_bottom = da["poloidal_distance_ylow"].isel({ycoord: 0}) + if all(abs(poloidal_distance_bottom) < 1.0e-16): + # Offset so that the poloidal_distance in da_lower is continuous from + # the poloidal_distance in this region. + # Expect there to be y-boundary cells in the Dataset, this will probably + # fail if there are not. + total_poloidal_distance = da["total_poloidal_distance"] + da_lower["poloidal_distance"] -= total_poloidal_distance + da_lower["poloidal_distance_ylow"] -= total_poloidal_distance + save_regions = da.bout._regions da = xr.concat((da_lower, da), ycoord, join="exact") # xr.concat takes attributes from the first variable (for xarray>=0.15.0, keeps attrs @@ -1668,6 +1684,24 @@ def _concat_upper_guards(da, da_global, mxg, myg): da_upper[xcoord].data[...] = new_xcoord.data da_upper[ycoord].data[...] = new_ycoord.data + if "poloidal_distance" in da.coords and myg > 0: + # Special handling for core regions to deal with branch cut + if "core" in region.name: + # import pdb; pdb.set_trace() + # Try to detect whether there is branch cut at upper boundary: if there is + # poloidal_distance_ylow should be zero at the boundary of da_upper + poloidal_distance_bottom = da_upper["poloidal_distance_ylow"].isel( + {ycoord: 0} + ) + if all(abs(poloidal_distance_bottom) < 1.0e-16): + # Offset so that the poloidal_distance in da_upper is continuous from + # the poloidal_distance in this region. + # Expect there to be y-boundary cells in the Dataset, this will probably + # fail if there are not. + total_poloidal_distance = da["total_poloidal_distance"] + da_upper["poloidal_distance"] += total_poloidal_distance + da_upper["poloidal_distance_ylow"] += total_poloidal_distance + save_regions = da.bout._regions da = xr.concat((da, da_upper), ycoord, join="exact") # xarray<0.15.0 only keeps attrs that are the same on all variables passed to concat diff --git a/xbout/tests/test_grid.py b/xbout/tests/test_grid.py index b4021bb5..72da6bc8 100644 --- a/xbout/tests/test_grid.py +++ b/xbout/tests/test_grid.py @@ -17,9 +17,9 @@ def create_example_grid_file(tmp_path_factory): """ # Create grid dataset - arr = np.arange(6).reshape(2, 3) + arr = np.arange(15).reshape(5, 3) grid = DataArray(data=arr, name="arr", dims=["x", "y"]).to_dataset() - grid["dy"] = DataArray(np.ones((2, 3)), dims=["x", "y"]) + grid["dy"] = DataArray(np.ones((5, 3)), dims=["x", "y"]) grid = grid.set_coords(["dy"]) # Create temporary directory @@ -44,7 +44,11 @@ def test_open_grid(self, create_example_grid_file): def test_open_grid_extra_dims(self, create_example_grid_file, tmp_path_factory): example_grid = open_dataset(create_example_grid_file) - new_var = DataArray(name="new", data=[[1, 2], [8, 9]], dims=["x", "w"]) + new_var = DataArray( + name="new", + data=[[1, 2], [8, 9], [16, 17], [27, 28], [37, 38]], + dims=["x", "w"], + ) dodgy_grid_directory = tmp_path_factory.mktemp("dodgy_grid") dodgy_grid_path = dodgy_grid_directory.joinpath("dodgy_grid.nc") diff --git a/xbout/utils.py b/xbout/utils.py index 4e398b10..3c67723d 100644 --- a/xbout/utils.py +++ b/xbout/utils.py @@ -344,28 +344,17 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over ny_inner = ds.metadata["ny_inner"] - # These variables need to be saved to restart files in addition to evolving ones - restart_metadata_vars = [ - "zperiod", - "MZSUB", - "MXG", - "MYG", - "MZG", - "nx", - "ny", - "nz", - "MZ", - "NZPE", - "ixseps1", - "ixseps2", - "jyseps1_1", - "jyseps2_1", - "jyseps1_2", - "jyseps2_2", - "ny_inner", - "ZMAX", - "ZMIN", - "BOUT_VERSION", + # These metadata variables are created by xBOUT, so should not be saved to restart + # files + restart_exclude_metadata_vars = [ + "bout_tdim", + "bout_xdim", + "bout_ydim", + "bout_zdim", + "fine_interpolation_factor", + "is_restart", + "keep_xboundaries", + "keep_yboundaries", ] if variables is None: @@ -374,7 +363,8 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over # variables variables = [v for v in ds if "t" in ds[v].dims] else: - # No time dimension, so just save all variables + # No time dimension in Dataset, so cannot distinguish time-evolving + # variables: just include all variables variables = [v for v in ds] # Add extra variables always needed @@ -395,8 +385,9 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over "g_13", "g_23", "J", + "zShift", ]: - if v not in variables: + if v not in variables and v in ds: variables.append(v) # number of points in the domain on each processor, not including guard or boundary @@ -404,16 +395,18 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over mxsub = (ds.metadata["nx"] - 2 * mxg) // nxpe mysub = ds.metadata["ny"] // nype - if "hist_hi" in ds.metadata: - hist_hi = ds.metadata["hist_hi"] - else: - # hist_hi represents the number of iterations before the restart. Attempt to - # reconstruct here - iteration = ds.metadata.get("iteration", -1) + # hist_hi represents the number of iterations before the restart. Attempt to + # reconstruct here + iteration = ds.metadata.get("iteration", -1) + if "t" in ds.dims: nt = ds.sizes["t"] hist_hi = iteration - (nt - tind) if hist_hi < 0: hist_hi = -1 + elif "hist_hi" in ds.metadata: + hist_hi = ds.metadata["hist_hi"] + else: + hist_hi = -1 has_second_divertor = ds.metadata["jyseps2_1"] != ds.metadata["jyseps1_2"] @@ -422,8 +415,11 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over ds = ds.isel({"t": tind}).persist() tt = ds["t"].values.flatten()[0] else: - # If loaded from restart files, "tt" should be a scalar in metadata - tt = ds.metadata["tt"] + if "tt" in ds.metadata: + # If loaded from restart files, "tt" should be a scalar in metadata + tt = ds.metadata["tt"] + else: + tt = 0.0 ds = _pad_x_boundaries(ds) ds = _pad_y_boundaries(ds) @@ -448,8 +444,12 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over data_variable.attrs = {} restart_ds[v] = data_variable - for v in restart_metadata_vars: - restart_ds[v] = ds.metadata[v] + for v in ds.metadata: + if v not in restart_exclude_metadata_vars: + restart_ds[v] = ds.metadata[v] + + # These variables need to be altered, because they depend on the number of + # files and/or the rank of this file. restart_ds["MXSUB"] = mxsub restart_ds["MYSUB"] = mysub restart_ds["NXPE"] = nxpe @@ -457,6 +457,9 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over restart_ds["PE_XIND"] = xproc restart_ds["PE_YIND"] = yproc restart_ds["hist_hi"] = hist_hi + restart_ds["PE_XIND"] = xproc + restart_ds["PE_YIND"] = yproc + restart_ds["MYPE"] = yproc * nxpe + xproc # tt is the simulation time where the restart happens restart_ds["tt"] = tt