@@ -56,8 +56,6 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None):
5656 ds = _set_attrs_on_all_vars (ds , "geometry" , "" )
5757 updated_ds = ds
5858 else :
59- ds = _set_attrs_on_all_vars (ds , "geometry" , geometry_name )
60-
6159 try :
6260 add_geometry_coords = REGISTERED_GEOMETRIES [geometry_name ]
6361 except KeyError :
@@ -81,6 +79,11 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None):
8179 else :
8280 updated_ds = add_geometry_coords (ds )
8381
82+ # Set "geometry" attribute after adding coordinates, so that functions in
83+ # `REGISTERED_GEOMETRIES` can check if ds.attrs["geometry"] is already defined
84+ # to see if they are being applied for the first time or re-applied.
85+ updated_ds = _set_attrs_on_all_vars (updated_ds , "geometry" , geometry_name )
86+
8487 del ds
8588
8689 # Set dimension names if they were not set by add_geometry_coords(). Dimensions
@@ -298,7 +301,7 @@ def _set_default_toroidal_coordinates(coordinates, ds):
298301 return coordinates
299302
300303
301- def _add_vars_from_grid (ds , grid , variables ):
304+ def _add_vars_from_grid (ds , grid , variables , * , optional_variables = None ):
302305 # Get extra geometry information from grid file if it's not in the dump files
303306 for v in variables :
304307 if v not in ds :
@@ -317,6 +320,23 @@ def _add_vars_from_grid(ds, grid, variables):
317320 ds [v ] = (grid [v ].dims , grid [v ].values )
318321
319322 _add_attrs_to_var (ds , v )
323+
324+ if optional_variables is not None :
325+ for v in optional_variables :
326+ if v not in ds :
327+ if grid is None :
328+ continue
329+ if v in grid :
330+ # ds[v] = grid[v]
331+ # Work around issue where xarray drops attributes on
332+ # coordinates when a new DataArray is assigned to the
333+ # Dataset, see https://github.com/pydata/xarray/issues/4415
334+ # https://github.com/pydata/xarray/issues/4393
335+ # This way adds as a 'Variable' instead of as a 'DataArray'
336+ ds [v ] = (grid [v ].dims , grid [v ].values )
337+
338+ _add_attrs_to_var (ds , v )
339+
320340 return ds
321341
322342
@@ -325,7 +345,7 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
325345
326346 coordinates = _set_default_toroidal_coordinates (coordinates , ds )
327347
328- if set ( coordinates . values ()). issubset ( set ( ds .coords ). union ( ds . dims )) :
348+ if ds .attrs . get ( "geometry" , None ) == "toroidal" :
329349 # Loading a Dataset which already had the coordinates created for it
330350 ds = _create_regions_toroidal (ds )
331351 return ds
@@ -343,7 +363,9 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
343363 )
344364
345365 # Get extra geometry information from grid file if it's not in the dump files
346- ds = _add_vars_from_grid (ds , grid , ["psixy" , "Rxy" , "Zxy" ])
366+ ds = _add_vars_from_grid (
367+ ds , grid , ["psixy" , "Rxy" , "Zxy" ], optional_variables = ["Bpxy" , "Brxy" , "Bzxy" ]
368+ )
347369
348370 if "t" in ds .dims :
349371 # Rename 't' if user requested it
0 commit comments