diff --git a/scripts/compute_climatology.py b/scripts/compute_climatology.py index 40b1857..9ab7f3f 100644 --- a/scripts/compute_climatology.py +++ b/scripts/compute_climatology.py @@ -36,7 +36,7 @@ """ import ast import functools -from typing import Optional +from typing import Callable, Optional, Union from absl import app from absl import flags @@ -98,13 +98,9 @@ STATISTICS = flags.DEFINE_list( 'statistics', ['mean'], - help='Statistics to compute from "mean", "std", "seeps".', -) -ADD_STATISTIC_SUFFIX = flags.DEFINE_bool( - 'add_statistic_suffix', - False, - 'Add suffix of statistic to variable name. Required for >1 statistic.', + help='Statistics to compute from "mean", "std", "seeps", "quantile".', ) +QUANTILES = flags.DEFINE_list('quantiles', [], 'List of quantiles to compute.') METHOD = flags.DEFINE_string( 'method', 'explicit', @@ -126,6 +122,23 @@ ) +class Quantile: + """Compute quantiles.""" + + def __init__(self, quantiles: list[float]): + self.quantiles = quantiles + + def compute( + self, + ds: xr.Dataset, + dim: tuple[str], + weights: Optional[xr.Dataset] = None, + ): + if weights is not None: + ds = ds.weighted(weights) # pytype: disable=wrong-arg-types + return ds.quantile(self.quantiles, dim=dim) + + class SEEPSThreshold: """Compute SEEPS thresholds (heav/light) and fraction of dry grid points.""" @@ -201,23 +214,25 @@ def compute_stat_chunk( frequency: str, window_size: int, clim_years: slice, - statistic: str = 'mean', + statistic: Union[str, Callable[..., xr.Dataset]] = 'mean', hour_interval: Optional[int] = None, - add_statistic_suffix: bool = False, + quantiles: Optional[list[float]] = None, ) -> tuple[xbeam.Key, xr.Dataset]: """Compute climatology on a chunk.""" - if statistic not in ['mean', 'std']: + if statistic not in ['mean', 'std', 'quantile']: raise NotImplementedError(f'stat {statistic} not implemented.') offsets = dict(dayofyear=0) if frequency == 'hourly': offsets['hour'] = 0 clim_key = obs_key.with_offsets(time=None, **offsets) - if add_statistic_suffix: + if statistic != 'mean': clim_key = clim_key.replace( vars={f'{var}_{statistic}' for var in clim_key.vars} ) for var in obs_chunk: obs_chunk = obs_chunk.rename({var: f'{var}_{statistic}'}) + if statistic == 'quantile': + statistic = Quantile(quantiles).compute compute_kwargs = { 'obs': obs_chunk, 'window_size': window_size, @@ -246,12 +261,6 @@ def compute_stat_chunk( def main(argv: list[str]) -> None: - non_seeps_stats = [stat for stat in STATISTICS.value if stat != 'seeps'] - if not ADD_STATISTIC_SUFFIX.value and len(non_seeps_stats) > 1: - raise ValueError( - '--add_statistic_suffix is required for >1 non-SEEPS statistics.' - ) - obs, input_chunks = xbeam.open_zarr(INPUT_PATH.value) # Convert object-type coordinates to string. @@ -312,14 +321,22 @@ def _compute_seeps(kv): (var,) = k.vars return var in seeps_dry_threshold_mm.keys() + quantiles = [float(q) for q in QUANTILES.value] for stat in STATISTICS.value: - if stat != 'seeps': - if ADD_STATISTIC_SUFFIX.value: - for var in raw_vars: - clim_template = clim_template.assign( - {f'{var}_{stat}': clim_template[var]} + if stat not in ['seeps', 'mean']: + for var in raw_vars: + if stat == 'quantile': + quantile_dim = xr.DataArray( + quantiles, name='quantile', dims=['quantile'] ) - if ADD_STATISTIC_SUFFIX.value: + temp = clim_template[var].expand_dims(quantile=quantile_dim) + if 'hour' in temp.dims: + temp = temp.transpose('hour', 'quantile', ...) + else: + temp = clim_template[var] + clim_template = clim_template.assign({f'{var}_{stat}': temp}) + # Mean has no suffix. Delete no suffix variables if no mean required + if 'mean' not in STATISTICS.value: for var in raw_vars: clim_template = clim_template.drop(var) @@ -370,7 +387,7 @@ def _compute_seeps(kv): window_size=WINDOW_SIZE.value, clim_years=slice(str(START_YEAR.value), str(END_YEAR.value)), statistic=stat, - add_statistic_suffix=ADD_STATISTIC_SUFFIX.value, + quantiles=quantiles, **stat_kwargs, ) )