diff --git a/scripts/expand_climatology.py b/scripts/expand_climatology.py index fe2e614..6d9681e 100644 --- a/scripts/expand_climatology.py +++ b/scripts/expand_climatology.py @@ -37,7 +37,6 @@ --job_name=expand-climatology-$USER ``` """ -from collections import abc import math from absl import app @@ -76,26 +75,28 @@ def select_climatology( - time_slice: slice, climatology: xarray.Dataset, time_index: pd.DatetimeIndex -) -> abc.Iterable[tuple[xbeam.Key, xarray.Dataset]]: + variable_name_and_time_slice: tuple[str, slice], + climatology: xarray.Dataset, + time_index: pd.DatetimeIndex, +) -> tuple[xbeam.Key, xarray.Dataset]: """Select climatology data matching time_index[time_slice].""" + variable_name, time_slice = variable_name_and_time_slice chunk_times = time_index[time_slice] times_array = xarray.DataArray( chunk_times, dims=['time'], coords={'time': chunk_times} ) if 'hour' in climatology.coords: - chunk = climatology.sel( + chunk = climatology[[variable_name]].sel( dayofyear=times_array.dt.dayofyear, hour=times_array.dt.hour ) del chunk.coords['dayofyear'] del chunk.coords['hour'] else: - chunk = climatology.sel(dayofyear=times_array.dt.dayofyear) + chunk = climatology[[variable_name]].sel(dayofyear=times_array.dt.dayofyear) del chunk.coords['dayofyear'] - for variable_name in chunk: - key = xbeam.Key({'time': time_slice.start}, vars={variable_name}) # pytype: disable=wrong-arg-types - yield key, chunk[[variable_name]] + key = xbeam.Key({'time': time_slice.start}, vars={variable_name}) + return key, chunk.compute() def main(argv: list[str]) -> None: @@ -125,6 +126,8 @@ def main(argv: list[str]) -> None: time_chunk_count = math.ceil(times.size / time_chunk_size) + variables = list(climatology.keys()) + output_chunks = {dim: -1 for dim in input_chunks if dim not in time_dims} output_chunks['time'] = time_chunk_size @@ -137,7 +140,9 @@ def main(argv: list[str]) -> None: root | beam.Create([i * time_chunk_size for i in range(time_chunk_count)]) | beam.Map(lambda start: slice(start, start + time_chunk_size)) - | beam.FlatMap(select_climatology, climatology, times) + | beam.FlatMap(lambda index: [(v, index) for v in variables]) + | beam.Reshuffle() + | beam.Map(select_climatology, climatology, times) | xbeam.ChunksToZarr( OUTPUT_PATH.value, template=template, zarr_chunks=output_chunks )