diff --git a/xbout/bout_info.py b/xbout/bout_info.py new file mode 100644 index 00000000..24e91e39 --- /dev/null +++ b/xbout/bout_info.py @@ -0,0 +1,17 @@ +_BOUT_PER_PROC_VARIABLES = [ + "wall_time", + "wtime", + "wtime_rhs", + "wtime_invert", + "wtime_comms", + "wtime_io", + "wtime_per_rhs", + "wtime_per_rhs_e", + "wtime_per_rhs_i", + "PE_XIND", + "PE_YIND", + "MYPE", +] +_BOUT_PER_PROC_VARIABLES_REQUIRED_FROM_RESTARTS = ["hist_hi", "tt"] +_BOUT_TIME_DEPENDENT_META_VARS = ["iteration"] +_BOUT_VARIABLE_ATTRIBUTES = ["cell_location", "direction_y", "direction_z", "bout_type"] diff --git a/xbout/boutdataarray.py b/xbout/boutdataarray.py index 7e57cf1e..9aea1d43 100644 --- a/xbout/boutdataarray.py +++ b/xbout/boutdataarray.py @@ -11,13 +11,16 @@ from xarray import register_dataarray_accessor from .geometries import apply_geometry +from .load import open_boutdataset from .plotting.animate import animate_poloidal, animate_pcolormesh, animate_line from .plotting import plotfuncs from .plotting.utils import _create_norm from .region import _from_region from .utils import ( _add_cartesian_coordinates, - _update_metadata_increased_resolution, + _make_1d_xcoord, + _update_metadata_increased_x_resolution, + _update_metadata_increased_y_resolution, _get_bounding_surfaces, ) @@ -156,7 +159,11 @@ def to_field_aligned(self): f"argument to open_boutdataset()?" ) - result = self._shift_z(self.data[zShift_coord]) + # zShift may have NaNs in the corners. These should not affect any useful + # results, but may cause parts or all of arrays to be filled with NaN, even + # where the entries should not depend on the NaN values. Replace NaN with 0 to + # avoid this. + result = self._shift_z(self.data[zShift_coord].fillna(0.0)) result.attrs["direction_y"] = "Aligned" return result @@ -188,7 +195,11 @@ def from_field_aligned(self): f"argument to open_boutdataset()?" ) - result = self._shift_z(-self.data[zShift_coord]) + # zShift may have NaNs in the corners. These should not affect any useful + # results, but may cause parts or all of arrays to be filled with NaN, even + # where the entries should not depend on the NaN values. Replace NaN with 0 to + # avoid this. + result = self._shift_z(-self.data[zShift_coord].fillna(0.0)) result.attrs["direction_y"] = "Standard" return result @@ -240,6 +251,8 @@ def interpolate_parallel( self, region=None, *, + poloidal_distance=None, + dy=None, n=None, toroidal_points=None, method="cubic", @@ -249,15 +262,29 @@ def interpolate_parallel( Interpolate in the parallel direction to get a higher resolution version of the variable. + Note: when using poloidal_distance for interpolation, have to convert to numpy + arrays for calculation. This means that dask cannot be used to parallelise this + calculation, so may be slow for large Datasets. + Parameters ---------- region : str, optional By default, return a result with all regions interpolated separately and then combined. If an explicit region argument is passed, then return the variable - from only that region. + from only that region. If the DataArray has already been restricted to a + single region, pass `region=False` to skip calling `from_region()` again. + poloidal_distance : 2d array, optional + Poloidal distance values to interpolate to - interpolation is calculated as + a function of poloidal distance along psi contours. Should have the same + radial grid size as the input. If not given, `n` is used instead. + dy : 2d array, optional + New values of `dy`, corresponding to the values of `poloidal_distance`. + Required if `poloidal_distance` is passed. n : int, optional The factor to increase the resolution by. Defaults to the value set by BoutDataset.setupParallelInterp(), or 10 if that has not been called. + If `n` is used, interpolation is onto a linearly spaced grid in grid-index + space. toroidal_points : int or sequence of int, optional If int, number of toroidal points to output, applies a stride to toroidal direction to save memory usage. If sequence of int, the indexes of toroidal @@ -279,11 +306,32 @@ def interpolate_parallel( if region is None: # Call the single-region version of this method for each region, and combine # the results together + + # apply_unfunc of scipy.interp1d() fails if data is a dask array + self.data.load() + if poloidal_distance is None: + poloidal_distance_parts = [None for _ in self._regions] + else: + poloidal_distance.load() + poloidal_distance_parts = [ + poloidal_distance.from_region( + region, with_guards={xcoord: 2, ycoord: 0} + ) + .isel({ycoord: 0}, drop=True) + .data + for region in self._regions + ] parts = [ self.interpolate_parallel( - region, n=n, toroidal_points=toroidal_points, method=method + region, + poloidal_distance=this_poloidal_distance, + n=n, + toroidal_points=toroidal_points, + method=method, ).bout.to_dataset() - for region in self._regions + for (region, this_poloidal_distance) in zip( + self._regions, poloidal_distance_parts + ) ] # 'region' is not the same for all parts, and should not exist in the result, @@ -303,15 +351,16 @@ def interpolate_parallel( result = apply_geometry(result, self.data.geometry) return result[self.data.name] - # Select a particular 'region' and interpolate to higher parallel resolution - da = self.data - region = da.bout._regions[region] + da = self.data.copy() tcoord = da.metadata["bout_tdim"] xcoord = da.metadata["bout_xdim"] ycoord = da.metadata["bout_ydim"] zcoord = da.metadata["bout_zdim"] - da = da.bout.from_region(region.name, with_guards={xcoord: 0, ycoord: 2}) + if region is not False: + # Select a particular 'region' and interpolate to higher parallel resolution + region = da.bout._regions[region] + da = da.bout.from_region(region.name, with_guards={xcoord: 0, ycoord: 2}) if zcoord in da.dims and da.direction_y != "Aligned": aligned_input = False @@ -319,29 +368,277 @@ def interpolate_parallel( else: aligned_input = True - if n is None: - n = self.fine_interpolation_factor + if poloidal_distance is not None: + # apply_unfunc of scipy.interp1d() fails if data is a dask array + da.load() + poloidal_distance.load() + + poloidal_distance = poloidal_distance.copy() + # Need to delete xcoord 'indexer', because it is not present on 'result', so + # would cause an error in apply_ufunc() if it was present. + del poloidal_distance[xcoord] + # Need to delete ycoord to avoid a clash below + del poloidal_distance[ycoord] + + if n is not None: + raise ValueError( + f"poloidal_distance and n cannot both be passed, got " + f"poloidal_distance={poloidal_distance} and n={n}" + ) + if dy is None: + raise ValueError() + + from scipy.interpolate import interp1d + + def y_interp_func( + data, poloidal_distance_in, poloidal_distance_out, method=None + ): + interp_func = interp1d( + poloidal_distance_in, data, kind=method, assume_sorted=True + ) + return interp_func(poloidal_distance_out) + + # Need to give different name to output dimension to avoid clash + new_ycoord = ycoord + "_interpolate_to_new_grid_new_ycoord" + poloidal_distance = poloidal_distance.rename({ycoord: new_ycoord}) + result = xr.apply_ufunc( + y_interp_func, + da, + da["poloidal_distance"], + poloidal_distance, + method, + input_core_dims=[[ycoord], [ycoord], [new_ycoord], []], + output_core_dims=[[new_ycoord]], + exclude_dims=set([ycoord]), + vectorize=True, + dask="parallelized", + dask_gufunc_kwargs={ + "output_sizes": {new_ycoord: poloidal_distance.sizes[new_ycoord]} + }, + ) + # Rename new_ycoord back to ycoord for output + result = result.rename({new_ycoord: ycoord}) + + # Transpose to original dimension order + result = result.transpose(*da.dims) + + result.attrs = da.attrs.copy() + da = result - da = da.chunk({ycoord: None}) + if dy is None: + raise ValueError( + "It is required to pass dy if poloidal_distance is passed" + ) - ny_fine = n * region.ny - dy = (region.yupper - region.ylower) / ny_fine + da = _update_metadata_increased_y_resolution(da) - myg = da.metadata["MYG"] - if da.metadata["keep_yboundaries"] and region.connection_lower_y is None: - ybndry_lower = myg + da["dy"] = dy else: - ybndry_lower = 0 - if da.metadata["keep_yboundaries"] and region.connection_upper_y is None: - ybndry_upper = myg + if n is None: + n = self.fine_interpolation_factor + + da = da.chunk({ycoord: None}) + + ny_fine = n * region.ny + dy = (region.yupper - region.ylower) / ny_fine + + myg = da.metadata["MYG"] + if da.metadata["keep_yboundaries"] and region.connection_lower_y is None: + ybndry_lower = myg + else: + ybndry_lower = 0 + if da.metadata["keep_yboundaries"] and region.connection_upper_y is None: + ybndry_upper = myg + else: + ybndry_upper = 0 + + y_fine = np.linspace( + region.ylower - (ybndry_lower - 0.5) * dy, + region.yupper + (ybndry_upper - 0.5) * dy, + ny_fine + ybndry_lower + ybndry_upper, + ) + + # This prevents da.interp() from being very slow. + # Apparently large attrs (i.e. regions) on a coordinate which is passed as + # an argument to dask.array.map_blocks() slow things down, maybe because + # coordinates are numpy arrays, not dask arrays? + # Slow-down was introduced in d062fa9e75c02fbfdd46e5d1104b9b12f034448f when + # _add_attrs_to_var(updated_ds, ycoord) was added in geometries.py + da[ycoord].attrs = {} + + da = da.interp( + {ycoord: y_fine.data}, + assume_sorted=True, + method=method, + kwargs={"fill_value": "extrapolate"}, + ) + + da = _update_metadata_increased_y_resolution(da, n=n) + + # Modify dy to be consistent with the higher resolution grid + dy_array = xr.DataArray( + np.full([da.sizes[xcoord], da.sizes[ycoord]], dy), dims=[xcoord, ycoord] + ) + da["dy"] = da["dy"].copy(data=dy_array) + + # Remove regions which have incorrect information for the high-resolution + # grid. New regions will be generated when creating a new Dataset in + # BoutDataset.getHighParallelResVars + del da.attrs["regions"] + + if not aligned_input: + # Want output in non-aligned coordinates + da = da.bout.from_field_aligned() + + if toroidal_points is not None and zcoord in da.sizes: + if isinstance(toroidal_points, int): + nz = len(da[zcoord]) + zstride = (nz + toroidal_points - 1) // toroidal_points + da = da.isel(**{zcoord: slice(None, None, zstride)}) + else: + da = da.isel(**{zcoord: toroidal_points}) + + return da + + def interpolate_radial( + self, + region=None, + *, + psi=None, + dx=None, + n=None, + method="cubic", + return_dataset=False, + ): + """ + Interpolate in the parallel direction to get a higher resolution version of the + variable. + + Parameters + ---------- + region : str, optional + By default, return a result with all regions interpolated separately and + then combined. If an explicit region argument is passed, then return the + variable from only that region. If the DataArray has already been restricted + to a single region, pass `region=False` to skip calling `from_region()` + again. + psi : 1d or 2d array, optional + Values of `psixy` to interpolate data to. If not given use `n` instead. If + `psi` is given, it must be a 1d array with psi values for the region if + `region` is passed and otherwise must be a 2d {x,y} array. + dx : 1d array, optional + New values of `dx`, corresponding to the values of `psi`. Required if `psi` + is passed. + n : int, optional + The factor to increase the resolution by. Defaults to the value set by + BoutDataset.setupParallelInterp(), or 10 if that has not been called. + method : str, optional + The interpolation method to use. Options from xarray.DataArray.interp(), + currently: linear, nearest, zero, slinear, quadratic, cubic. Default is + 'cubic'. + return_dataset : bool, optional + If this is set to True, return a Dataset containing this variable as a + member (by default returns a DataArray). Only used when region=None. + + Returns + ------- + A new DataArray containing a high-resolution version of the variable. (If + return_dataset=True, instead returns a Dataset containing the DataArray.) + """ + + if psi is not None and n is not None: + raise ValueError(f"Cannot pass both psi and n, got psi={psi}, n={n}") + + tcoord = self.data.metadata["bout_tdim"] + xcoord = self.data.metadata["bout_xdim"] + ycoord = self.data.metadata["bout_ydim"] + zcoord = self.data.metadata["bout_zdim"] + + if region is None: + # Call the single-region version of this method for each region, and combine + # the results together + if psi is None: + psi_parts = [None for _ in self._regions] + else: + psi_parts = [ + psi.bout.from_region(region, with_guards={xcoord: 2, ycoord: 0}) + .isel({ycoord: 0}, drop=True) + .data + for region in self._regions + ] + parts = [ + self.interpolate_radial( + region, psi=this_psi, n=n, method=method + ).bout.to_dataset() + for (region, this_psi) in zip(self._regions, psi_parts) + ] + + # 'region' is not the same for all parts, and should not exist in the + # result, so delete before merging + for part in parts: + if "region" in part.attrs: + del part.attrs["region"] + if "region" in part[self.data.name].attrs: + del part[self.data.name].attrs["region"] + + result = xr.combine_by_coords(parts, combine_attrs="drop_conflicts") + + _make_1d_xcoord(result) + + if return_dataset: + return result + else: + # Extract the DataArray to return + # Cannot call apply_geometry here, because we have not set ixseps1, + # ixseps2, which are needed to create the 'regions'. + return result[self.data.name] + + da = self.data + + if region is not False: + # Select a particular 'region' and interpolate to higher parallel resolution + region = da.bout._regions[region] + da = da.bout.from_region(region.name, with_guards={xcoord: 2, ycoord: 0}) + + da = da.chunk({xcoord: None}) + + old_psi = da["psi_poloidal"].isel({ycoord: 0}, drop=True).values + + if psi is not None: + if dx is None: + raise ValueError("It is required to pass dx if psi is passed") else: - ybndry_upper = 0 + # Do a rough approximation to the boundary values - expect accurate + # interpolations to be done by passing psi from a new grid file + if n is None: + n = self.fine_interpolation_factor + mxg = da.metadata["MXG"] + if da.metadata["keep_xboundaries"] and region.connection_inner_x is None: + xbndry_lower = mxg + else: + xbndry_lower = 0 + if da.metadata["keep_xboundaries"] and region.connection_outer_x is None: + xbndry_upper = mxg + else: + xbndry_upper = 0 - y_fine = np.linspace( - region.ylower - (ybndry_lower - 0.5) * dy, - region.yupper + (ybndry_upper - 0.5) * dy, - ny_fine + ybndry_lower + ybndry_upper, - ) + nx_fine = n * region.nx + dx = (region.xouter - region.xinner) / nx_fine + + psi = np.linspace( + region.xinner - (xbndry_lower - 0.5) * dx, + region.xouter + (xbndry_upper - 0.5) * dx, + nx_fine + xbndry_lower + xbndry_upper, + ) + + # Modify dx to be consistent with the higher resolution grid + dx_array = xr.full_like(da["dx"], dx) + + # Use psi as a 1d x-coordinate for this interpolation. psixy depends only on x + # in each region (although it may be a different function of x in different + # regions). + del da[xcoord] + da[xcoord] = old_psi # This prevents da.interp() from being very slow. # Apparently large attrs (i.e. regions) on a coordinate which is passed as an @@ -349,44 +646,184 @@ def interpolate_parallel( # are numpy arrays, not dask arrays? # Slow-down was introduced in d062fa9e75c02fbfdd46e5d1104b9b12f034448f when # _add_attrs_to_var(updated_ds, ycoord) was added in geometries.py - da[ycoord].attrs = {} + da[xcoord].attrs = {} da = da.interp( - {ycoord: y_fine}, + {xcoord: psi}, assume_sorted=True, method=method, kwargs={"fill_value": "extrapolate"}, ) - da = _update_metadata_increased_resolution(da, n) + da = _update_metadata_increased_x_resolution(da) - # Modify dy to be consistent with the higher resolution grid - dy_array = xr.DataArray( - np.full([da.sizes[xcoord], da.sizes[ycoord]], dy), dims=[xcoord, ycoord] - ) - da["dy"] = da["dy"].copy(data=dy_array) + da["dx"][:] = dx.broadcast_like(da["dx"]).data # Remove regions which have incorrect information for the high-resolution grid. # New regions will be generated when creating a new Dataset in # BoutDataset.getHighParallelResVars del da.attrs["regions"] - if not aligned_input: - # Want output in non-aligned coordinates - da = da.bout.from_field_aligned() - - if toroidal_points is not None and zcoord in da.sizes: - if isinstance(toroidal_points, int): - nz = len(da[zcoord]) - zstride = (nz + toroidal_points - 1) // toroidal_points - da = da.isel(**{zcoord: slice(None, None, zstride)}) - else: - da = da.isel(**{zcoord: toroidal_points}) + # Remove x-coordinate, will recreate x-coordinate for combined DataArray + del da[xcoord] return da - def add_cartesian_coordinates(self): - return _add_cartesian_coordinates(self.data) + def interpolate_to_new_grid( + self, + new_gridfile, + *, + field_aligned_radial_interpolation=False, + method="cubic", + return_dataset=False, + ): + """ + Interpolate the DataArray onto a new set of grid points, given by a grid file. + + The grid file is asssumed to represent the same equilibrium as the one + associated by the original DataArray, so that psi-values and poloidal distances + along psi-contours of the equilibrium are the same. + + Note: poloidal_distance is used for parallel interpolation inside this method. + For this, have to convert to numpy arrays for calculation. Means that dask + cannot be used to parallelise that part of the calculation, so this method may + be slow for large Datasets. + + Parameters + ---------- + new_gridfile : str, pathlib.Path or Dataset + Path to a new grid file, or grid file opened as a Dataset. + field_aligned_radial_interpolation : bool, default False + If set to True, transform to field-aligned grid for radial interpolation + (parallel interpolation is always on field-aligned grid). Probably less + accurate, at least in some parts of the grid where integrated shear is high, + but may (especially if most of the turbulence is at the outboard midplane) + produce a result that is better field-aligned and so creates less of an + initial transient when restarting. + method : str, optional + The interpolation method to use. Options from xarray.DataArray.interp(), + currently: linear, nearest, zero, slinear, quadratic, cubic. Default is + 'cubic'. + return_dataset : bool, default False + Return the result as a Dataset containing the new DataArray. + """ + + da = self.data + + if not isinstance(new_gridfile, xr.Dataset): + new_gridfile = open_boutdataset( + new_gridfile, + keep_xboundaries=da.metadata["keep_xboundaries"], + keep_yboundaries=da.metadata["keep_yboundaries"], + drop_variables=["theta"], + info=False, + geometry=self.data.geometry, + ) + + xcoord = da.metadata["bout_xdim"] + ycoord = da.metadata["bout_ydim"] + zcoord = da.metadata["bout_zdim"] + + # apply_unfunc() of scipy.interp1d() fails with dask arrays, so load + da.load() + new_gridfile["poloidal_distance"].load() + + parts = [] + for region in self._regions: + # Note, need to set 0 x-guards here. If we include x-guards in the radial + # interpolation, poloidal_distance gets messed up at the edges for the + # parallel interpolation because poloidal_distance does not have to be + # consistent between different regions. + part = da.bout.from_region(region, with_guards={xcoord: 0, ycoord: 2}) + + # Radial interpolation first, because the psi coordinate is 1d (in each + # region), so does not need to be interpolated in y-direction, whereas + # poloidal_distance would need to be interpolated to the original + # DataArray's radial grid points. + psi_part = ( + new_gridfile["psi_poloidal"] + .bout.from_region(region, with_guards={xcoord: 0, ycoord: 0}) + .isel({ycoord: 0}, drop=True) + ) + dx_part = ( + new_gridfile["dx"] + .bout.from_region(region, with_guards={xcoord: 0, ycoord: 0}) + .isel({ycoord: 0}, drop=True) + ) + + if field_aligned_radial_interpolation and zcoord in part.dims: + part = part.bout.to_field_aligned() + + part = part.bout.interpolate_radial( + False, + psi=psi_part, + dx=dx_part, + method=method, + return_dataset=return_dataset, + ) + + poloidal_distance_part = new_gridfile["poloidal_distance"].bout.from_region( + region, with_guards={xcoord: 0, ycoord: 0} + ) + dy_part = new_gridfile["dy"].bout.from_region( + region, with_guards={xcoord: 0, ycoord: 0} + ) + + # apply_unfunc() of scipy.interp1d() fails with dask arrays, so load + part.load() + poloidal_distance_part.load() + + part = part.bout.interpolate_parallel( + False, + poloidal_distance=poloidal_distance_part, + dy=dy_part, + method=method, + return_dataset=return_dataset, + ) + + if field_aligned_radial_interpolation and zcoord in part.dims: + part = part.bout.from_field_aligned() + + # Get theta coordinate from new_gridfile, as interpolated versions may not + # be consistent between different regions. + part["theta"] = poloidal_distance_part["theta"] + + # 'region' is not the same for all parts, and should not exist in the + # result, so delete + if "region" in part.attrs: + del part.attrs["region"] + + parts.append(part.to_dataset(name=self.data.name)) + + result = xr.combine_by_coords(parts, combine_attrs="drop_conflicts") + + # Get attributes from original DataArray, then update for increased resolution + result.attrs = self.data.attrs + + result = _update_metadata_increased_x_resolution( + result, + ixseps1=new_gridfile.metadata["ixseps1"], + ixseps2=new_gridfile.metadata["ixseps2"], + nx=new_gridfile.metadata["nx"], + ) + result = _update_metadata_increased_y_resolution( + result, + jyseps1_1=new_gridfile.metadata["jyseps1_1"], + jyseps2_1=new_gridfile.metadata["jyseps2_1"], + jyseps1_2=new_gridfile.metadata["jyseps1_2"], + jyseps2_2=new_gridfile.metadata["jyseps2_2"], + ny_inner=new_gridfile.metadata["ny_inner"], + ny=new_gridfile.metadata["ny"], + ) + + _make_1d_xcoord(result) + + if return_dataset: + return result + else: + # Extract the DataArray to return + result = apply_geometry(result, self.data.geometry) + return result[self.data.name] def add_cartesian_coordinates(self): """ diff --git a/xbout/boutdataset.py b/xbout/boutdataset.py index 6540fe78..d4636f93 100644 --- a/xbout/boutdataset.py +++ b/xbout/boutdataset.py @@ -18,6 +18,7 @@ from dask.diagnostics import ProgressBar from .geometries import apply_geometry +from .load import open_boutdataset from .plotting.animate import ( animate_poloidal, animate_pcolormesh, @@ -266,8 +267,208 @@ def find_with_dims(first_var, dims): return ds - def add_cartesian_coordinates(self): - return _add_cartesian_coordinates(self.data) + def interpolate_radial(self, variables, **kwargs): + """ + Interpolate in the parallel direction to get a higher resolution version of the + variable. + + Note that the high-resolution variables are all loaded into memory, so most + likely it is necessary to select only a small number. The toroidal_points + argument can also be used to reduce the memory demand. + + Parameters + ---------- + variables : str or sequence of str or ... + The names of the variables to interpolate. If 'variables=...' is passed + explicitly, then interpolate all variables in the Dataset. + psi : 1d array, optional + Values of `psixy` to interpolate data to. If not given use `n` instead. If + `psi` is given, it must be a 1d array with psi values for the region if + `region` is passed and otherwise must be a 2d {x,y} array. + dx : 1d array, optional + New values of `dx`, corresponding to the values of `psi`. Required if `psi` + is passed. + n : int, optional + The factor to increase the resolution by. Defaults to the value set by + BoutDataset.setupParallelInterp(), or 10 if that has not been called. + method : str, optional + The interpolation method to use. Options from xarray.DataArray.interp(), + currently: linear, nearest, zero, slinear, quadratic, cubic. Default is + 'cubic'. + + Returns + ------- + A new Dataset containing a high-resolution versions of the variables. The new + Dataset is a valid BoutDataset, although containing only the specified + variables. + """ + + if variables is ...: + variables = [v for v in self.data] + + if isinstance(variables, str): + variables = [variables] + if isinstance(variables, tuple): + variables = list(variables) + + # Need to start with a Dataset with attrs as merge() drops the attrs of the + # passed-in argument. + # Make sure the first variable has all dimensions so we don't lose any + # coordinates + def find_with_dims(first_var, dims): + if first_var is None: + dims = set(dims) + for v in variables: + if set(self.data[v].dims) == dims: + first_var = v + break + return first_var + + tcoord = self.data.metadata.get("bout_tdim", "t") + zcoord = self.data.metadata.get("bout_zdim", "z") + first_var = find_with_dims(None, self.data.dims) + first_var = find_with_dims(first_var, set(self.data.dims) - set(tcoord)) + first_var = find_with_dims(first_var, set(self.data.dims) - set(zcoord)) + first_var = find_with_dims( + first_var, set(self.data.dims) - set([tcoord, zcoord]) + ) + if first_var is None: + raise ValueError( + f"Could not find variable to interpolate with both " + f"{ds.metadata.get('bout_xdim', 'x')} and " + f"{ds.metadata.get('bout_ydim', 'y')} dimensions" + ) + variables.remove(first_var) + ds = self.data[first_var].bout.interpolate_radial(return_dataset=True, **kwargs) + xcoord = ds.metadata.get("bout_xdim", "x") + ycoord = ds.metadata.get("bout_ydim", "y") + for var in variables: + da = self.data[var] + if xcoord in da.dims and ycoord in da.dims: + ds = ds.merge(da.bout.interpolate_radial(return_dataset=True, **kwargs)) + elif xcoord not in da.dims: + ds[var] = da + # Can't interpolate a variable that depends on x but not y, so just skip + + # Apply geometry + ds = apply_geometry(ds, ds.geometry) + + return ds + + def interpolate_to_new_grid(self, variables, new_gridfile, **kwargs): + """ + Interpolate the DataSet onto a new set of grid points, given by a grid file. + + The grid file is asssumed to represent the same equilibrium as the one + associated by the original DataSet, so that psi-values and poloidal distances + along psi-contours of the equilibrium are the same. + + Parameters + ---------- + variables : str or sequence of str or ... + The names of the variables to interpolate. If 'variables=...' is passed + explicitly, then interpolate all variables in the Dataset. + new_gridfile : str, pathlib.Path or Dataset + Path to a new grid file, or grid file opened as a Dataset. + field_aligned_radial_interpolation : bool, default False + If set to True, transform to field-aligned grid for radial interpolation + (parallel interpolation is always on field-aligned grid). Probably less + accurate, at least in some parts of the grid where integrated shear is high, + but may (especially if most of the turbulence is at the outboard midplane) + produce a result that is better field-aligned and so creates less of an + initial transient when restarting. + method : str, optional + The interpolation method to use. Options from xarray.DataSet.interp(), + currently: linear, nearest, zero, slinear, quadratic, cubic. Default is + 'cubic'. + + Returns + ------- + A new Dataset containing the variables interpolated to the new grid. The new + Dataset is a valid BoutDataset, although containing only the specified + variables. + """ + + print("Interpolating to new grid:") + + if variables is ...: + variables = [v for v in self.data] + + if not isinstance(new_gridfile, xr.Dataset): + new_gridfile = open_boutdataset( + new_gridfile, + keep_xboundaries=self.data.metadata["keep_xboundaries"], + keep_yboundaries=self.data.metadata["keep_yboundaries"], + drop_variables=["theta"], + info=False, + geometry=self.data.geometry, + ) + + if isinstance(variables, str): + variables = [variables] + if isinstance(variables, tuple): + variables = list(variables) + + # Need to start with a Dataset with attrs as merge() drops the attrs of the + # passed-in argument. + # Make sure the first variable has all dimensions so we don't lose any + # coordinates + def find_with_dims(first_var, dims): + if first_var is None: + dims = set(dims) + for v in variables: + if set(self.data[v].dims) == dims: + first_var = v + break + return first_var + + tcoord = self.data.metadata.get("bout_tdim", "t") + zcoord = self.data.metadata.get("bout_zdim", "z") + first_var = find_with_dims(None, self.data.dims) + first_var = find_with_dims(first_var, set(self.data.dims) - set(tcoord)) + first_var = find_with_dims(first_var, set(self.data.dims) - set(zcoord)) + first_var = find_with_dims( + first_var, set(self.data.dims) - set([tcoord, zcoord]) + ) + if first_var is None: + raise ValueError( + f"Could not find variable to interpolate with both " + f"{ds.metadata.get('bout_xdim', 'x')} and " + f"{ds.metadata.get('bout_ydim', 'y')} dimensions" + ) + variables.remove(first_var) + print(first_var) + ds = self.data[first_var].bout.interpolate_to_new_grid( + new_gridfile, return_dataset=True, **kwargs + ) + xcoord = ds.metadata.get("bout_xdim", "x") + ycoord = ds.metadata.get("bout_ydim", "y") + for var in variables: + print(var) + da = self.data[var] + if xcoord in da.dims and ycoord in da.dims: + ds = ds.merge( + da.bout.interpolate_to_new_grid( + new_gridfile, return_dataset=True, **kwargs + ) + ) + elif xcoord in da.dims: + print( + f"{var} depends on x but not y, so do not know how to interpolate " + f"to new grid" + ) + elif ycoord in da.dims: + print( + f"{var} depends on y but not x, so do not know how to interpolate " + f"to new grid" + ) + else: + ds[var] = da + + # Apply geometry + ds = apply_geometry(ds, ds.geometry) + + return ds def integrate_midpoints(self, variable, *, dims=None, cumulative_t=False): """ diff --git a/xbout/geometries.py b/xbout/geometries.py index 1cc6d727..4ee712b0 100644 --- a/xbout/geometries.py +++ b/xbout/geometries.py @@ -7,6 +7,7 @@ from .region import Region, _create_regions_toroidal, _create_single_region from .utils import ( _add_attrs_to_var, + _make_1d_xcoord, _set_attrs_on_all_vars, _set_as_coord, _1d_coord_from_spacing, @@ -139,21 +140,7 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None): updated_ds = updated_ds.drop_vars("t_array") if xcoord not in updated_ds.coords: - # Make index 'x' a coordinate, useful for handling global indexing - # Note we have to use the index value, not the value calculated from 'dx' because - # 'dx' may not be consistent between different regions (e.g. core and PFR). - # For some geometries xcoord may have already been created by - # add_geometry_coords, in which case we do not need this. - nx = updated_ds.dims[xcoord] - - # can't use commented out version, uncommented one works around xarray bug - # removing attrs - # https://github.com/pydata/xarray/issues/4415 - # https://github.com/pydata/xarray/issues/4393 - # updated_ds = updated_ds.assign_coords(**{xcoord: np.arange(nx)}) - updated_ds[xcoord] = (xcoord, np.arange(nx)) - - _add_attrs_to_var(updated_ds, xcoord) + _make_1d_xcoord(updated_ds) if ycoord not in updated_ds.coords: ny = updated_ds.dims[ycoord] diff --git a/xbout/tests/test_utils.py b/xbout/tests/test_utils.py index a7c9b8e1..b5255bcc 100644 --- a/xbout/tests/test_utils.py +++ b/xbout/tests/test_utils.py @@ -6,7 +6,7 @@ from xbout.utils import ( _set_attrs_on_all_vars, - _update_metadata_increased_resolution, + _update_metadata_increased_y_resolution, _1d_coord_from_spacing, ) @@ -54,7 +54,7 @@ def test__update_metadata_increased_resolution(self): "MYSUB": 7, } - da = _update_metadata_increased_resolution(da, 3) + da = _update_metadata_increased_y_resolution(da, n=3) assert da.metadata["jyseps1_1"] == 5 assert da.metadata["jyseps2_1"] == 8 diff --git a/xbout/utils.py b/xbout/utils.py index 9587fcb4..07577b99 100644 --- a/xbout/utils.py +++ b/xbout/utils.py @@ -6,6 +6,8 @@ import numpy as np import xarray as xr +from .bout_info import _BOUT_VARIABLE_ATTRIBUTES + def _set_attrs_on_all_vars(ds, key, attr_data, copy=False): ds.attrs[key] = attr_data @@ -87,7 +89,56 @@ def _separate_metadata(ds): return ds.drop_vars(scalar_vars), metadata -def _update_metadata_increased_resolution(da, n): +def _update_metadata_increased_x_resolution(da, *, ixseps1=None, ixseps2=None, nx=None): + """ + Update the metadata variables to account for a change in x-direction resolution. + + Parameters + ---------- + da : DataArray + The variable to update + ixseps1 : int + The value to give to ixseps1 + ixseps2 : int + The value to give to ixseps2 + nx : int + The value to give to nx + """ + + # Take deepcopy to ensure we do not alter metadata of other variables + da.attrs["metadata"] = deepcopy(da.metadata) + + def set_var(var, value): + if value is None: + da.metadata[var] = -1 + else: + da.metadata[var] = value + + set_var("ixseps1", ixseps1) + set_var("ixseps2", ixseps2) + set_var("nx", nx) + if nx is not None: + da.metadata["MXSUB"] = nx - 2 * da.metadata["MXG"] + + # Update attrs of coordinates to be consistent with da + for coord in da.coords: + da[coord].attrs = {} + _add_attrs_to_var(da, coord) + + return da + + +def _update_metadata_increased_y_resolution( + da, + *, + n=None, + jyseps1_1=None, + jyseps2_1=None, + jyseps1_2=None, + jyseps2_2=None, + ny_inner=None, + ny=None, +): """ Update the metadata variables to account for a y-direction resolution increased by a factor n. @@ -96,29 +147,48 @@ def _update_metadata_increased_resolution(da, n): ---------- da : DataArray The variable to update - n : int - The factor to increase the y-resolution by + n : int, optional + The factor to increase the y-resolution by. If n is not given, y-dependent + metadata variables are set to -1, assuming they will be corrected later. + jyseps1_1, jyseps2_1, jyseps1_2, jyseps2_2, ny_inner, ny : int + Metadata variables for y-grid. Should not be passed if `n` is passed. """ # Take deepcopy to ensure we do not alter metadata of other variables da.attrs["metadata"] = deepcopy(da.metadata) - def update_jyseps(name): + def update_jyseps(name, value): # If any jyseps<=0, need to leave as is if da.metadata[name] > 0: - da.metadata[name] = n * (da.metadata[name] + 1) - 1 - - update_jyseps("jyseps1_1") - update_jyseps("jyseps2_1") - update_jyseps("jyseps1_2") - update_jyseps("jyseps2_2") - - def update_ny(name): - da.metadata[name] = n * da.metadata[name] + if n is None: + if value is None: + da.metadata[name] = -1 + else: + da.metadata[name] = value + else: + if value is not None: + raise ValueError(f"n set, but value also passed to {name}") + da.metadata[name] = n * (da.metadata[name] + 1) - 1 + + update_jyseps("jyseps1_1", jyseps1_1) + update_jyseps("jyseps2_1", jyseps2_1) + update_jyseps("jyseps1_2", jyseps1_2) + update_jyseps("jyseps2_2", jyseps2_2) + + def update_ny(name, value): + if n is None: + if value is None: + da.metadata[name] = -1 + else: + da.metadata[name] = value + else: + if value is not None: + raise ValueError(f"n set, but value also passed to {name}") + da.metadata[name] = n * da.metadata[name] - update_ny("ny") - update_ny("ny_inner") - update_ny("MYSUB") + update_ny("ny", ny) + update_ny("ny_inner", ny_inner) + update_ny("MYSUB", None) # Update attrs of coordinates to be consistent with da for coord in da.coords: @@ -222,6 +292,25 @@ def _1d_coord_from_spacing(spacing, dim, ds=None, *, origin_at=None): return xr.Variable(dim, coord_values) +def _make_1d_xcoord(ds_or_da): + # Make index 'x' a coordinate, useful for handling global indexing + # Note we have to use the index value, not the value calculated from 'dx' because + # 'dx' may not be consistent between different regions (e.g. core and PFR). + # For some geometries xcoord may have already been created by + # add_geometry_coords, in which case we do not need this. + xcoord = ds_or_da.metadata["bout_xdim"] + nx = ds_or_da.dims[xcoord] + + # can't use commented out version, uncommented one works around xarray bug + # removing attrs + # https://github.com/pydata/xarray/issues/4415 + # https://github.com/pydata/xarray/issues/4393 + # updated_ds = updated_ds.assign_coords(**{xcoord: np.arange(nx)}) + ds_or_da[xcoord] = (xcoord, np.arange(nx)) + + _add_attrs_to_var(ds_or_da, xcoord) + + def _add_cartesian_coordinates(ds): # Add Cartesian X and Y coordinates if they do not exist already # Works on either BoutDataset or BoutDataArray @@ -503,8 +592,13 @@ def _split_into_restarts(ds, variables, savepath, nxpe, nype, tind, prefix, over for v in variables: data_variable = ds_slice[v].variable - # delete attrs so we don't try to save metadata dict to restart files - data_variable.attrs = {} + # delete attrs, except for those that were created by BOUT++, so we + # don't try to save metadata dict to restart files + data_variable.attrs = { + k: v + for k, v in data_variable.attrs.items() + if k in _BOUT_VARIABLE_ATTRIBUTES + } restart_ds[v] = data_variable for v in ds.metadata: