Skip to content

Commit

Permalink
Refactor expand_climatology to make it more scalable.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 575355512
  • Loading branch information
shoyer authored and Weatherbench2 authors committed Oct 21, 2023
1 parent 131e05e commit 1e1b36b
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions scripts/expand_climatology.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
--job_name=expand-climatology-$USER
```
"""
from collections import abc
import math

from absl import app
Expand Down Expand Up @@ -76,26 +75,28 @@


def select_climatology(
time_slice: slice, climatology: xarray.Dataset, time_index: pd.DatetimeIndex
) -> abc.Iterable[tuple[xbeam.Key, xarray.Dataset]]:
variable_name_and_time_slice: tuple[str, slice],
climatology: xarray.Dataset,
time_index: pd.DatetimeIndex,
) -> tuple[xbeam.Key, xarray.Dataset]:
"""Select climatology data matching time_index[time_slice]."""
variable_name, time_slice = variable_name_and_time_slice
chunk_times = time_index[time_slice]
times_array = xarray.DataArray(
chunk_times, dims=['time'], coords={'time': chunk_times}
)
if 'hour' in climatology.coords:
chunk = climatology.sel(
chunk = climatology[[variable_name]].sel(
dayofyear=times_array.dt.dayofyear, hour=times_array.dt.hour
)
del chunk.coords['dayofyear']
del chunk.coords['hour']
else:
chunk = climatology.sel(dayofyear=times_array.dt.dayofyear)
chunk = climatology[[variable_name]].sel(dayofyear=times_array.dt.dayofyear)
del chunk.coords['dayofyear']

for variable_name in chunk:
key = xbeam.Key({'time': time_slice.start}, vars={variable_name}) # pytype: disable=wrong-arg-types
yield key, chunk[[variable_name]]
key = xbeam.Key({'time': time_slice.start}, vars={variable_name})
return key, chunk.compute()


def main(argv: list[str]) -> None:
Expand Down Expand Up @@ -125,6 +126,8 @@ def main(argv: list[str]) -> None:

time_chunk_count = math.ceil(times.size / time_chunk_size)

variables = list(climatology.keys())

output_chunks = {dim: -1 for dim in input_chunks if dim not in time_dims}
output_chunks['time'] = time_chunk_size

Expand All @@ -137,7 +140,9 @@ def main(argv: list[str]) -> None:
root
| beam.Create([i * time_chunk_size for i in range(time_chunk_count)])
| beam.Map(lambda start: slice(start, start + time_chunk_size))
| beam.FlatMap(select_climatology, climatology, times)
| beam.FlatMap(lambda index: [(v, index) for v in variables])
| beam.Reshuffle()
| beam.Map(select_climatology, climatology, times)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value, template=template, zarr_chunks=output_chunks
)
Expand Down

0 comments on commit 1e1b36b

Please sign in to comment.