Skip to content

Commit 5944af3

Browse files
Implement XConstantField interpolation
1 parent c069c66 commit 5944af3

File tree

4 files changed

+28
-10
lines changed

4 files changed

+28
-10
lines changed

src/parcels/_core/fieldset.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from parcels._core.xgrid import _DEFAULT_XGCM_KWARGS, XGrid
1818
from parcels._logger import logger
1919
from parcels._typing import Mesh
20+
from parcels.interpolators import XConstantField
2021

2122
if TYPE_CHECKING:
2223
from parcels._core.basegrid import BaseGrid
@@ -116,7 +117,7 @@ def add_field(self, field: Field, name: str | None = None):
116117

117118
self.fields[name] = field
118119

119-
def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
120+
def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"):
120121
"""Wrapper function to add a Field that is constant in space,
121122
useful e.g. when using constant horizontal diffusivity
122123
@@ -134,14 +135,24 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
134135
correction for zonal velocity U near the poles.
135136
2. flat: No conversion, lat/lon are assumed to be in m.
136137
"""
137-
ds = xr.Dataset({name: (["time", "lat", "lon", "depth"], np.full((1, 1, 1, 1), value))})
138-
grid = XGrid(xgcm.Grid(ds, **_DEFAULT_XGCM_KWARGS))
138+
ds = xr.Dataset(
139+
{name: (["time", "depth", "lat", "lon"], np.full((1, 1, 1, 1), value))},
140+
coords={
141+
"time": (["time"], [np.timedelta64(0, "s")], {"axis": "T"}),
142+
"depth": (["depth"], [0], {"axis": "Z"}),
143+
"lat": (["lat"], [0], {"axis": "Y", "c_grid_axis_shift": -0.5}),
144+
"lon": (["lon"], [0], {"axis": "X", "c_grid_axis_shift": -0.5}),
145+
"lat_C": (["lat_C"], [0.5], {"axis": "Y"}), # TODO why is this needed?
146+
"lon_C": (["lon_C"], [0.5], {"axis": "X"}), # TODO why is this needed?
147+
},
148+
)
149+
grid = XGrid(xgcm.Grid(ds, **_DEFAULT_XGCM_KWARGS), mesh=mesh)
139150
self.add_field(
140151
Field(
141152
name,
142153
ds[name],
143154
grid,
144-
interp_method=None, # TODO : Need to define an interpolation method for constants
155+
interp_method=XConstantField,
145156
)
146157
)
147158

src/parcels/interpolators.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"CGrid_Velocity",
2121
"UXPiecewiseConstantFace",
2222
"UXPiecewiseLinearNode",
23+
"XConstantField",
2324
"XFreeslip",
2425
"XLinear",
2526
"XLinearInvdistLandTracer",
@@ -135,6 +136,15 @@ def XLinear(
135136
return value.compute() if is_dask_collection(value) else value
136137

137138

139+
def XConstantField(
140+
particle_positions: dict[str, float | np.ndarray],
141+
grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]],
142+
field: Field,
143+
):
144+
"""Returning the single value of a Constant Field (with a size=(1,1,1,1) array)"""
145+
return field.data[0, 0, 0, 0].values
146+
147+
138148
def CGrid_Velocity(
139149
particle_positions: dict[str, float | np.ndarray],
140150
grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]],

tests/test_diffusion.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,10 @@ def test_fieldKh_Brownian(mesh):
2323
grid = XGrid.from_dataset(ds, mesh=mesh)
2424
U = Field("U", ds["U"], grid, interp_method=XLinear)
2525
V = Field("V", ds["V"], grid, interp_method=XLinear)
26-
ds["Kh_zonal"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_zonal))
27-
ds["Kh_meridional"] = (["time", "depth", "YG", "XG"], np.full((2, 1, 2, 2), kh_meridional))
28-
Kh_zonal = Field("Kh_zonal", ds["Kh_zonal"], grid=grid, interp_method=XLinear)
29-
Kh_meridional = Field("Kh_meridional", ds["Kh_meridional"], grid=grid, interp_method=XLinear)
3026
UV = VectorField("UV", U, V)
31-
fieldset = FieldSet([U, V, UV, Kh_zonal, Kh_meridional])
27+
fieldset = FieldSet([U, V, UV])
28+
fieldset.add_constant_field("Kh_zonal", kh_zonal, mesh=mesh)
29+
fieldset.add_constant_field("Kh_meridional", kh_meridional, mesh=mesh)
3230

3331
npart = 100
3432
runtime = np.timedelta64(2, "h")

tests/test_fieldset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def test_fieldset_add_constant_field(fieldset):
5959
lat = ds["lat"].mean()
6060
lon = ds["lon"].mean()
6161

62-
pytest.xfail(reason="Not yet implemented interpolation.")
6362
assert fieldset.test_constant_field[time, z, lat, lon] == 1.0
6463

6564

0 commit comments

Comments
 (0)