Skip to content

Commit 2e7ed45

Browse files
shoyerWeatherbench2 authors
authored and
Weatherbench2 authors
committed
Refactor expand_climatology to make it more scalable and to preserve spatial chunks
PiperOrigin-RevId: 575355512
1 parent 364c28c commit 2e7ed45

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

scripts/expand_climatology.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -76,26 +76,33 @@
7676

7777

7878
def select_climatology(
79-
time_slice: slice, climatology: xarray.Dataset, time_index: pd.DatetimeIndex
80-
) -> abc.Iterable[tuple[xbeam.Key, xarray.Dataset]]:
79+
variable_name_and_time_slice: tuple[str, slice],
80+
climatology: xarray.Dataset,
81+
time_index: pd.DatetimeIndex,
82+
base_chunks: dict[str, int],
83+
) -> abc.Iterator[tuple[xbeam.Key, xarray.Dataset]]:
8184
"""Select climatology data matching time_index[time_slice]."""
85+
variable_name, time_slice = variable_name_and_time_slice
8286
chunk_times = time_index[time_slice]
8387
times_array = xarray.DataArray(
8488
chunk_times, dims=['time'], coords={'time': chunk_times}
8589
)
8690
if 'hour' in climatology.coords:
87-
chunk = climatology.sel(
91+
sliced = climatology[[variable_name]].sel(
8892
dayofyear=times_array.dt.dayofyear, hour=times_array.dt.hour
8993
)
90-
del chunk.coords['dayofyear']
91-
del chunk.coords['hour']
94+
del sliced.coords['dayofyear']
95+
del sliced.coords['hour']
9296
else:
93-
chunk = climatology.sel(dayofyear=times_array.dt.dayofyear)
94-
del chunk.coords['dayofyear']
97+
sliced = climatology[[variable_name]].sel(
98+
dayofyear=times_array.dt.dayofyear
99+
)
100+
del sliced.coords['dayofyear']
95101

96-
for variable_name in chunk:
97-
key = xbeam.Key({'time': time_slice.start}, vars={variable_name}) # pytype: disable=wrong-arg-types
98-
yield key, chunk[[variable_name]]
102+
key = xbeam.Key({'time': time_slice.start}, vars={variable_name})
103+
sliced = sliced.compute()
104+
target_chunks = {k: v for k, v in base_chunks.items() if k in sliced.dims}
105+
yield from xbeam.split_chunks(key, sliced, target_chunks)
99106

100107

101108
def main(argv: list[str]) -> None:
@@ -124,8 +131,9 @@ def main(argv: list[str]) -> None:
124131
time_chunk_size = TIME_CHUNK_SIZE.value
125132

126133
time_chunk_count = math.ceil(times.size / time_chunk_size)
127-
128-
output_chunks = {dim: -1 for dim in input_chunks if dim not in time_dims}
134+
variables = list(climatology.keys())
135+
base_chunks = {k: v for k, v in input_chunks.items() if k not in time_dims}
136+
output_chunks = dict(base_chunks)
129137
output_chunks['time'] = time_chunk_size
130138

131139
# Beam type checking is broken with Python 3.10:
@@ -137,7 +145,9 @@ def main(argv: list[str]) -> None:
137145
root
138146
| beam.Create([i * time_chunk_size for i in range(time_chunk_count)])
139147
| beam.Map(lambda start: slice(start, start + time_chunk_size))
140-
| beam.FlatMap(select_climatology, climatology, times)
148+
| beam.FlatMap(lambda index: [(v, index) for v in variables])
149+
| beam.Reshuffle()
150+
| beam.FlatMap(select_climatology, climatology, times, base_chunks)
141151
| xbeam.ChunksToZarr(
142152
OUTPUT_PATH.value, template=template, zarr_chunks=output_chunks
143153
)

scripts/expand_climatology_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test(self):
5252
input_path = self.create_tempdir('input_path').full_path
5353
output_path = self.create_tempdir('output_path').full_path
5454

55-
climatology.chunk({'dayofyear': 31}).to_zarr(input_path)
55+
climatology.chunk({'dayofyear': 31, 'level': 1}).to_zarr(input_path)
5656

5757
with flagsaver.flagsaver(
5858
input_path=input_path,
@@ -65,6 +65,8 @@ def test(self):
6565

6666
actual = xarray.open_zarr(output_path)
6767
xarray.testing.assert_allclose(actual, expected)
68+
self.assertEqual(actual.chunks['time'][0], 4 * 31)
69+
self.assertEqual(actual.chunks['level'][0], 1)
6870

6971

7072
if __name__ == '__main__':

0 commit comments

Comments
 (0)