Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/using/combining.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,16 @@ value as the smallest distance between two grid points in the global
dataset over the cutout area. If you do not want to use this feature,
you can set `min_distance_km=0`, or provide your own value.

Additionally, you can pass a `max_distance_km` parameter to the `cutout`
function. Any grid points in the global dataset that are further than
this distance from any grid point in the LAM dataset will be excluded
from the cutout. This can be useful to limit the extent of the global
dataset to only include points within a certain radius of the LAM region,
reducing memory usage and computation time. For example, setting
`max_distance_km=1000.0` will only include global grid points within
1000 km of the LAM boundary. If no value is provided (the default), all
global grid points outside the LAM region will be included.

The plots below illustrate how the cutout differs if `min_distance_km`
is not given (top) or if `min_distance_km` is set to `0` (bottom). The
difference can be seen at the boundary between the two grids:
Expand Down
11 changes: 11 additions & 0 deletions src/anemoi/datasets/data/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(
cropping_distance: float = 2.0,
neighbours: int = 5,
min_distance_km: float | None = None,
max_distance_km: float | None = None,
plot: bool | None = None,
) -> None:
"""Initializes a Cutout object for hierarchical management of Limited Area
Expand All @@ -172,6 +173,9 @@ def __init__(
Number of neighboring points to consider when constructing masks.
min_distance_km : float, optional
Minimum distance threshold in km between grid points.
max_distance_km : float, optional
Maximum distance threshold in km. Points further than this distance from the LAM
region will be excluded from the mask.
plot : bool, optional
Flag to enable or disable visualization plots.
"""
Expand All @@ -181,13 +185,16 @@ def __init__(
assert cropping_distance >= 0, "cropping_distance must be a non-negative number"
if min_distance_km is not None:
assert min_distance_km >= 0, "min_distance_km must be a non-negative number"
if max_distance_km is not None:
assert max_distance_km >= 0, "max_distance_km must be a non-negative number"

self.lams = datasets[:-1] # Assume the last dataset is the global one
self.globe = datasets[-1]
self.axis = axis
self.cropping_distance = cropping_distance
self.neighbours = neighbours
self.min_distance_km = min_distance_km
self.max_distance_km = max_distance_km
self._plot = plot
self.masks = [] # To store the masks for each LAM dataset
self.global_mask = np.ones(self.globe.shape[-1], dtype=bool)
Expand Down Expand Up @@ -219,6 +226,7 @@ def _initialize_masks(self) -> None:
self.globe.longitudes,
plot=self._plot,
min_distance_km=self.min_distance_km,
max_distance_km=self.max_distance_km,
cropping_distance=self.cropping_distance,
neighbours=self.neighbours,
)
Expand All @@ -245,6 +253,7 @@ def _initialize_masks(self) -> None:
lam_lons,
plot=self._plot,
min_distance_km=self.min_distance_km,
max_distance_km=self.max_distance_km,
cropping_distance=self.cropping_distance,
neighbours=self.neighbours,
)
Expand Down Expand Up @@ -484,6 +493,7 @@ def cutout_factory(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Dataset:
axis = kwargs.pop("axis", 3)
plot = kwargs.pop("plot", None)
min_distance_km = kwargs.pop("min_distance_km", None)
max_distance_km = kwargs.pop("max_distance_km", None)
cropping_distance = kwargs.pop("cropping_distance", 2.0)
neighbours = kwargs.pop("neighbours", 5)

Expand All @@ -498,6 +508,7 @@ def cutout_factory(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Dataset:
axis=axis,
neighbours=neighbours,
min_distance_km=min_distance_km,
max_distance_km=max_distance_km,
cropping_distance=cropping_distance,
plot=plot,
)._subset(**kwargs)
14 changes: 12 additions & 2 deletions src/anemoi/datasets/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def cutout_mask(
cropping_distance: float = 2.0,
neighbours: int = 5,
min_distance_km: int | float | None = None,
max_distance_km: int | float | None = None,
plot: str | None = None,
) -> NDArray[Any]:
"""Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons].
Expand All @@ -224,6 +225,9 @@ def cutout_mask(
Number of neighbours. Defaults to 5.
min_distance_km : Optional[Union[int, float]], optional
Minimum distance in kilometers. Defaults to None.
max_distance_km : Optional[Union[int, float]], optional
Maximum distance in kilometers. Points further than this distance from the LAM
region will be excluded from the mask. Defaults to None.
plot : Optional[str], optional
Path for saving the plot. Defaults to None.

Expand Down Expand Up @@ -305,7 +309,13 @@ def cutout_mask(

close = np.min(distance) <= min_distance

inside_lam.append(inside or close)
# Check if the point is within max_distance_km if specified
if max_distance_km is not None:
max_distance = max_distance_km / 6371.0
too_far = np.min(distance) > max_distance
inside_lam.append((inside or close) and not too_far)
else:
inside_lam.append(inside or close)

j = 0
inside_lam_array = np.array(inside_lam)
Expand Down Expand Up @@ -449,7 +459,7 @@ def nearest_grid_points(
source_longitudes: NDArray[Any],
target_latitudes: NDArray[Any],
target_longitudes: NDArray[Any],
max_distance: float = None,
max_distance: float | None = None,
k: int = 1,
) -> NDArray[Any]:
"""Find the nearest grid points from source to target coordinates.
Expand Down
135 changes: 135 additions & 0 deletions tests/test_grids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import numpy as np
import pytest

from anemoi.datasets.grids import cutout_mask


def test_cutout_mask_with_max_distance():
"""Test cutout_mask with max_distance_km parameter."""
# Create a LAM region
lam_lat_range = np.linspace(44.0, 46.0, 11)
lam_lon_range = np.linspace(0.0, 2.0, 11)
lam_lats, lam_lons = np.meshgrid(lam_lat_range, lam_lon_range)
lam_lats = lam_lats.flatten()
lam_lons = lam_lons.flatten()

# Create a global grid with points at varying distances
global_lats = np.array([43.0, 44.0, 45.0, 45.5, 46.0, 50.0])
global_lons = np.array([358.0, 359.0, 0.0, 1.0, 2.0, 0.0])

# Apply mask with max_distance_km to exclude far points
mask = cutout_mask(
lam_lats,
lam_lons,
global_lats,
global_lons,
max_distance_km=250.0, # 250 km limit
)

# The last point at lat=50.0 should be excluded (too far)
assert isinstance(mask, np.ndarray)
assert mask.shape == global_lats.shape
assert np.array_equal(mask, np.array([False, True, True, True, True, True]))


def test_cutout_mask_with_min_distance():
"""Test cutout_mask with both min_distance_km."""
# Create a LAM region
lam_lat_range = np.linspace(44.0, 46.0, 11)
lam_lon_range = np.linspace(0.0, 2.0, 11)
lam_lats, lam_lons = np.meshgrid(lam_lat_range, lam_lon_range)
lam_lats = lam_lats.flatten()
lam_lons = lam_lons.flatten()

# Create a global grid
global_lats = np.array([44.0, 45.0, 46.0, 46.01])
global_lons = np.array([0.0, 1.0, 2.0, -0.01])

mask = cutout_mask(
lam_lats,
lam_lons,
global_lats,
global_lons,
min_distance_km=100.0,
)

# The last point at lat=50.0 should be excluded (too close)
assert isinstance(mask, np.ndarray)
assert mask.shape == global_lats.shape
assert np.array_equal(mask, np.array([True, True, True, True]))


def test_cutout_mask_array_shapes():
"""Test that input arrays must be 1D."""
lam_lats = np.array([[45.0, 45.0], [46.0, 46.0]])
lam_lons = np.array([[0.0, 1.0], [0.0, 1.0]])
global_lats = np.array([45.0])
global_lons = np.array([0.0])

# Should raise assertion error due to 2D arrays
with pytest.raises(AssertionError):
cutout_mask(lam_lats, lam_lons, global_lats, global_lons)


def test_cutout_mask_parameter_types():
"""Test that max_distance_km accepts int and float."""
lam_lat_range = np.linspace(44.0, 46.0, 11)
lam_lon_range = np.linspace(0.0, 2.0, 11)
lam_lats, lam_lons = np.meshgrid(lam_lat_range, lam_lon_range)
lam_lats = lam_lats.flatten()
lam_lons = lam_lons.flatten()

global_lats = np.array([45.0, 46.0])
global_lons = np.array([0.0, 2.0])

# Test with int
mask_int = cutout_mask(lam_lats, lam_lons, global_lats, global_lons, max_distance_km=100)
assert isinstance(mask_int, np.ndarray)

# Test with float
mask_float = cutout_mask(lam_lats, lam_lons, global_lats, global_lons, max_distance_km=100.0)
assert isinstance(mask_float, np.ndarray)


def test_cutout_mask_large_grid():
"""Test cutout_mask with a larger, more realistic grid."""
# Create a LAM region (11x11 grid)
lam_lat_range = np.linspace(40.0, 50.0, 11)
lam_lon_range = np.linspace(0.0, 10.0, 11)
lam_lats, lam_lons = np.meshgrid(lam_lat_range, lam_lon_range)
lam_lats = lam_lats.flatten()
lam_lons = lam_lons.flatten()

# Create a global grid (21x21 grid)
global_lat_range = np.linspace(30.0, 60.0, 21)
global_lon_range = np.linspace(-10.0, 20.0, 21)
global_lats, global_lons = np.meshgrid(global_lat_range, global_lon_range)
global_lats = global_lats.flatten()
global_lons = global_lons.flatten()

mask = cutout_mask(
lam_lats,
lam_lons,
global_lats,
global_lons,
min_distance_km=50.0,
max_distance_km=300.0,
)

assert isinstance(mask, np.ndarray)
assert mask.shape == (441,) # 21x21 flattened
assert mask.dtype == bool
# Some points should be masked (excluded)
assert np.any(mask)
# Some points should not be masked
assert not np.all(mask)
Loading