Skip to content

Commit 2e6804a

Browse files
Make sure that supplementary variables and weights have same chunks as parent cube (#2637)
Co-authored-by: Bouwe Andela <b.andela@esciencecenter.nl>
1 parent fc45c72 commit 2e6804a

File tree

5 files changed

+152
-53
lines changed

5 files changed

+152
-53
lines changed

esmvalcore/iris_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,14 +247,14 @@ def rechunk_cube(
247247
cube:
248248
Input cube.
249249
complete_coords:
250-
(Names of) coordinates along which the output cubes should not be
250+
(Names of) coordinates along which the output cube should not be
251251
chunked.
252252
remaining_dims:
253253
Chunksize of the remaining dimensions.
254254
255255
Returns
256256
-------
257-
Cube
257+
iris.cube.Cube
258258
Rechunked cube. This will always be a copy of the input cube.
259259
260260
"""

esmvalcore/preprocessor/_supplementary_vars.py

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
"""Preprocessor functions for ancillary variables and cell measures."""
22

33
import logging
4-
from typing import Iterable
4+
from collections.abc import Callable
5+
from typing import Iterable, Literal
56

67
import iris.coords
7-
import iris.cube
8+
from iris.cube import Cube
89

910
logger = logging.getLogger(__name__)
1011

1112
PREPROCESSOR_SUPPLEMENTARIES = {}
1213

1314

14-
def register_supplementaries(variables, required):
15+
def register_supplementaries(
16+
variables: list[str],
17+
required: Literal["require_at_least_one", "prefer_at_least_one"],
18+
) -> Callable:
1519
"""Register supplementary variables required for a preprocessor function.
1620
1721
Parameters
1822
----------
19-
variables: :obj:`list` of :obj`str`
23+
variables:
2024
List of variable names.
2125
required:
2226
How strong the requirement is. Can be 'require_at_least_one' if at
@@ -39,16 +43,25 @@ def wrapper(func):
3943
return wrapper
4044

4145

42-
def add_cell_measure(cube, cell_measure_cube, measure):
43-
"""Add a cube as a cell_measure in the cube containing the data.
46+
def add_cell_measure(
47+
cube: Cube,
48+
cell_measure_cube: Cube,
49+
measure: Literal["area", "volume"],
50+
) -> None:
51+
"""Add cell measure to cube (in-place).
52+
53+
Note
54+
----
55+
This assumes that the cell measure spans the rightmost dimensions of the
56+
cube.
4457
4558
Parameters
4659
----------
47-
cube: iris.cube.Cube
60+
cube:
4861
Iris cube with input data.
49-
cell_measure_cube: iris.cube.Cube
62+
cell_measure_cube:
5063
Iris cube with cell measure data.
51-
measure: str
64+
measure:
5265
Name of the measure, can be 'area' or 'volume'.
5366
5467
Returns
@@ -65,47 +78,62 @@ def add_cell_measure(cube, cell_measure_cube, measure):
6578
raise ValueError(
6679
f"measure name must be 'area' or 'volume', got {measure} instead"
6780
)
68-
measure = iris.coords.CellMeasure(
69-
cell_measure_cube.core_data(),
81+
coord_dims = tuple(
82+
range(cube.ndim - len(cell_measure_cube.shape), cube.ndim)
83+
)
84+
cell_measure_data = cell_measure_cube.core_data()
85+
if cell_measure_cube.has_lazy_data():
86+
cube_chunks = tuple(cube.lazy_data().chunks[d] for d in coord_dims)
87+
cell_measure_data = cell_measure_data.rechunk(cube_chunks)
88+
cell_measure = iris.coords.CellMeasure(
89+
cell_measure_data,
7090
standard_name=cell_measure_cube.standard_name,
7191
units=cell_measure_cube.units,
7292
measure=measure,
7393
var_name=cell_measure_cube.var_name,
7494
attributes=cell_measure_cube.attributes,
7595
)
76-
start_dim = cube.ndim - len(measure.shape)
77-
cube.add_cell_measure(measure, range(start_dim, cube.ndim))
96+
cube.add_cell_measure(cell_measure, coord_dims)
7897
logger.debug(
7998
"Added %s as cell measure in cube of %s.",
8099
cell_measure_cube.var_name,
81100
cube.var_name,
82101
)
83102

84103

85-
def add_ancillary_variable(cube, ancillary_cube):
86-
"""Add cube as an ancillary variable in the cube containing the data.
104+
def add_ancillary_variable(cube: Cube, ancillary_cube: Cube) -> None:
105+
"""Add ancillary variable to cube (in-place).
106+
107+
Note
108+
----
109+
This assumes that the ancillary variable spans the rightmost dimensions of
110+
the cube.
87111
88112
Parameters
89113
----------
90-
cube: iris.cube.Cube
114+
cube:
91115
Iris cube with input data.
92-
ancillary_cube: iris.cube.Cube
116+
ancillary_cube:
93117
Iris cube with ancillary data.
94118
95119
Returns
96120
-------
97121
iris.cube.Cube
98122
Cube with added ancillary variables
99123
"""
124+
coord_dims = tuple(range(cube.ndim - len(ancillary_cube.shape), cube.ndim))
125+
ancillary_data = ancillary_cube.core_data()
126+
if ancillary_cube.has_lazy_data():
127+
cube_chunks = tuple(cube.lazy_data().chunks[d] for d in coord_dims)
128+
ancillary_data = ancillary_data.rechunk(cube_chunks)
100129
ancillary_var = iris.coords.AncillaryVariable(
101-
ancillary_cube.core_data(),
130+
ancillary_data,
102131
standard_name=ancillary_cube.standard_name,
103132
units=ancillary_cube.units,
104133
var_name=ancillary_cube.var_name,
105134
attributes=ancillary_cube.attributes,
106135
)
107-
start_dim = cube.ndim - len(ancillary_var.shape)
108-
cube.add_ancillary_variable(ancillary_var, range(start_dim, cube.ndim))
136+
cube.add_ancillary_variable(ancillary_var, coord_dims)
109137
logger.debug(
110138
"Added %s as ancillary variable in cube of %s.",
111139
ancillary_cube.var_name,
@@ -114,10 +142,10 @@ def add_ancillary_variable(cube, ancillary_cube):
114142

115143

116144
def add_supplementary_variables(
117-
cube: iris.cube.Cube,
118-
supplementary_cubes: Iterable[iris.cube.Cube],
119-
) -> iris.cube.Cube:
120-
"""Add ancillary variables and/or cell measures.
145+
cube: Cube,
146+
supplementary_cubes: Iterable[Cube],
147+
) -> Cube:
148+
"""Add ancillary variables and/or cell measures to cube (in-place).
121149
122150
Parameters
123151
----------
@@ -131,7 +159,7 @@ def add_supplementary_variables(
131159
iris.cube.Cube
132160
Cube with added ancillary variables and/or cell measures.
133161
"""
134-
measure_names = {
162+
measure_names: dict[str, Literal["area", "volume"]] = {
135163
"areacella": "area",
136164
"areacello": "area",
137165
"volcello": "volume",
@@ -145,15 +173,14 @@ def add_supplementary_variables(
145173
return cube
146174

147175

148-
def remove_supplementary_variables(cube: iris.cube.Cube):
149-
"""Remove supplementary variables.
176+
def remove_supplementary_variables(cube: Cube) -> Cube:
177+
"""Remove supplementary variables from cube (in-place).
150178
151-
Strip cell measures or ancillary variables from the cube containing the
152-
data.
179+
Strip cell measures or ancillary variables from the cube.
153180
154181
Parameters
155182
----------
156-
cube: iris.cube.Cube
183+
cube:
157184
Iris cube with data and cell measures or ancillary variables.
158185
159186
Returns

esmvalcore/preprocessor/_volume.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def extract_volume(
104104
return cube.extract(z_constraint)
105105

106106

107-
def calculate_volume(cube: Cube) -> da.core.Array:
107+
def calculate_volume(cube: Cube) -> np.ndarray | da.Array:
108108
"""Calculate volume from a cube.
109109
110110
This function is used when the 'ocean_volume' cell measure can't be found.
@@ -119,13 +119,13 @@ def calculate_volume(cube: Cube) -> da.core.Array:
119119
120120
Parameters
121121
----------
122-
cube: iris.cube.Cube
122+
cube:
123123
input cube.
124124
125125
Returns
126126
-------
127-
dask.array.core.Array
128-
Grid volumes.
127+
np.ndarray | dask.array.Array
128+
Grid volume.
129129
130130
"""
131131
# Load depth field and figure out which dim is which
@@ -158,7 +158,11 @@ def calculate_volume(cube: Cube) -> da.core.Array:
158158
# Calculate Z-direction thickness
159159
thickness = depth.core_bounds()[..., 1] - depth.core_bounds()[..., 0]
160160
if cube.has_lazy_data():
161-
thickness = da.array(thickness)
161+
z_chunks = tuple(cube.lazy_data().chunks[d] for d in z_dim)
162+
if isinstance(thickness, da.Array):
163+
thickness = thickness.rechunk(z_chunks)
164+
else:
165+
thickness = da.asarray(thickness, chunks=z_chunks)
162166

163167
# Get or calculate the horizontal areas of the cube
164168
has_cell_measure = bool(cube.cell_measures("cell_area"))
@@ -182,6 +186,8 @@ def calculate_volume(cube: Cube) -> da.core.Array:
182186
thickness, cube.shape, z_dim, chunks=chunks
183187
)
184188
grid_volume = area_arr * thickness_arr
189+
if cube.has_lazy_data():
190+
grid_volume = grid_volume.rechunk(chunks)
185191

186192
return grid_volume
187193

@@ -403,7 +409,10 @@ def axis_statistics(
403409
def _add_axis_stats_weights_coord(cube, coord, coord_dims):
404410
"""Add weights for axis_statistics to cube (in-place)."""
405411
weights = np.abs(coord.lazy_bounds()[:, 1] - coord.lazy_bounds()[:, 0])
406-
if not cube.has_lazy_data():
412+
if cube.has_lazy_data():
413+
coord_chunks = tuple(cube.lazy_data().chunks[d] for d in coord_dims)
414+
weights = weights.rechunk(coord_chunks)
415+
else:
407416
weights = weights.compute()
408417
weights_coord = AuxCoord(
409418
weights,

tests/integration/preprocessor/_supplementary_vars/test_add_supplementary_variables.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
:func:`esmvalcore.preprocessor._supplementary_vars` module.
55
"""
66

7+
import dask.array as da
78
import iris
89
import iris.fileformats
910
import numpy as np
@@ -112,20 +113,37 @@ def setUp(self):
112113
],
113114
)
114115

116+
@pytest.mark.parametrize("lazy", [True, False])
115117
@pytest.mark.parametrize("var_name", ["areacella", "areacello"])
116-
def test_add_cell_measure_area(self, var_name):
118+
def test_add_cell_measure_area(self, var_name, lazy):
117119
"""Test add area fx variables as cell measures."""
120+
if lazy:
121+
self.fx_area.data = self.fx_area.lazy_data()
122+
self.new_cube_data = da.array(self.new_cube_data).rechunk((1, 2))
118123
self.fx_area.var_name = var_name
119124
self.fx_area.standard_name = "cell_area"
120125
self.fx_area.units = "m2"
121126
cube = iris.cube.Cube(
122127
self.new_cube_data, dim_coords_and_dims=self.coords_spec
123128
)
129+
124130
cube = add_supplementary_variables(cube, [self.fx_area])
125-
assert cube.cell_measure(self.fx_area.standard_name) is not None
126131

127-
def test_add_cell_measure_volume(self):
132+
assert cube.has_lazy_data() is lazy
133+
assert cube.cell_measures(self.fx_area.standard_name)
134+
cell_measure = cube.cell_measure(self.fx_area.standard_name)
135+
assert cell_measure.has_lazy_data() is lazy
136+
if lazy:
137+
assert cell_measure.lazy_data().chunks == cube.lazy_data().chunks
138+
139+
@pytest.mark.parametrize("lazy", [True, False])
140+
def test_add_cell_measure_volume(self, lazy):
128141
"""Test add volume as cell measure."""
142+
if lazy:
143+
self.fx_volume.data = self.fx_volume.lazy_data()
144+
self.new_cube_3D_data = da.array(self.new_cube_3D_data).rechunk(
145+
(1, 2, 3)
146+
)
129147
self.fx_volume.var_name = "volcello"
130148
self.fx_volume.standard_name = "ocean_volume"
131149
self.fx_volume.units = "m3"
@@ -137,8 +155,15 @@ def test_add_cell_measure_volume(self):
137155
(self.lons, 2),
138156
],
139157
)
158+
140159
cube = add_supplementary_variables(cube, [self.fx_volume])
141-
assert cube.cell_measure(self.fx_volume.standard_name) is not None
160+
161+
assert cube.has_lazy_data() is lazy
162+
assert cube.cell_measures(self.fx_volume.standard_name)
163+
cell_measure = cube.cell_measure(self.fx_volume.standard_name)
164+
assert cell_measure.has_lazy_data() is lazy
165+
if lazy:
166+
assert cell_measure.lazy_data().chunks == cube.lazy_data().chunks
142167

143168
def test_no_cell_measure(self):
144169
"""Test no cell measure is added."""
@@ -153,16 +178,27 @@ def test_no_cell_measure(self):
153178
cube = add_supplementary_variables(cube, [])
154179
assert cube.cell_measures() == []
155180

156-
def test_add_supplementary_vars(self):
157-
"""Test invalid variable is not added as cell measure."""
181+
@pytest.mark.parametrize("lazy", [True, False])
182+
def test_add_ancillary_vars(self, lazy):
183+
"""Test adding ancillary variables."""
184+
if lazy:
185+
self.fx_area.data = self.fx_area.lazy_data()
186+
self.new_cube_data = da.array(self.new_cube_data).rechunk((1, 2))
158187
self.fx_area.var_name = "sftlf"
159188
self.fx_area.standard_name = "land_area_fraction"
160189
self.fx_area.units = "%"
161190
cube = iris.cube.Cube(
162191
self.new_cube_data, dim_coords_and_dims=self.coords_spec
163192
)
193+
164194
cube = add_supplementary_variables(cube, [self.fx_area])
165-
assert cube.ancillary_variable(self.fx_area.standard_name) is not None
195+
196+
assert cube.has_lazy_data() is lazy
197+
assert cube.ancillary_variables(self.fx_area.standard_name)
198+
anc_var = cube.ancillary_variable(self.fx_area.standard_name)
199+
assert anc_var.has_lazy_data() is lazy
200+
if lazy:
201+
assert anc_var.lazy_data().chunks == cube.lazy_data().chunks
166202

167203
def test_wrong_shape(self, monkeypatch):
168204
"""Test variable is not added if it's not broadcastable to cube."""

0 commit comments

Comments
 (0)