Skip to content

Commit 545481f

Browse files
author
Weatherbench authors
committed
Adds option to compute any climatological quantile.
Potentially breaking change: Removes add_statistic_suffix option. Now, 'mean' will always be without suffix, and all other statistics have a suffix. PiperOrigin-RevId: 580574874
1 parent 1167715 commit 545481f

File tree

1 file changed

+41
-24
lines changed

1 file changed

+41
-24
lines changed

scripts/compute_climatology.py

+41-24
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"""
3737
import ast
3838
import functools
39-
from typing import Optional
39+
from typing import Callable, Optional, Union
4040

4141
from absl import app
4242
from absl import flags
@@ -98,13 +98,9 @@
9898
STATISTICS = flags.DEFINE_list(
9999
'statistics',
100100
['mean'],
101-
help='Statistics to compute from "mean", "std", "seeps".',
102-
)
103-
ADD_STATISTIC_SUFFIX = flags.DEFINE_bool(
104-
'add_statistic_suffix',
105-
False,
106-
'Add suffix of statistic to variable name. Required for >1 statistic.',
101+
help='Statistics to compute from "mean", "std", "seeps", "quantile".',
107102
)
103+
QUANTILES = flags.DEFINE_list('quantiles', [], 'List of quantiles to compute.')
108104
METHOD = flags.DEFINE_string(
109105
'method',
110106
'explicit',
@@ -126,6 +122,23 @@
126122
)
127123

128124

125+
class Quantile:
126+
"""Compute quantiles."""
127+
128+
def __init__(self, quantiles: list[float]):
129+
self.quantiles = quantiles
130+
131+
def compute(
132+
self,
133+
ds: xr.Dataset,
134+
dim: tuple[str],
135+
weights: Optional[xr.Dataset] = None,
136+
):
137+
if weights is not None:
138+
ds = ds.weighted(weights) # pytype: disable=wrong-arg-types
139+
return ds.quantile(self.quantiles, dim=dim)
140+
141+
129142
class SEEPSThreshold:
130143
"""Compute SEEPS thresholds (heav/light) and fraction of dry grid points."""
131144

@@ -201,23 +214,25 @@ def compute_stat_chunk(
201214
frequency: str,
202215
window_size: int,
203216
clim_years: slice,
204-
statistic: str = 'mean',
217+
statistic: Union[str, Callable[..., xr.Dataset]] = 'mean',
205218
hour_interval: Optional[int] = None,
206-
add_statistic_suffix: bool = False,
219+
quantiles: Optional[list[float]] = None,
207220
) -> tuple[xbeam.Key, xr.Dataset]:
208221
"""Compute climatology on a chunk."""
209-
if statistic not in ['mean', 'std']:
222+
if statistic not in ['mean', 'std', 'quantile']:
210223
raise NotImplementedError(f'stat {statistic} not implemented.')
211224
offsets = dict(dayofyear=0)
212225
if frequency == 'hourly':
213226
offsets['hour'] = 0
214227
clim_key = obs_key.with_offsets(time=None, **offsets)
215-
if add_statistic_suffix:
228+
if statistic != 'mean':
216229
clim_key = clim_key.replace(
217230
vars={f'{var}_{statistic}' for var in clim_key.vars}
218231
)
219232
for var in obs_chunk:
220233
obs_chunk = obs_chunk.rename({var: f'{var}_{statistic}'})
234+
if statistic == 'quantile':
235+
statistic = Quantile(quantiles).compute
221236
compute_kwargs = {
222237
'obs': obs_chunk,
223238
'window_size': window_size,
@@ -246,12 +261,6 @@ def compute_stat_chunk(
246261

247262

248263
def main(argv: list[str]) -> None:
249-
non_seeps_stats = [stat for stat in STATISTICS.value if stat != 'seeps']
250-
if not ADD_STATISTIC_SUFFIX.value and len(non_seeps_stats) > 1:
251-
raise ValueError(
252-
'--add_statistic_suffix is required for >1 non-SEEPS statistics.'
253-
)
254-
255264
obs, input_chunks = xbeam.open_zarr(INPUT_PATH.value)
256265

257266
# Convert object-type coordinates to string.
@@ -312,14 +321,22 @@ def _compute_seeps(kv):
312321
(var,) = k.vars
313322
return var in seeps_dry_threshold_mm.keys()
314323

324+
quantiles = [float(q) for q in QUANTILES.value]
315325
for stat in STATISTICS.value:
316-
if stat != 'seeps':
317-
if ADD_STATISTIC_SUFFIX.value:
318-
for var in raw_vars:
319-
clim_template = clim_template.assign(
320-
{f'{var}_{stat}': clim_template[var]}
326+
if stat not in ['seeps', 'mean']:
327+
for var in raw_vars:
328+
if stat == 'quantile':
329+
quantile_dim = xr.DataArray(
330+
quantiles, name='quantile', dims=['quantile']
321331
)
322-
if ADD_STATISTIC_SUFFIX.value:
332+
temp = clim_template[var].expand_dims(quantile=quantile_dim)
333+
if 'hour' in temp.dims:
334+
temp = temp.transpose('hour', 'quantile', ...)
335+
else:
336+
temp = clim_template[var]
337+
clim_template = clim_template.assign({f'{var}_{stat}': temp})
338+
# Mean has no suffix. Delete no suffix variables if no mean required
339+
if 'mean' not in STATISTICS.value:
323340
for var in raw_vars:
324341
clim_template = clim_template.drop(var)
325342

@@ -370,7 +387,7 @@ def _compute_seeps(kv):
370387
window_size=WINDOW_SIZE.value,
371388
clim_years=slice(str(START_YEAR.value), str(END_YEAR.value)),
372389
statistic=stat,
373-
add_statistic_suffix=ADD_STATISTIC_SUFFIX.value,
390+
quantiles=quantiles,
374391
**stat_kwargs,
375392
)
376393
)

0 commit comments

Comments
 (0)