Skip to content

Commit f34bf3d

Browse files
langmoreWeatherbench2 authors
authored and
Weatherbench2 authors
committed
No public description
PiperOrigin-RevId: 629874668
1 parent 1d3cb4d commit f34bf3d

15 files changed

+748
-46
lines changed

docs/source/command-line-scripts.md

+43-1
Original file line numberDiff line numberDiff line change
@@ -429,10 +429,52 @@ _Command options_:
429429
* `--working_chunks`: Spatial chunk sizes to use during time downsampling, e.g., "longitude=10,latitude=10". They may not include "time".
430430
* `--beam_runner`: Beam runner. Use `DirectRunner` for local execution.
431431

432+
## Slice dataset
433+
Slices a Zarr file containing an xarray Dataset, using `.sel` and `.isel`.
434+
435+
```
436+
usage: slice_dataset.py [-h]
437+
[--input_path INPUT_PATH]
438+
[--output_path OUTPUT_PATH]
439+
[--sel SEL]
440+
[--isel ISEL]
441+
[--drop_variables DROP_VARIABLES]
442+
[--keep_variables KEEP_VARIABLES]
443+
[--output_chunks OUTPUT_CHUNKS]
444+
[--runner RUNNER]
445+
446+
```
447+
448+
_Command options_:
449+
450+
* `--input_path`: (required) Input Zarr path
451+
* `--output_path`: (required) Output Zarr path
452+
* `--sel`: Selection criteria, to pass to `xarray.Dataset.sel`. Passed as
453+
key=value pairs, with key = `VARNAME_{start,stop,step}`
454+
* `--isel`: Selection criteria, to pass to `xarray.Dataset.isel`. Passed as
455+
key=value pairs, with key = `VARNAME_{start,stop,step}`
456+
* `--drop_variables`: Comma delimited list of variables to drop. If empty, drop
457+
no variables.
458+
* `--keep_variables`: Comma delimited list of variables to keep. If empty, use
459+
`--drop_variables` to determine which variables to keep.
460+
* `--output_chunks`: Chunk sizes overriding input chunks.
461+
* `--runner`: Beam runner. Use `DirectRunner` for local execution.
462+
463+
*Example*
464+
465+
```bash
466+
python slice_dataset.py -- \
467+
--input_path=gs://weatherbench2/datasets/ens/2018-64x32_equiangular_with_poles_conservative.zarr \
468+
--output_path=PATH \
469+
--sel="prediction_timedelta_stop=15 days,latitude_start=-33.33,latitude_stop=33.33" \
470+
--isel="longitude_start=0,longitude_stop=180,longitude_step=40" \
471+
--keep_variables=geopotential,temperature
472+
```
473+
432474
## Expand climatology
433475

434476
`expand_climatology.py` takes a climatology dataset and expands it into a forecast-like format (`init_time` + `lead_time`). This is not currently used as `evaluation.py` is able to do this on-the-fly, reducing the number of intermediate steps. We still included the script here in case others find it useful.
435477

436478
## Init to valid time conversion
437479

438-
`compute_init_to_valid_time.py` converts a forecasts in init-time convention to valid-time convention. Since currently, we do all evaluation in the init-time format, this script is not used.
480+
`compute_init_to_valid_time.py` converts a forecasts in init-time convention to valid-time convention. Since currently, we do all evaluation in the init-time format, this script is not used.

scripts/compute_averages.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@
8686
None,
8787
help='Beam CombineFn fanout. Might be required for large dataset.',
8888
)
89+
NUM_THREADS = flags.DEFINE_integer(
90+
'num_threads',
91+
None,
92+
help='Number of chunks to read/write in parallel per worker.',
93+
)
8994

9095

9196
# pylint: disable=expression-not-assigned
@@ -120,7 +125,10 @@ def main(argv: list[str]):
120125

121126
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
122127
chunked = root | xbeam.DatasetToChunks(
123-
source_dataset, source_chunks, split_vars=True
128+
source_dataset,
129+
source_chunks,
130+
split_vars=True,
131+
num_threads=NUM_THREADS.value,
124132
)
125133

126134
if weights is not None:
@@ -131,7 +139,12 @@ def main(argv: list[str]):
131139
(
132140
chunked
133141
| xbeam.Mean(AVERAGING_DIMS.value, skipna=False, fanout=FANOUT.value)
134-
| xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks)
142+
| xbeam.ChunksToZarr(
143+
OUTPUT_PATH.value,
144+
template,
145+
target_chunks,
146+
num_threads=NUM_THREADS.value,
147+
)
135148
)
136149

137150

scripts/compute_climatology.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@
120120
'precipitation variable. In mm.'
121121
),
122122
)
123+
NUM_THREADS = flags.DEFINE_integer(
124+
'num_threads',
125+
None,
126+
help='Number of chunks to read/write in parallel per worker.',
127+
)
123128

124129

125130
class Quantile:
@@ -330,6 +335,10 @@ def _compute_seeps(kv):
330335
if stat not in ['seeps', 'mean']:
331336
for var in raw_vars:
332337
if stat == 'quantile':
338+
if not quantiles:
339+
raise ValueError(
340+
'Cannot compute stat `quantile` without specifying --quantiles.'
341+
)
333342
quantile_dim = xr.DataArray(
334343
quantiles, name='quantile', dims=['quantile']
335344
)
@@ -349,7 +358,10 @@ def _compute_seeps(kv):
349358
pcoll = (
350359
root
351360
| xbeam.DatasetToChunks(
352-
obs, input_chunks, split_vars=True, num_threads=16
361+
obs,
362+
input_chunks,
363+
split_vars=True,
364+
num_threads=NUM_THREADS.value,
353365
)
354366
| 'RechunkIn'
355367
>> xbeam.Rechunk( # pytype: disable=wrong-arg-types
@@ -412,7 +424,7 @@ def _compute_seeps(kv):
412424
OUTPUT_PATH.value,
413425
template=clim_template,
414426
zarr_chunks=output_chunks,
415-
num_threads=16,
427+
num_threads=NUM_THREADS.value,
416428
)
417429
)
418430

scripts/compute_derived_variables.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@
116116
MAX_MEM_GB = flags.DEFINE_integer(
117117
'max_mem_gb', 1, help='Max memory for rechunking in GB.'
118118
)
119+
NUM_THREADS = flags.DEFINE_integer(
120+
'num_threads',
121+
None,
122+
help='Number of chunks to read/write in parallel per worker.',
123+
)
119124

120125
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')
121126

@@ -226,7 +231,12 @@ def _is_not_precip(kv: tuple[xbeam.Key, xr.Dataset]) -> bool:
226231
# so that with and without rechunking can be computed in parallel
227232
pcoll = (
228233
root
229-
| xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=False)
234+
| xbeam.DatasetToChunks(
235+
source_dataset,
236+
source_chunks,
237+
split_vars=False,
238+
num_threads=NUM_THREADS.value,
239+
)
230240
| beam.MapTuple(
231241
lambda k, v: ( # pylint: disable=g-long-lambda
232242
k,
@@ -274,7 +284,10 @@ def _is_not_precip(kv: tuple[xbeam.Key, xr.Dataset]) -> bool:
274284

275285
# Combined
276286
_ = pcoll | xbeam.ChunksToZarr(
277-
OUTPUT_PATH.value, template, source_chunks, num_threads=16
287+
OUTPUT_PATH.value,
288+
template,
289+
source_chunks,
290+
num_threads=NUM_THREADS.value,
278291
)
279292

280293

scripts/compute_ensemble_mean.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@
6161
'2020-12-31',
6262
help='ISO 8601 timestamp (inclusive) at which to stop evaluation',
6363
)
64+
NUM_THREADS = flags.DEFINE_integer(
65+
'num_threads',
66+
None,
67+
help='Number of chunks to read/write in parallel per worker.',
68+
)
6469

6570

6671
# pylint: disable=expression-not-assigned
@@ -88,9 +93,19 @@ def main(argv: list[str]):
8893
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
8994
(
9095
root
91-
| xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=True)
96+
| xbeam.DatasetToChunks(
97+
source_dataset,
98+
source_chunks,
99+
split_vars=True,
100+
num_threads=NUM_THREADS.value,
101+
)
92102
| xbeam.Mean(REALIZATION_NAME.value, skipna=False)
93-
| xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks)
103+
| xbeam.ChunksToZarr(
104+
OUTPUT_PATH.value,
105+
template,
106+
target_chunks,
107+
num_threads=NUM_THREADS.value,
108+
)
94109
)
95110

96111

scripts/compute_statistical_moments.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
RECHUNK_ITEMSIZE = flags.DEFINE_integer(
3838
'rechunk_itemsize', 4, help='Itemsize for rechunking.'
3939
)
40+
NUM_THREADS = flags.DEFINE_integer(
41+
'num_threads',
42+
None,
43+
help='Number of chunks to read/write in parallel per worker.',
44+
)
4045

4146

4247
def moment_reduce(
@@ -143,7 +148,9 @@ def main(argv: list[str]) -> None:
143148

144149
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
145150
# Read
146-
pcoll = root | xbeam.DatasetToChunks(obs, input_chunks, split_vars=True)
151+
pcoll = root | xbeam.DatasetToChunks(
152+
obs, input_chunks, split_vars=True, num_threads=NUM_THREADS.value
153+
)
147154

148155
# Branches to compute statistical moments
149156
pcolls = []
@@ -174,6 +181,7 @@ def main(argv: list[str]) -> None:
174181
OUTPUT_PATH.value,
175182
template=output_template,
176183
zarr_chunks=output_chunks,
184+
num_threads=NUM_THREADS.value,
177185
)
178186
)
179187

scripts/compute_zonal_energy_spectrum.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@
9696
None,
9797
help='Beam CombineFn fanout. Might be required for large dataset.',
9898
)
99+
NUM_THREADS = flags.DEFINE_integer(
100+
'num_threads',
101+
None,
102+
help='Number of chunks to read/write in parallel per worker.',
103+
)
99104

100105
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')
101106

@@ -196,7 +201,12 @@ def main(argv: list[str]) -> None:
196201
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
197202
_ = (
198203
root
199-
| xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=False)
204+
| xbeam.DatasetToChunks(
205+
source_dataset,
206+
source_chunks,
207+
split_vars=False,
208+
num_threads=NUM_THREADS.value,
209+
)
200210
| beam.MapTuple(
201211
lambda k, v: ( # pylint: disable=g-long-lambda
202212
k,
@@ -207,7 +217,10 @@ def main(argv: list[str]) -> None:
207217
| beam.MapTuple(_strip_offsets)
208218
| xbeam.Mean(AVERAGING_DIMS.value, fanout=FANOUT.value)
209219
| xbeam.ChunksToZarr(
210-
OUTPUT_PATH.value, template, output_chunks, num_threads=16
220+
OUTPUT_PATH.value,
221+
template,
222+
output_chunks,
223+
num_threads=NUM_THREADS.value,
211224
)
212225
)
213226

scripts/convert_init_to_valid_time.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@
102102
INPUT_PATH = flags.DEFINE_string('input_path', None, help='zarr inputs')
103103
OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='zarr outputs')
104104
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')
105+
NUM_THREADS = flags.DEFINE_integer(
106+
'num_threads',
107+
None,
108+
help='Number of chunks to read/write in parallel per worker.',
109+
)
105110

106111
TIME = 'time'
107112
DELTA = 'prediction_timedelta'
@@ -254,7 +259,9 @@ def main(argv: list[str]) -> None:
254259
source_ds.indexes[INIT],
255260
)
256261
)
257-
p |= xarray_beam.DatasetToChunks(source_ds, input_chunks, split_vars=True)
262+
p |= xarray_beam.DatasetToChunks(
263+
source_ds, input_chunks, split_vars=True, num_threads=NUM_THREADS.value
264+
)
258265
if input_chunks != split_chunks:
259266
p |= xarray_beam.SplitChunks(split_chunks)
260267
p |= beam.FlatMapTuple(
@@ -266,7 +273,12 @@ def main(argv: list[str]) -> None:
266273
p = (p, padding) | beam.Flatten()
267274
if input_chunks != split_chunks:
268275
p |= xarray_beam.ConsolidateChunks(output_chunks)
269-
p |= xarray_beam.ChunksToZarr(OUTPUT_PATH.value, template, output_chunks)
276+
p |= xarray_beam.ChunksToZarr(
277+
OUTPUT_PATH.value,
278+
template,
279+
output_chunks,
280+
num_threads=NUM_THREADS.value,
281+
)
270282

271283

272284
if __name__ == '__main__':

scripts/expand_climatology.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@
7272
None,
7373
help='Desired integer chunk size. If not set, inferred from input chunks.',
7474
)
75+
NUM_THREADS = flags.DEFINE_integer(
76+
'num_threads',
77+
None,
78+
help='Number of chunks to read/write in parallel per worker.',
79+
)
7580
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')
7681

7782

@@ -149,7 +154,10 @@ def main(argv: list[str]) -> None:
149154
| beam.Reshuffle()
150155
| beam.FlatMap(select_climatology, climatology, times, base_chunks)
151156
| xbeam.ChunksToZarr(
152-
OUTPUT_PATH.value, template=template, zarr_chunks=output_chunks
157+
OUTPUT_PATH.value,
158+
template=template,
159+
zarr_chunks=output_chunks,
160+
num_threads=NUM_THREADS.value,
153161
)
154162
)
155163

scripts/regrid.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@
7878
LONGITUDE_NAME = flags.DEFINE_string(
7979
'longitude_name', 'longitude', help='Name of longitude dimension in dataset'
8080
)
81+
NUM_THREADS = flags.DEFINE_integer(
82+
'num_threads',
83+
None,
84+
help='Number of chunks to read/write in parallel per worker.',
85+
)
8186
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')
8287

8388

@@ -135,11 +140,21 @@ def main(argv):
135140
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
136141
_ = (
137142
root
138-
| xarray_beam.DatasetToChunks(source_ds, input_chunks, split_vars=True)
143+
| xarray_beam.DatasetToChunks(
144+
source_ds,
145+
input_chunks,
146+
split_vars=True,
147+
num_threads=NUM_THREADS.value,
148+
)
139149
| 'Regrid'
140150
>> beam.MapTuple(lambda k, v: (k, regridder.regrid_dataset(v)))
141151
| xarray_beam.ConsolidateChunks(output_chunks)
142-
| xarray_beam.ChunksToZarr(OUTPUT_PATH.value, template, output_chunks)
152+
| xarray_beam.ChunksToZarr(
153+
OUTPUT_PATH.value,
154+
template,
155+
output_chunks,
156+
num_threads=NUM_THREADS.value,
157+
)
143158
)
144159

145160

0 commit comments

Comments
 (0)