Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681163784
  • Loading branch information
langmore authored and Weatherbench2 authors committed Oct 14, 2024
1 parent de3f56e commit a09ccba
Show file tree
Hide file tree
Showing 8 changed files with 413 additions and 111 deletions.
12 changes: 11 additions & 1 deletion scripts/compute_averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@
'If empty, compute on all data_vars of --input_path'
),
)
SKIPNA = flags.DEFINE_boolean(
'skipna',
False,
help=(
'Whether to skip NaN data points (in forecasts and observations) when'
' evaluating.'
),
)
FANOUT = flags.DEFINE_integer(
'fanout',
None,
Expand Down Expand Up @@ -138,7 +146,9 @@ def main(argv: list[str]):

(
chunked
| xbeam.Mean(AVERAGING_DIMS.value, skipna=False, fanout=FANOUT.value)
| xbeam.Mean(
AVERAGING_DIMS.value, skipna=SKIPNA.value, fanout=FANOUT.value
)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value,
template,
Expand Down
10 changes: 9 additions & 1 deletion scripts/compute_ensemble_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@
' all variables are selected.'
),
)
SKIPNA = flags.DEFINE_boolean(
'skipna',
False,
help=(
'Whether to skip NaN data points (in forecasts and observations) when'
' evaluating.'
),
)


# pylint: disable=expression-not-assigned
Expand Down Expand Up @@ -123,7 +131,7 @@ def main(argv: list[str]):
split_vars=True,
num_threads=NUM_THREADS.value,
)
| xbeam.Mean(REALIZATION_NAME.value, skipna=False)
| xbeam.Mean(REALIZATION_NAME.value, skipna=SKIPNA.value)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value,
template,
Expand Down
15 changes: 13 additions & 2 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@
' "2m_temperature"}'
),
)
SKIPNA = flags.DEFINE_boolean(
'skipna',
False,
help=(
'Whether to skip NaN data points (in forecasts and observations) when'
' evaluating.'
),
)
PRESSURE_LEVEL_SUFFIXES = flags.DEFINE_bool(
'pressure_level_suffixes',
False,
Expand Down Expand Up @@ -630,14 +638,17 @@ def main(argv: list[str]) -> None:
eval_configs,
runner=RUNNER.value,
input_chunks=INPUT_CHUNKS.value,
skipna=SKIPNA.value,
fanout=FANOUT.value,
num_threads=NUM_THREADS.value,
argv=argv,
)
else:
evaluation.evaluate_in_memory(data_config, eval_configs)
evaluation.evaluate_in_memory(
data_config, eval_configs, skipna=SKIPNA.value
)


if __name__ == '__main__':
app.run(main)
flags.mark_flag_as_required('output_path')
flags.mark_flags_as_required(['output_path', 'obs_path'])
34 changes: 25 additions & 9 deletions scripts/resample_in_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@
' use the last time in --input_path.'
),
)
SKIPNA = flags.DEFINE_boolean(
'skipna',
False,
help=(
'Whether to skip NaN data points (in forecasts and observations) when'
' evaluating.'
),
)
WORKING_CHUNKS = flag_utils.DEFINE_chunks(
'working_chunks',
'',
Expand Down Expand Up @@ -182,6 +190,7 @@ def resample_in_time_chunk(
min_vars: list[str],
max_vars: list[str],
add_mean_suffix: bool,
skipna: bool = False,
) -> tuple[xbeam.Key, xr.Dataset]:
"""Resample a data chunk in time and return a requested time statistic.
Expand All @@ -196,6 +205,8 @@ def resample_in_time_chunk(
max_vars: Variables to compute the max of.
add_mean_suffix: Whether to add a "_mean" suffix to variables after
computing the mean.
skipna: Whether to skip NaN values in both forecasts and observations during
evaluation.
Returns:
The resampled data chunk and its key.
Expand All @@ -207,21 +218,23 @@ def resample_in_time_chunk(
for chunk_var in chunk.data_vars:
if chunk_var in mean_vars:
rsmp_chunks.append(
resample_in_time_core(chunk, method, period, 'mean').rename(
resample_in_time_core(
chunk, method, period, 'mean', skipna=skipna
).rename(
{chunk_var: f'{chunk_var}_mean' if add_mean_suffix else chunk_var}
)
)
if chunk_var in min_vars:
rsmp_chunks.append(
resample_in_time_core(chunk, method, period, 'min').rename(
{chunk_var: f'{chunk_var}_min'}
)
resample_in_time_core(
chunk, method, period, 'min', skipna=skipna
).rename({chunk_var: f'{chunk_var}_min'})
)
if chunk_var in max_vars:
rsmp_chunks.append(
resample_in_time_core(chunk, method, period, 'max').rename(
{chunk_var: f'{chunk_var}_max'}
)
resample_in_time_core(
chunk, method, period, 'max', skipna=skipna
).rename({chunk_var: f'{chunk_var}_max'})
)

return rsmp_key, xr.merge(rsmp_chunks)
Expand All @@ -232,6 +245,7 @@ def resample_in_time_core(
method: str,
period: pd.Timedelta,
statistic: str,
skipna: bool,
) -> t.Union[xr.Dataset, xr.DataArray]:
"""Core call to xarray resample or rolling."""
if method == 'rolling':
Expand All @@ -245,12 +259,12 @@ def resample_in_time_core(
{TIME_DIM.value: period // delta_t}, center=False, min_periods=None
),
statistic,
)(skipna=False)
)(skipna=skipna)
elif method == 'resample':
return getattr(
chunk.resample({TIME_DIM.value: period}, label='left'),
statistic,
)(skipna=False)
)(skipna=skipna)
else:
raise ValueError(f'Unhandled {method=}')

Expand Down Expand Up @@ -301,6 +315,7 @@ def main(argv: abc.Sequence[str]) -> None:
METHOD.value,
period,
statistic='mean',
skipna=SKIPNA.value,
)[TIME_DIM.value]
else:
rsmp_times = ds[TIME_DIM.value]
Expand Down Expand Up @@ -369,6 +384,7 @@ def main(argv: abc.Sequence[str]) -> None:
min_vars=min_vars,
max_vars=max_vars,
add_mean_suffix=ADD_MEAN_SUFFIX.value,
skipna=SKIPNA.value,
)
)
| 'RechunkToOutputChunks'
Expand Down
35 changes: 27 additions & 8 deletions weatherbench2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def _metric_and_region_loop(
forecast: xr.Dataset,
truth: xr.Dataset,
eval_config: config.Eval,
skipna: bool,
compute_chunk: bool = False,
) -> xr.Dataset:
"""Compute metric results looping over metrics and regions in eval config."""
Expand All @@ -412,16 +413,18 @@ def _metric_and_region_loop(
region_dim = xr.DataArray(
[region_name], coords={'region': [region_name]}
)
tmp_result = eval_fn(forecast=forecast, truth=truth, region=region)
tmp_result = eval_fn(
forecast=forecast, truth=truth, region=region, skipna=skipna
)
tmp_results.append(
tmp_result.expand_dims({'metric': metric_dim, 'region': region_dim})
)
logging.info(f'Logging region done: {region_name}')
result = xr.concat(tmp_results, 'region')
else:
result = eval_fn(forecast=forecast, truth=truth).expand_dims(
{'metric': metric_dim}
)
result = eval_fn(
forecast=forecast, truth=truth, skipna=skipna
).expand_dims({'metric': metric_dim})
results.append(result)
logging.info(f'Logging metric done: {name}')
results = xr.merge(results)
Expand All @@ -432,6 +435,7 @@ def _evaluate_all_metrics(
eval_name: str,
eval_config: config.Eval,
data_config: config.Data,
skipna: bool,
) -> None:
"""Evaluate a set of eval metrics in memory."""
forecast, truth, climatology = open_forecast_and_truth_datasets(
Expand Down Expand Up @@ -463,7 +467,7 @@ def _evaluate_all_metrics(
if data_config.by_init:
truth = truth.sel(time=forecast.valid_time)

results = _metric_and_region_loop(forecast, truth, eval_config)
results = _metric_and_region_loop(forecast, truth, eval_config, skipna=skipna)

logging.info(f'Logging Evaluation complete:\n{results}')

Expand All @@ -475,6 +479,7 @@ def _evaluate_all_metrics(
def evaluate_in_memory(
data_config: config.Data,
eval_configs: dict[str, config.Eval],
skipna: bool = False,
) -> None:
"""Run evaluation in memory.
Expand All @@ -498,9 +503,11 @@ def evaluate_in_memory(
Args:
data_config: config.Data instance.
eval_configs: Dictionary of config.Eval instances.
skipna: Whether to skip NaN values in both forecasts and observations during
evaluation.
"""
for eval_name, eval_config in eval_configs.items():
_evaluate_all_metrics(eval_name, eval_config, data_config)
_evaluate_all_metrics(eval_name, eval_config, data_config, skipna=skipna)


@dataclasses.dataclass
Expand Down Expand Up @@ -547,13 +554,17 @@ class _EvaluateAllMetrics(beam.PTransform):
eval_config: config.Eval instance.
data_config: config.Data instance.
input_chunks: Chunks to use for input files.
skipna: Whether to skip NaN values in both forecasts and observations during
evaluation.
fanout: Fanout parameter for Beam combiners.
num_threads: Number of threads for reading/writing files.
"""

eval_name: str
eval_config: config.Eval
data_config: config.Data
input_chunks: abc.Mapping[str, int]
skipna: bool
fanout: Optional[int] = None
num_threads: Optional[int] = None

Expand All @@ -565,7 +576,11 @@ def _evaluate_chunk(
forecast, truth = forecast_and_truth
logging.info(f'Logging _evaluate_chunk Key: {key}')
results = _metric_and_region_loop(
forecast, truth, self.eval_config, compute_chunk=True
forecast,
truth,
self.eval_config,
compute_chunk=True,
skipna=self.skipna,
)
dropped_dims = [dim for dim in key.offsets if dim not in results.dims]
result_key = key.with_offsets(**{dim: None for dim in dropped_dims})
Expand Down Expand Up @@ -709,7 +724,7 @@ def _evaluate(
forecast_pipeline |= 'TemporalMean' >> xbeam.Mean(
dim='init_time' if self.data_config.by_init else 'time',
fanout=self.fanout,
skipna=False,
skipna=self.skipna,
)

return forecast_pipeline
Expand All @@ -733,6 +748,7 @@ def evaluate_with_beam(
fanout: Optional[int] = None,
num_threads: Optional[int] = None,
argv: Optional[list[str]] = None,
skipna: bool = False,
) -> None:
"""Run evaluation with a Beam pipeline.
Expand Down Expand Up @@ -761,6 +777,8 @@ def evaluate_with_beam(
fanout: Beam CombineFn fanout.
num_threads: Number of threads to use for reading/writing data.
argv: Other arguments to pass into the Beam pipeline.
skipna: Whether to skip NaN values in both forecasts and observations during
evaluation.
"""

with beam.Pipeline(runner=runner, argv=argv) as root:
Expand All @@ -776,6 +794,7 @@ def evaluate_with_beam(
input_chunks,
fanout=fanout,
num_threads=num_threads,
skipna=skipna,
)
| f'save_{eval_name}'
>> _SaveOutputs(
Expand Down
Loading

0 comments on commit a09ccba

Please sign in to comment.