Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/parcels/_core/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from parcels._core.xgrid import _DEFAULT_XGCM_KWARGS, XGrid
from parcels._logger import logger
from parcels._typing import Mesh
from parcels.interpolators import XConstantField

if TYPE_CHECKING:
from parcels._core.basegrid import BaseGrid
Expand Down Expand Up @@ -116,7 +117,7 @@ def add_field(self, field: Field, name: str | None = None):

self.fields[name] = field

def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"):
"""Wrapper function to add a Field that is constant in space,
useful e.g. when using constant horizontal diffusivity

Expand All @@ -134,16 +135,15 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
correction for zonal velocity U near the poles.
2. flat: No conversion, lat/lon are assumed to be in m.
"""
ds = xr.Dataset({name: (["time", "lat", "lon", "depth"], np.full((1, 1, 1, 1), value))})
grid = XGrid(xgcm.Grid(ds, **_DEFAULT_XGCM_KWARGS))
self.add_field(
Field(
name,
ds[name],
grid,
interp_method=None, # TODO : Need to define an interpolation method for constants
)
ds = xr.Dataset(
{name: (["lat", "lon"], np.full((1, 1), value))},
coords={"lat": (["lat"], [0], {"axis": "Y"}), "lon": (["lon"], [0], {"axis": "X"})},
)
xgrid = xgcm.Grid(
ds, coords={"X": {"left": "lon"}, "Y": {"left": "lat"}}, autoparse_metadata=False, **_DEFAULT_XGCM_KWARGS
)
grid = XGrid(xgrid, mesh=mesh)
self.add_field(Field(name, ds[name], grid, interp_method=XConstantField))

def add_constant(self, name, value):
"""Add a constant to the FieldSet. Note that all constants are
Expand Down
12 changes: 11 additions & 1 deletion src/parcels/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"CGrid_Velocity",
"UXPiecewiseConstantFace",
"UXPiecewiseLinearNode",
"XConstantField",
"XFreeslip",
"XLinear",
"XLinearInvdistLandTracer",
Expand Down Expand Up @@ -135,6 +136,15 @@ def XLinear(
return value.compute() if is_dask_collection(value) else value


def XConstantField(
particle_positions: dict[str, float | np.ndarray],
grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]],
field: Field,
):
"""Returning the single value of a Constant Field (with a size=(1,1,1,1) array)"""
return field.data[0, 0, 0, 0].values


def CGrid_Velocity(
particle_positions: dict[str, float | np.ndarray],
grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]],
Expand Down Expand Up @@ -598,7 +608,7 @@ def XLinearInvdistLandTracer(
all_land_mask = nb_land == 4 * lenZ * lenT
values[all_land_mask] = 0.0

not_all_land = ~all_land_mask
not_all_land = np.asarray(~all_land_mask, dtype=bool)
if np.any(not_all_land):
i_grid = np.arange(2)[None, None, None, :, None]
j_grid = np.arange(2)[None, None, :, None, None]
Expand Down
8 changes: 3 additions & 5 deletions tests/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@ def test_fieldKh_Brownian(mesh):
grid = XGrid.from_dataset(ds, mesh=mesh)
U = Field("U", ds["U"], grid, interp_method=XLinear)
V = Field("V", ds["V"], grid, interp_method=XLinear)
ds["Kh_zonal"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_zonal))
ds["Kh_meridional"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_meridional))
Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, interp_method=XLinear)
Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, interp_method=XLinear)
UV = VectorField("UV", U, V)
fieldset = FieldSet([U, V, UV, Kh_zonal, Kh_meridional])
fieldset = FieldSet([U, V, UV])
fieldset.add_constant_field("Kh_zonal", kh_zonal, mesh=mesh)
fieldset.add_constant_field("Kh_meridional", kh_meridional, mesh=mesh)

npart = 100
runtime = np.timedelta64(2, "h")
Expand Down
1 change: 0 additions & 1 deletion tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def test_fieldset_add_constant_field(fieldset):
lat = ds["lat"].mean()
lon = ds["lon"].mean()

pytest.xfail(reason="Not yet implemented interpolation.")
assert fieldset.test_constant_field[time, z, lat, lon] == 1.0


Expand Down