Skip to content

XGrid localization #2082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 16, 2025
Merged

XGrid localization #2082

merged 6 commits into from
Jul 16, 2025

Conversation

VeckoTheGecko
Copy link
Contributor

@VeckoTheGecko VeckoTheGecko commented Jul 11, 2025

Currently we have a search method which returns a particle position relative to the F points. However, since our data can be defined on a staggered grid, it's important to "localize" this particle position to the grid for the array of interest. This mainly applies for C grids and when working with MITgcm and NEMO where their F points and C points are defined differently relative to each other (see diagram in docs or in #2037) .

This PR introduces this grid localization. This is really just the first draft to get feedback (code is a bit more messy than I would like - and there are no tests). This good localization will help with writing interpolators.

  • Chose the correct base branch (v4-dev for v4 changes)
  • Fixes None
  • Added tests (not yet)
  • Added documentation

@VeckoTheGecko
Copy link
Contributor Author

VeckoTheGecko commented Jul 11, 2025

Below is a small testing script along with output. data_c with the -1 index seems a bit strange, but I guess it makes sense (if vertical positions are defined on the cell centers and the particle is at the surface then -1 with bcoord 0.5 is expected). Thoughts @erikvansebille ?

# %%
import numpy as np
import xarray as xr
from pprint import pprint

from parcels import xgcm
from parcels._datasets.structured.generic import X, Y, Z
from parcels.xgrid import XGrid

T = 2
Z = Y = X = 3
TIME = xr.date_range("2000", "2001", T)

ds_mitgcm = xr.Dataset(
    {
        "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
        "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
        "U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
        "V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
        "U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
        "V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
    },
    coords={
        "XG": (
            ["XG"],
            np.arange(0, X),
            {"axis": "X", "c_grid_axis_shift": -0.5},
        ),
        "XC": (["XC"], np.arange(0, X) + 0.5, {"axis": "X"}),
        "YG": (
            ["YG"],
            np.arange(0, Y),
            {"axis": "Y", "c_grid_axis_shift": -0.5},
        ),
        "YC": (
            ["YC"],
            np.arange(0, Y) + 0.5,
            {"axis": "Y"},
        ),
        "ZG": (
            ["ZG"],
            np.arange(Z),
            {"axis": "Z", "c_grid_axis_shift": -0.5},
        ),
        "ZC": (
            ["ZC"],
            np.arange(Z) + 0.5,
            {"axis": "Z"},
        ),
        "lon": (["XG"], np.arange(0, X)),
        "lat": (["YG"], np.arange(0, Y)),
        "depth": (["ZG"], np.arange(Z)),
        "time": (["time"], TIME, {"axis": "T"}),
    },
)

ds_nemo = xr.Dataset(
    {
        "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
        "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
        "U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
        "V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
        "U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
        "V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
    },
    coords={
        "XG": (
            ["XG"],
            np.arange(0, X),
            {"axis": "X", "c_grid_axis_shift": 0.5},
        ),
        "XC": (["XC"], np.arange(0, X) - 0.5, {"axis": "X"}),
        "YG": (
            ["YG"],
            np.arange(0, Y),
            {"axis": "Y", "c_grid_axis_shift": 0.5},
        ),
        "YC": (
            ["YC"],
            np.arange(0, Y) - 0.5,
            {"axis": "Y"},
        ),
        "ZG": (
            ["ZG"],
            np.arange(Z),
            {"axis": "Z", "c_grid_axis_shift": -0.5},
        ),
        "ZC": (
            ["ZC"],
            np.arange(Z) + 0.5,
            {"axis": "Z"},
        ),
        "lon": (["XG"], np.arange(0, X)),
        "lat": (["YG"], np.arange(0, Y)),
        "depth": (["ZG"], np.arange(Z)),
        "time": (["time"], TIME, {"axis": "T"}),
    },
)


grid_mitgcm = XGrid(xgcm.Grid(ds_mitgcm, periodic=False))
grid_nemo = XGrid(xgcm.Grid(ds_nemo, periodic=False))

print("XGCM repr of MITgcm grid:")
print(grid_mitgcm.xgcm_grid)
print("\nXGCM repr of NEMO grid:")
print(grid_nemo.xgcm_grid)


# %%
def show_point_on_grid(grid, z, y, x):
    """Pretty printing of some info"""

    position = grid.search(z, y, x)
    print(f"Position wrt. fpoints (lon/lat grid):")
    pprint(position)

    for da in grid.xgcm_grid._ds.data_vars.values():
        local_position = grid.localize(position, da.dims)
        print(f"On {da.name=} with {da.dims=}, local position:")
        pprint(local_position)


print("----Working with MITgcm grid----")
show_point_on_grid(grid_mitgcm, 0, 0.8, 0.8)

print("\n----Working with NEMO grid----")
show_point_on_grid(grid_nemo, 0, 0.8, 0.8)

output:

XGCM repr of MITgcm grid:
<parcels.Grid>
Y Axis (not periodic, boundary=None):
  * center   YC --> left
  * left     YG --> center
X Axis (not periodic, boundary=None):
  * center   XC --> left
  * left     XG --> center
Z Axis (not periodic, boundary=None):
  * center   ZC --> left
  * left     ZG --> center
T Axis (not periodic, boundary=None):
  * center   time

XGCM repr of NEMO grid:
<parcels.Grid>
Y Axis (not periodic, boundary=None):
  * center   YC --> right
  * right    YG --> center
X Axis (not periodic, boundary=None):
  * center   XC --> right
  * right    XG --> center
Z Axis (not periodic, boundary=None):
  * center   ZC --> left
  * left     ZG --> center
T Axis (not periodic, boundary=None):
  * center   time
----Working with MITgcm grid----
Position wrt. fpoints (lon/lat grid):
{'X': (np.int64(0), np.float64(0.8)),
 'Y': (np.int64(0), np.float64(0.8)),
 'Z': (np.int64(0), np.float64(0.0))}
On da.name='data_g' with da.dims=('time', 'ZG', 'YG', 'XG'), local position:
{'XG': (np.int64(0), np.float64(0.8)),
 'YG': (np.int64(0), np.float64(0.8)),
 'ZG': (np.int64(0), np.float64(0.0))}
On da.name='data_c' with da.dims=('time', 'ZC', 'YC', 'XC'), local position:
{'XC': (np.int64(0), np.float64(0.30000000000000004)),
 'YC': (np.int64(0), np.float64(0.30000000000000004)),
 'ZC': (np.int64(-1), np.float64(0.5))}
On da.name='U (A grid)' with da.dims=('time', 'ZG', 'YG', 'XG'), local position:
{'XG': (np.int64(0), np.float64(0.8)),
 'YG': (np.int64(0), np.float64(0.8)),
 'ZG': (np.int64(0), np.float64(0.0))}
On da.name='V (A grid)' with da.dims=('time', 'ZG', 'YG', 'XG'), local position:
{'XG': (np.int64(0), np.float64(0.8)),
 'YG': (np.int64(0), np.float64(0.8)),
 'ZG': (np.int64(0), np.float64(0.0))}
On da.name='U (C grid)' with da.dims=('time', 'ZG', 'YC', 'XG'), local position:
{'XG': (np.int64(0), np.float64(0.8)),
 'YC': (np.int64(0), np.float64(0.30000000000000004)),
 'ZG': (np.int64(0), np.float64(0.0))}
On da.name='V (C grid)' with da.dims=('time', 'ZG', 'YG', 'XC'), local position:
{'XC': (np.int64(0), np.float64(0.30000000000000004)),
 'YG': (np.int64(0), np.float64(0.8)),
 'ZG': (np.int64(0), np.float64(0.0))}

----Working with NEMO grid----
Position wrt. fpoints (lon/lat grid):
{'X': (np.int64(0), np.float64(0.8)),
 'Y': (np.int64(0), np.float64(0.8)),
 'Z': (np.int64(0), np.float64(0.0))}
On da.name='data_g' with da.dims=('time', 'ZG', 'YG', 'XG'), local position:
{'XG': (np.int64(0), np.float64(0.8)),
 'YG': (np.int64(0), np.float64(0.8)),
 'ZG': (np.int64(0), np.float64(0.0))}
On da.name='data_c' with da.dims=('time', 'ZC', 'YC', 'XC'), local position:
{'XC': (np.int64(1), np.float64(0.30000000000000004)),
 'YC': (np.int64(1), np.float64(0.30000000000000004)),
 'ZC': (np.int64(-1), np.float64(0.5))}
On da.name='U (A grid)' with da.dims=('time', 'ZG', 'YG', 'XG'), local position:
{'XG': (np.int64(0), np.float64(0.8)),
 'YG': (np.int64(0), np.float64(0.8)),
 'ZG': (np.int64(0), np.float64(0.0))}
On da.name='V (A grid)' with da.dims=('time', 'ZG', 'YG', 'XG'), local position:
{'XG': (np.int64(0), np.float64(0.8)),
 'YG': (np.int64(0), np.float64(0.8)),
 'ZG': (np.int64(0), np.float64(0.0))}
On da.name='U (C grid)' with da.dims=('time', 'ZG', 'YC', 'XG'), local position:
{'XG': (np.int64(0), np.float64(0.8)),
 'YC': (np.int64(1), np.float64(0.30000000000000004)),
 'ZG': (np.int64(0), np.float64(0.0))}
On da.name='V (C grid)' with da.dims=('time', 'ZG', 'YG', 'XC'), local position:
{'XC': (np.int64(1), np.float64(0.30000000000000004)),
 'YG': (np.int64(0), np.float64(0.8)),
 'ZG': (np.int64(0), np.float64(0.0))}

@erikvansebille
Copy link
Member

Note that some datasets have an attribute c_grid_axis_shift, see for example below. If this attribute exists (but that's unfortunately not guaranteed), we can use it in this function too?

"YG": (
["YG"],
np.linspace(0, 5000, Y, dtype="float64"),
{
"standard_name": "latitude_at_f_location",
"long_name": "latitude",
"units": "degrees_north",
"coordinate": "YG XG",
"axis": "Y",
"c_grid_axis_shift": -0.5,

@VeckoTheGecko
Copy link
Contributor Author

Note that some datasets have an attribute c_grid_axis_shift, see for example below. If this attribute exists (but that's unfortunately not guaranteed), we can use it in this function too?

Yes, this is already taken into account by xgcm during grid ingestion to determine the nature of the grid staggering. In fact, the only difference between ds_mitgcm and ds_nemo is the value of this attribute and the corresponding offset.

@VeckoTheGecko
Copy link
Contributor Author

VeckoTheGecko commented Jul 11, 2025

if we're comfortable with this approach, I'll go and clear up some stuff (e.g., variable naming since it might be a bit confusing)

@VeckoTheGecko VeckoTheGecko force-pushed the grid-localization branch 2 times, most recently from f3949e4 to 0d0b0e5 Compare July 14, 2025 11:34
@VeckoTheGecko VeckoTheGecko marked this pull request as ready for review July 14, 2025 11:35
Copy link
Member

@erikvansebille erikvansebille left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. And how would a user use this in an interpolation method? Or will that come in a next PR?

@VeckoTheGecko
Copy link
Contributor Author

VeckoTheGecko commented Jul 15, 2025

I realised that this localization step isn't actually needed.

image

If a point is defined at (in X) index, bcoord = 0,0.8 wrt. the f-points, that will be at 0, 0.3 wrt. the C points. However, since the dual grid already has dimension coordinates that are offset by 0.5 (i.e., 0.5, 1.5 ...) we can just use xarray's interpolation functionality with 0.8 on this grid in order to get this point no matter the grid.

Still we need to build tooling around getting the dimension names that correspond to the axes (i.e, for da.dims == ["XC", "YC"] the grid needs to tell us "X" -> "XC" and "Y" -> "YC". I'll put up another PR

@github-project-automation github-project-automation bot moved this from Backlog to Done in Parcels development Jul 15, 2025
@VeckoTheGecko VeckoTheGecko reopened this Jul 16, 2025
@github-project-automation github-project-automation bot moved this from Done to Backlog in Parcels development Jul 16, 2025
@VeckoTheGecko
Copy link
Contributor Author

I realised that this localization step isn't actually needed.

I realised that this is wrong actually 😅 . Localisation will be needed when writing curvilinear interpolators such as in the following being investigated in #2081 :

def XTriCurviLinear(
    field: Field,
    ti: int,
    position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
    tau: np.float32 | np.float64,
    t: np.float32 | np.float64,
    z: np.float32 | np.float64,
    y: np.float32 | np.float64,
    x: np.float32 | np.float64,
):
    """Trilinear interpolation on a curvilinear grid."""
    xi, xsi = position["X"]
    yi, eta = position["Y"]
    zi, zeta = position["Z"]
    data = field.data
    axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)

    return (
        (
            (1 - xsi) * (1 - eta) * data.isel({axis_dim["Y"]: yi, axis_dim["X"]: xi})
            + xsi * (1 - eta) * data.isel({axis_dim["Y"]: yi, axis_dim["X"]: xi + 1})
            + xsi * eta * data.isel({axis_dim["Y"]: yi + 1, axis_dim["X"]: xi + 1})
            + (1 - xsi) * eta * data.isel({axis_dim["Y"]: yi + 1, axis_dim["X"]: xi})
        )
        .interp(time=t, **{axis_dim["Z"]: zi + zeta})
        .values
    )

Here, xsi and eta need to be adjusted to match how the data is defined (e.g., if the data is defined on the cell centers).

Perhaps there are other ways to achieve this without a localization step - but I need to get more familiar with Xarray internals and coordinate aware interpolation for that. I'll mark this as unstable API subject to change in the docstring - this is an isolated change we can remove later.

@erikvansebille any additional thoughts?

@erikvansebille
Copy link
Member

Yep agree; let's keep it in for now

@VeckoTheGecko
Copy link
Contributor Author

VeckoTheGecko commented Jul 16, 2025

Also wrapped into this numpydoc!=1.9.0 so that our docs build again on v4-dev (didn't bother with v3 since I assume they'll issue a fix before we need to update v3 docs again) Actually, I'll quickly backport to v3 as well

@VeckoTheGecko VeckoTheGecko merged commit b510f11 into v4-dev Jul 16, 2025
9 checks passed
@VeckoTheGecko VeckoTheGecko deleted the grid-localization branch July 16, 2025 09:58
@github-project-automation github-project-automation bot moved this from Backlog to Done in Parcels development Jul 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

2 participants