diff --git a/scripts/expand_climatology.py b/scripts/expand_climatology.py index fe2e614..699e9c8 100644 --- a/scripts/expand_climatology.py +++ b/scripts/expand_climatology.py @@ -76,26 +76,33 @@ def select_climatology( - time_slice: slice, climatology: xarray.Dataset, time_index: pd.DatetimeIndex -) -> abc.Iterable[tuple[xbeam.Key, xarray.Dataset]]: + variable_name_and_time_slice: tuple[str, slice], + climatology: xarray.Dataset, + time_index: pd.DatetimeIndex, + base_chunks: dict[str, int], +) -> abc.Iterator[tuple[xbeam.Key, xarray.Dataset]]: """Select climatology data matching time_index[time_slice].""" + variable_name, time_slice = variable_name_and_time_slice chunk_times = time_index[time_slice] times_array = xarray.DataArray( chunk_times, dims=['time'], coords={'time': chunk_times} ) if 'hour' in climatology.coords: - chunk = climatology.sel( + sliced = climatology[[variable_name]].sel( dayofyear=times_array.dt.dayofyear, hour=times_array.dt.hour ) - del chunk.coords['dayofyear'] - del chunk.coords['hour'] + del sliced.coords['dayofyear'] + del sliced.coords['hour'] else: - chunk = climatology.sel(dayofyear=times_array.dt.dayofyear) - del chunk.coords['dayofyear'] + sliced = climatology[[variable_name]].sel( + dayofyear=times_array.dt.dayofyear + ) + del sliced.coords['dayofyear'] - for variable_name in chunk: - key = xbeam.Key({'time': time_slice.start}, vars={variable_name}) # pytype: disable=wrong-arg-types - yield key, chunk[[variable_name]] + key = xbeam.Key({'time': time_slice.start}, vars={variable_name}) + sliced = sliced.compute() + target_chunks = {k: v for k, v in base_chunks.items() if k in sliced.dims} + yield from xbeam.split_chunks(key, sliced, target_chunks) def main(argv: list[str]) -> None: @@ -124,8 +131,9 @@ def main(argv: list[str]) -> None: time_chunk_size = TIME_CHUNK_SIZE.value time_chunk_count = math.ceil(times.size / time_chunk_size) - - output_chunks = {dim: -1 for dim in input_chunks if dim not in time_dims} + variables = list(climatology.keys()) + base_chunks = {k: v for k, v in input_chunks.items() if k not in time_dims} + output_chunks = dict(base_chunks) output_chunks['time'] = time_chunk_size # Beam type checking is broken with Python 3.10: @@ -137,7 +145,9 @@ def main(argv: list[str]) -> None: root | beam.Create([i * time_chunk_size for i in range(time_chunk_count)]) | beam.Map(lambda start: slice(start, start + time_chunk_size)) - | beam.FlatMap(select_climatology, climatology, times) + | beam.FlatMap(lambda index: [(v, index) for v in variables]) + | beam.Reshuffle() + | beam.FlatMap(select_climatology, climatology, times, base_chunks) | xbeam.ChunksToZarr( OUTPUT_PATH.value, template=template, zarr_chunks=output_chunks ) diff --git a/scripts/expand_climatology_test.py b/scripts/expand_climatology_test.py index e404cc3..5381fca 100644 --- a/scripts/expand_climatology_test.py +++ b/scripts/expand_climatology_test.py @@ -52,7 +52,7 @@ def test(self): input_path = self.create_tempdir('input_path').full_path output_path = self.create_tempdir('output_path').full_path - climatology.chunk({'dayofyear': 31}).to_zarr(input_path) + climatology.chunk({'dayofyear': 31, 'level': 1}).to_zarr(input_path) with flagsaver.flagsaver( input_path=input_path, @@ -65,6 +65,8 @@ def test(self): actual = xarray.open_zarr(output_path) xarray.testing.assert_allclose(actual, expected) + self.assertEqual(actual.chunks['time'][0], 4 * 31) + self.assertEqual(actual.chunks['level'][0], 1) if __name__ == '__main__':