Skip to content

Commit

Permalink
Adds option to compute any climatological quantile.
Browse files Browse the repository at this point in the history
Potentially breaking change: Removes add_statistic_suffix option. Now, 'mean' will always be without suffix, and all other statistics have a suffix.

PiperOrigin-RevId: 580158565
  • Loading branch information
Weatherbench authors committed Nov 8, 2023
1 parent 1167715 commit 1e8e595
Showing 1 changed file with 41 additions and 24 deletions.
65 changes: 41 additions & 24 deletions scripts/compute_climatology.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"""
import ast
import functools
from typing import Optional
from typing import Callable, Optional, Union

from absl import app
from absl import flags
Expand Down Expand Up @@ -98,13 +98,9 @@
STATISTICS = flags.DEFINE_list(
'statistics',
['mean'],
help='Statistics to compute from "mean", "std", "seeps".',
)
ADD_STATISTIC_SUFFIX = flags.DEFINE_bool(
'add_statistic_suffix',
False,
'Add suffix of statistic to variable name. Required for >1 statistic.',
help='Statistics to compute from "mean", "std", "seeps", "quantile".',
)
QUANTILES = flags.DEFINE_list('quantiles', [], 'List of quantiles to compute.')
METHOD = flags.DEFINE_string(
'method',
'explicit',
Expand All @@ -126,6 +122,23 @@
)


class Quantile:
"""Compute quantiles."""

def __init__(self, quantiles: list[float]):
self.quantiles = quantiles

def compute(
self,
ds: xr.Dataset,
dim: tuple[str],
weights: Optional[xr.Dataset] = None,
):
if weights is not None:
ds = ds.weighted(weights) # pytype: disable=wrong-arg-types
return ds.quantile(self.quantiles, dim=dim)


class SEEPSThreshold:
"""Compute SEEPS thresholds (heav/light) and fraction of dry grid points."""

Expand Down Expand Up @@ -201,23 +214,25 @@ def compute_stat_chunk(
frequency: str,
window_size: int,
clim_years: slice,
statistic: str = 'mean',
statistic: Union[str, Callable[..., xr.Dataset]] = 'mean',
hour_interval: Optional[int] = None,
add_statistic_suffix: bool = False,
quantiles: Optional[list[float]] = None,
) -> tuple[xbeam.Key, xr.Dataset]:
"""Compute climatology on a chunk."""
if statistic not in ['mean', 'std']:
if statistic not in ['mean', 'std', 'quantile']:
raise NotImplementedError(f'stat {statistic} not implemented.')
offsets = dict(dayofyear=0)
if frequency == 'hourly':
offsets['hour'] = 0
clim_key = obs_key.with_offsets(time=None, **offsets)
if add_statistic_suffix:
if statistic != 'mean':
clim_key = clim_key.replace(
vars={f'{var}_{statistic}' for var in clim_key.vars}
)
for var in obs_chunk:
obs_chunk = obs_chunk.rename({var: f'{var}_{statistic}'})
if statistic == 'quantile':
statistic = Quantile(quantiles).compute
compute_kwargs = {
'obs': obs_chunk,
'window_size': window_size,
Expand Down Expand Up @@ -246,12 +261,6 @@ def compute_stat_chunk(


def main(argv: list[str]) -> None:
non_seeps_stats = [stat for stat in STATISTICS.value if stat != 'seeps']
if not ADD_STATISTIC_SUFFIX.value and len(non_seeps_stats) > 1:
raise ValueError(
'--add_statistic_suffix is required for >1 non-SEEPS statistics.'
)

obs, input_chunks = xbeam.open_zarr(INPUT_PATH.value)

# Convert object-type coordinates to string.
Expand Down Expand Up @@ -312,14 +321,22 @@ def _compute_seeps(kv):
(var,) = k.vars
return var in seeps_dry_threshold_mm.keys()

quantiles = [float(q) for q in QUANTILES.value]
for stat in STATISTICS.value:
if stat != 'seeps':
if ADD_STATISTIC_SUFFIX.value:
for var in raw_vars:
clim_template = clim_template.assign(
{f'{var}_{stat}': clim_template[var]}
if stat not in ['seeps', 'mean']:
for var in raw_vars:
if stat == 'quantile':
quantile_dim = xr.DataArray(
quantiles, name='quantile', dims=['quantile']
)
if ADD_STATISTIC_SUFFIX.value:
temp = clim_template[var].expand_dims(quantile=quantile_dim)
if 'hour' in temp.dims:
temp = temp.transpose('hour', 'quantile', ...)
else:
temp = clim_template[var]
clim_template = clim_template.assign({f'{var}_{stat}': temp})
# Mean has no suffix. Delete no suffix variables if no mean required
if 'mean' not in STATISTICS.value:
for var in raw_vars:
clim_template = clim_template.drop(var)

Expand Down Expand Up @@ -370,7 +387,7 @@ def _compute_seeps(kv):
window_size=WINDOW_SIZE.value,
clim_years=slice(str(START_YEAR.value), str(END_YEAR.value)),
statistic=stat,
add_statistic_suffix=ADD_STATISTIC_SUFFIX.value,
quantiles=quantiles,
**stat_kwargs,
)
)
Expand Down

0 comments on commit 1e8e595

Please sign in to comment.