diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 91cc04c06..07753093f 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -17,6 +17,7 @@ from parcels._core.xgrid import _DEFAULT_XGCM_KWARGS, XGrid from parcels._logger import logger from parcels._typing import Mesh +from parcels.interpolators import XConstantField if TYPE_CHECKING: from parcels._core.basegrid import BaseGrid @@ -116,7 +117,7 @@ def add_field(self, field: Field, name: str | None = None): self.fields[name] = field - def add_constant_field(self, name: str, value, mesh: Mesh = "flat"): + def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"): """Wrapper function to add a Field that is constant in space, useful e.g. when using constant horizontal diffusivity @@ -134,16 +135,15 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"): correction for zonal velocity U near the poles. 2. flat: No conversion, lat/lon are assumed to be in m. """ - ds = xr.Dataset({name: (["time", "lat", "lon", "depth"], np.full((1, 1, 1, 1), value))}) - grid = XGrid(xgcm.Grid(ds, **_DEFAULT_XGCM_KWARGS)) - self.add_field( - Field( - name, - ds[name], - grid, - interp_method=None, # TODO : Need to define an interpolation method for constants - ) + ds = xr.Dataset( + {name: (["lat", "lon"], np.full((1, 1), value))}, + coords={"lat": (["lat"], [0], {"axis": "Y"}), "lon": (["lon"], [0], {"axis": "X"})}, + ) + xgrid = xgcm.Grid( + ds, coords={"X": {"left": "lon"}, "Y": {"left": "lat"}}, autoparse_metadata=False, **_DEFAULT_XGCM_KWARGS ) + grid = XGrid(xgrid, mesh=mesh) + self.add_field(Field(name, ds[name], grid, interp_method=XConstantField)) def add_constant(self, name, value): """Add a constant to the FieldSet. Note that all constants are diff --git a/src/parcels/interpolators.py b/src/parcels/interpolators.py index 51e16510f..fdfbb6211 100644 --- a/src/parcels/interpolators.py +++ b/src/parcels/interpolators.py @@ -20,6 +20,7 @@ "CGrid_Velocity", "UXPiecewiseConstantFace", "UXPiecewiseLinearNode", + "XConstantField", "XFreeslip", "XLinear", "XLinearInvdistLandTracer", @@ -135,6 +136,15 @@ def XLinear( return value.compute() if is_dask_collection(value) else value +def XConstantField( + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]], + field: Field, +): + """Returning the single value of a Constant Field (with a size=(1,1,1,1) array)""" + return field.data[0, 0, 0, 0].values + + def CGrid_Velocity( particle_positions: dict[str, float | np.ndarray], grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]], @@ -598,7 +608,7 @@ def XLinearInvdistLandTracer( all_land_mask = nb_land == 4 * lenZ * lenT values[all_land_mask] = 0.0 - not_all_land = ~all_land_mask + not_all_land = np.asarray(~all_land_mask, dtype=bool) if np.any(not_all_land): i_grid = np.arange(2)[None, None, None, :, None] j_grid = np.arange(2)[None, None, :, None, None] diff --git a/tests/test_diffusion.py b/tests/test_diffusion.py index d0091c7d6..3c9d63a86 100644 --- a/tests/test_diffusion.py +++ b/tests/test_diffusion.py @@ -23,12 +23,10 @@ def test_fieldKh_Brownian(mesh): grid = XGrid.from_dataset(ds, mesh=mesh) U = Field("U", ds["U"], grid, interp_method=XLinear) V = Field("V", ds["V"], grid, interp_method=XLinear) - ds["Kh_zonal"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_zonal)) - ds["Kh_meridional"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_meridional)) - Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, interp_method=XLinear) - Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, interp_method=XLinear) UV = VectorField("UV", U, V) - fieldset = FieldSet([U, V, UV, Kh_zonal, Kh_meridional]) + fieldset = FieldSet([U, V, UV]) + fieldset.add_constant_field("Kh_zonal", kh_zonal, mesh=mesh) + fieldset.add_constant_field("Kh_meridional", kh_meridional, mesh=mesh) npart = 100 runtime = np.timedelta64(2, "h") diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 5c140b44b..82498b567 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -59,7 +59,6 @@ def test_fieldset_add_constant_field(fieldset): lat = ds["lat"].mean() lon = ds["lon"].mean() - pytest.xfail(reason="Not yet implemented interpolation.") assert fieldset.test_constant_field[time, z, lat, lon] == 1.0