Skip to content

Commit

Permalink
Refactor expand_climatology to make it more scalable and to preserve …
Browse files Browse the repository at this point in the history
…spatial chunks

PiperOrigin-RevId: 576932947
  • Loading branch information
shoyer authored and Weatherbench2 authors committed Oct 26, 2023
1 parent 364c28c commit 235765a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
36 changes: 23 additions & 13 deletions scripts/expand_climatology.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,33 @@


def select_climatology(
time_slice: slice, climatology: xarray.Dataset, time_index: pd.DatetimeIndex
) -> abc.Iterable[tuple[xbeam.Key, xarray.Dataset]]:
variable_name_and_time_slice: tuple[str, slice],
climatology: xarray.Dataset,
time_index: pd.DatetimeIndex,
base_chunks: dict[str, int],
) -> abc.Iterator[tuple[xbeam.Key, xarray.Dataset]]:
"""Select climatology data matching time_index[time_slice]."""
variable_name, time_slice = variable_name_and_time_slice
chunk_times = time_index[time_slice]
times_array = xarray.DataArray(
chunk_times, dims=['time'], coords={'time': chunk_times}
)
if 'hour' in climatology.coords:
chunk = climatology.sel(
sliced = climatology[[variable_name]].sel(
dayofyear=times_array.dt.dayofyear, hour=times_array.dt.hour
)
del chunk.coords['dayofyear']
del chunk.coords['hour']
del sliced.coords['dayofyear']
del sliced.coords['hour']
else:
chunk = climatology.sel(dayofyear=times_array.dt.dayofyear)
del chunk.coords['dayofyear']
sliced = climatology[[variable_name]].sel(
dayofyear=times_array.dt.dayofyear
)
del sliced.coords['dayofyear']

for variable_name in chunk:
key = xbeam.Key({'time': time_slice.start}, vars={variable_name}) # pytype: disable=wrong-arg-types
yield key, chunk[[variable_name]]
key = xbeam.Key({'time': time_slice.start}, vars={variable_name})
sliced = sliced.compute()
target_chunks = {k: v for k, v in base_chunks.items() if k in sliced.dims}
yield from xbeam.split_chunks(key, sliced, target_chunks)


def main(argv: list[str]) -> None:
Expand Down Expand Up @@ -124,8 +131,9 @@ def main(argv: list[str]) -> None:
time_chunk_size = TIME_CHUNK_SIZE.value

time_chunk_count = math.ceil(times.size / time_chunk_size)

output_chunks = {dim: -1 for dim in input_chunks if dim not in time_dims}
variables = list(climatology.keys())
base_chunks = {k: v for k, v in input_chunks.items() if k not in time_dims}
output_chunks = dict(base_chunks)
output_chunks['time'] = time_chunk_size

# Beam type checking is broken with Python 3.10:
Expand All @@ -137,7 +145,9 @@ def main(argv: list[str]) -> None:
root
| beam.Create([i * time_chunk_size for i in range(time_chunk_count)])
| beam.Map(lambda start: slice(start, start + time_chunk_size))
| beam.FlatMap(select_climatology, climatology, times)
| beam.FlatMap(lambda index: [(v, index) for v in variables])
| beam.Reshuffle()
| beam.FlatMap(select_climatology, climatology, times, base_chunks)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value, template=template, zarr_chunks=output_chunks
)
Expand Down
4 changes: 3 additions & 1 deletion scripts/expand_climatology_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test(self):
input_path = self.create_tempdir('input_path').full_path
output_path = self.create_tempdir('output_path').full_path

climatology.chunk({'dayofyear': 31}).to_zarr(input_path)
climatology.chunk({'dayofyear': 31, 'level': 1}).to_zarr(input_path)

with flagsaver.flagsaver(
input_path=input_path,
Expand All @@ -65,6 +65,8 @@ def test(self):

actual = xarray.open_zarr(output_path)
xarray.testing.assert_allclose(actual, expected)
self.assertEqual(actual.chunks['time'][0], 4 * 31)
self.assertEqual(actual.chunks['level'][0], 1)


if __name__ == '__main__':
Expand Down

0 comments on commit 235765a

Please sign in to comment.