|
36 | 36 | """
|
37 | 37 | import ast
|
38 | 38 | import functools
|
39 |
| -from typing import Optional |
| 39 | +from typing import Callable, Optional, Union |
40 | 40 |
|
41 | 41 | from absl import app
|
42 | 42 | from absl import flags
|
|
98 | 98 | STATISTICS = flags.DEFINE_list(
|
99 | 99 | 'statistics',
|
100 | 100 | ['mean'],
|
101 |
| - help='Statistics to compute from "mean", "std", "seeps".', |
102 |
| -) |
103 |
| -ADD_STATISTIC_SUFFIX = flags.DEFINE_bool( |
104 |
| - 'add_statistic_suffix', |
105 |
| - False, |
106 |
| - 'Add suffix of statistic to variable name. Required for >1 statistic.', |
| 101 | + help='Statistics to compute from "mean", "std", "seeps", "quantile".', |
107 | 102 | )
|
| 103 | +QUANTILES = flags.DEFINE_list('quantiles', [], 'List of quantiles to compute.') |
108 | 104 | METHOD = flags.DEFINE_string(
|
109 | 105 | 'method',
|
110 | 106 | 'explicit',
|
|
126 | 122 | )
|
127 | 123 |
|
128 | 124 |
|
| 125 | +class Quantile: |
| 126 | + """Compute quantiles.""" |
| 127 | + |
| 128 | + def __init__(self, quantiles: list[float]): |
| 129 | + self.quantiles = quantiles |
| 130 | + |
| 131 | + def compute( |
| 132 | + self, |
| 133 | + ds: xr.Dataset, |
| 134 | + dim: tuple[str], |
| 135 | + weights: Optional[xr.Dataset] = None, |
| 136 | + ): |
| 137 | + if weights is not None: |
| 138 | + ds = ds.weighted(weights) # pytype: disable=wrong-arg-types |
| 139 | + return ds.quantile(self.quantiles, dim=dim) |
| 140 | + |
| 141 | + |
129 | 142 | class SEEPSThreshold:
|
130 | 143 | """Compute SEEPS thresholds (heav/light) and fraction of dry grid points."""
|
131 | 144 |
|
@@ -201,23 +214,25 @@ def compute_stat_chunk(
|
201 | 214 | frequency: str,
|
202 | 215 | window_size: int,
|
203 | 216 | clim_years: slice,
|
204 |
| - statistic: str = 'mean', |
| 217 | + statistic: Union[str, Callable[..., xr.Dataset]] = 'mean', |
205 | 218 | hour_interval: Optional[int] = None,
|
206 |
| - add_statistic_suffix: bool = False, |
| 219 | + quantiles: Optional[list[float]] = None, |
207 | 220 | ) -> tuple[xbeam.Key, xr.Dataset]:
|
208 | 221 | """Compute climatology on a chunk."""
|
209 |
| - if statistic not in ['mean', 'std']: |
| 222 | + if statistic not in ['mean', 'std', 'quantile']: |
210 | 223 | raise NotImplementedError(f'stat {statistic} not implemented.')
|
211 | 224 | offsets = dict(dayofyear=0)
|
212 | 225 | if frequency == 'hourly':
|
213 | 226 | offsets['hour'] = 0
|
214 | 227 | clim_key = obs_key.with_offsets(time=None, **offsets)
|
215 |
| - if add_statistic_suffix: |
| 228 | + if statistic != 'mean': |
216 | 229 | clim_key = clim_key.replace(
|
217 | 230 | vars={f'{var}_{statistic}' for var in clim_key.vars}
|
218 | 231 | )
|
219 | 232 | for var in obs_chunk:
|
220 | 233 | obs_chunk = obs_chunk.rename({var: f'{var}_{statistic}'})
|
| 234 | + if statistic == 'quantile': |
| 235 | + statistic = Quantile(quantiles).compute |
221 | 236 | compute_kwargs = {
|
222 | 237 | 'obs': obs_chunk,
|
223 | 238 | 'window_size': window_size,
|
@@ -246,12 +261,6 @@ def compute_stat_chunk(
|
246 | 261 |
|
247 | 262 |
|
248 | 263 | def main(argv: list[str]) -> None:
|
249 |
| - non_seeps_stats = [stat for stat in STATISTICS.value if stat != 'seeps'] |
250 |
| - if not ADD_STATISTIC_SUFFIX.value and len(non_seeps_stats) > 1: |
251 |
| - raise ValueError( |
252 |
| - '--add_statistic_suffix is required for >1 non-SEEPS statistics.' |
253 |
| - ) |
254 |
| - |
255 | 264 | obs, input_chunks = xbeam.open_zarr(INPUT_PATH.value)
|
256 | 265 |
|
257 | 266 | # Convert object-type coordinates to string.
|
@@ -312,14 +321,22 @@ def _compute_seeps(kv):
|
312 | 321 | (var,) = k.vars
|
313 | 322 | return var in seeps_dry_threshold_mm.keys()
|
314 | 323 |
|
| 324 | + quantiles = [float(q) for q in QUANTILES.value] |
315 | 325 | for stat in STATISTICS.value:
|
316 |
| - if stat != 'seeps': |
317 |
| - if ADD_STATISTIC_SUFFIX.value: |
318 |
| - for var in raw_vars: |
319 |
| - clim_template = clim_template.assign( |
320 |
| - {f'{var}_{stat}': clim_template[var]} |
| 326 | + if stat not in ['seeps', 'mean']: |
| 327 | + for var in raw_vars: |
| 328 | + if stat == 'quantile': |
| 329 | + quantile_dim = xr.DataArray( |
| 330 | + quantiles, name='quantile', dims=['quantile'] |
321 | 331 | )
|
322 |
| - if ADD_STATISTIC_SUFFIX.value: |
| 332 | + temp = clim_template[var].expand_dims(quantile=quantile_dim) |
| 333 | + if 'hour' in temp.dims: |
| 334 | + temp = temp.transpose('hour', 'quantile', ...) |
| 335 | + else: |
| 336 | + temp = clim_template[var] |
| 337 | + clim_template = clim_template.assign({f'{var}_{stat}': temp}) |
| 338 | + # Mean has no suffix. Delete no suffix variables if no mean required |
| 339 | + if 'mean' not in STATISTICS.value: |
323 | 340 | for var in raw_vars:
|
324 | 341 | clim_template = clim_template.drop(var)
|
325 | 342 |
|
@@ -370,7 +387,7 @@ def _compute_seeps(kv):
|
370 | 387 | window_size=WINDOW_SIZE.value,
|
371 | 388 | clim_years=slice(str(START_YEAR.value), str(END_YEAR.value)),
|
372 | 389 | statistic=stat,
|
373 |
| - add_statistic_suffix=ADD_STATISTIC_SUFFIX.value, |
| 390 | + quantiles=quantiles, |
374 | 391 | **stat_kwargs,
|
375 | 392 | )
|
376 | 393 | )
|
|
0 commit comments