Skip to content

Commit f2dd91e

Browse files
authored
Merge pull request #227 from boutproject/load-Br-Bz
Better test for re-applying toroidal coordinates; load Bpxy, Brxy, Bzxy if possible
2 parents 76533af + ce1c2d8 commit f2dd91e

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

xbout/geometries.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None):
5656
ds = _set_attrs_on_all_vars(ds, "geometry", "")
5757
updated_ds = ds
5858
else:
59-
ds = _set_attrs_on_all_vars(ds, "geometry", geometry_name)
60-
6159
try:
6260
add_geometry_coords = REGISTERED_GEOMETRIES[geometry_name]
6361
except KeyError:
@@ -81,6 +79,11 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None):
8179
else:
8280
updated_ds = add_geometry_coords(ds)
8381

82+
# Set "geometry" attribute after adding coordinates, so that functions in
83+
# `REGISTERED_GEOMETRIES` can check if ds.attrs["geometry"] is already defined
84+
# to see if they are being applied for the first time or re-applied.
85+
updated_ds = _set_attrs_on_all_vars(updated_ds, "geometry", geometry_name)
86+
8487
del ds
8588

8689
# Set dimension names if they were not set by add_geometry_coords(). Dimensions
@@ -298,7 +301,7 @@ def _set_default_toroidal_coordinates(coordinates, ds):
298301
return coordinates
299302

300303

301-
def _add_vars_from_grid(ds, grid, variables):
304+
def _add_vars_from_grid(ds, grid, variables, *, optional_variables=None):
302305
# Get extra geometry information from grid file if it's not in the dump files
303306
for v in variables:
304307
if v not in ds:
@@ -317,6 +320,23 @@ def _add_vars_from_grid(ds, grid, variables):
317320
ds[v] = (grid[v].dims, grid[v].values)
318321

319322
_add_attrs_to_var(ds, v)
323+
324+
if optional_variables is not None:
325+
for v in optional_variables:
326+
if v not in ds:
327+
if grid is None:
328+
continue
329+
if v in grid:
330+
# ds[v] = grid[v]
331+
# Work around issue where xarray drops attributes on
332+
# coordinates when a new DataArray is assigned to the
333+
# Dataset, see https://github.com/pydata/xarray/issues/4415
334+
# https://github.com/pydata/xarray/issues/4393
335+
# This way adds as a 'Variable' instead of as a 'DataArray'
336+
ds[v] = (grid[v].dims, grid[v].values)
337+
338+
_add_attrs_to_var(ds, v)
339+
320340
return ds
321341

322342

@@ -325,7 +345,7 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
325345

326346
coordinates = _set_default_toroidal_coordinates(coordinates, ds)
327347

328-
if set(coordinates.values()).issubset(set(ds.coords).union(ds.dims)):
348+
if ds.attrs.get("geometry", None) == "toroidal":
329349
# Loading a Dataset which already had the coordinates created for it
330350
ds = _create_regions_toroidal(ds)
331351
return ds
@@ -343,7 +363,9 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
343363
)
344364

345365
# Get extra geometry information from grid file if it's not in the dump files
346-
ds = _add_vars_from_grid(ds, grid, ["psixy", "Rxy", "Zxy"])
366+
ds = _add_vars_from_grid(
367+
ds, grid, ["psixy", "Rxy", "Zxy"], optional_variables=["Bpxy", "Brxy", "Bzxy"]
368+
)
347369

348370
if "t" in ds.dims:
349371
# Rename 't' if user requested it

0 commit comments

Comments
 (0)