76
76
77
77
78
78
def select_climatology (
79
- time_slice : slice , climatology : xarray .Dataset , time_index : pd .DatetimeIndex
80
- ) -> abc .Iterable [tuple [xbeam .Key , xarray .Dataset ]]:
79
+ variable_name_and_time_slice : tuple [str , slice ],
80
+ climatology : xarray .Dataset ,
81
+ time_index : pd .DatetimeIndex ,
82
+ base_chunks : dict [str , int ],
83
+ ) -> abc .Iterator [tuple [xbeam .Key , xarray .Dataset ]]:
81
84
"""Select climatology data matching time_index[time_slice]."""
85
+ variable_name , time_slice = variable_name_and_time_slice
82
86
chunk_times = time_index [time_slice ]
83
87
times_array = xarray .DataArray (
84
88
chunk_times , dims = ['time' ], coords = {'time' : chunk_times }
85
89
)
86
90
if 'hour' in climatology .coords :
87
- chunk = climatology .sel (
91
+ sliced = climatology [[ variable_name ]] .sel (
88
92
dayofyear = times_array .dt .dayofyear , hour = times_array .dt .hour
89
93
)
90
- del chunk .coords ['dayofyear' ]
91
- del chunk .coords ['hour' ]
94
+ del sliced .coords ['dayofyear' ]
95
+ del sliced .coords ['hour' ]
92
96
else :
93
- chunk = climatology .sel (dayofyear = times_array .dt .dayofyear )
94
- del chunk .coords ['dayofyear' ]
97
+ sliced = climatology [[variable_name ]].sel (
98
+ dayofyear = times_array .dt .dayofyear
99
+ )
100
+ del sliced .coords ['dayofyear' ]
95
101
96
- for variable_name in chunk :
97
- key = xbeam .Key ({'time' : time_slice .start }, vars = {variable_name }) # pytype: disable=wrong-arg-types
98
- yield key , chunk [[variable_name ]]
102
+ key = xbeam .Key ({'time' : time_slice .start }, vars = {variable_name })
103
+ sliced = sliced .compute ()
104
+ target_chunks = {k : v for k , v in base_chunks .items () if k in sliced .dims }
105
+ yield from xbeam .split_chunks (key , sliced , target_chunks )
99
106
100
107
101
108
def main (argv : list [str ]) -> None :
@@ -124,8 +131,9 @@ def main(argv: list[str]) -> None:
124
131
time_chunk_size = TIME_CHUNK_SIZE .value
125
132
126
133
time_chunk_count = math .ceil (times .size / time_chunk_size )
127
-
128
- output_chunks = {dim : - 1 for dim in input_chunks if dim not in time_dims }
134
+ variables = list (climatology .keys ())
135
+ base_chunks = {k : v for k , v in input_chunks .items () if k not in time_dims }
136
+ output_chunks = dict (base_chunks )
129
137
output_chunks ['time' ] = time_chunk_size
130
138
131
139
# Beam type checking is broken with Python 3.10:
@@ -137,7 +145,9 @@ def main(argv: list[str]) -> None:
137
145
root
138
146
| beam .Create ([i * time_chunk_size for i in range (time_chunk_count )])
139
147
| beam .Map (lambda start : slice (start , start + time_chunk_size ))
140
- | beam .FlatMap (select_climatology , climatology , times )
148
+ | beam .FlatMap (lambda index : [(v , index ) for v in variables ])
149
+ | beam .Reshuffle ()
150
+ | beam .FlatMap (select_climatology , climatology , times , base_chunks )
141
151
| xbeam .ChunksToZarr (
142
152
OUTPUT_PATH .value , template = template , zarr_chunks = output_chunks
143
153
)
0 commit comments