Skip to content

Commit

Permalink
Remove RMSE from default metrics dicts. Change the metric name to inc…
Browse files Browse the repository at this point in the history
…lude the `sqrt_before_time_avg` caveat.

Breaking change for users of top-level WB2 API.

PiperOrigin-RevId: 587793549
  • Loading branch information
langmore authored and Weatherbench2 authors committed Dec 8, 2023
1 parent 670e5aa commit a2fd459
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
26 changes: 16 additions & 10 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,14 +355,16 @@ def main(argv: list[str]) -> None:
climatology = evaluation.make_latitude_increasing(climatology)

deterministic_metrics = {
'rmse': metrics.RMSESqrtBeforeTimeAvg(
wind_vector_rmse=_wind_vector_error('rmse')
),
'mse': metrics.MSE(wind_vector_mse=_wind_vector_error('mse')),
'acc': metrics.ACC(climatology=climatology),
'bias': metrics.Bias(),
'mae': metrics.MAE(),
}
rmse_metrics = {
'rmse_sqrt_before_time_avg': metrics.RMSESqrtBeforeTimeAvg(
wind_vector_rmse=_wind_vector_error('rmse')
),
}
spatial_metrics = {
'bias': metrics.SpatialBias(),
'mse': metrics.SpatialMSE(),
Expand Down Expand Up @@ -412,7 +414,7 @@ def main(argv: list[str]) -> None:
output_format='zarr',
),
'deterministic_temporal': config.Eval(
metrics=deterministic_metrics,
metrics=deterministic_metrics | rmse_metrics,
against_analysis=False,
regions=regions,
derived_variables=derived_variables,
Expand All @@ -435,15 +437,9 @@ def main(argv: list[str]) -> None:
ensemble_dim=ENSEMBLE_DIM.value
),
'crps_skill': metrics.CRPSSkill(ensemble_dim=ENSEMBLE_DIM.value),
'ensemble_mean_rmse': metrics.EnsembleMeanRMSESqrtBeforeTimeAvg(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_mean_mse': metrics.EnsembleMeanMSE(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_stddev': metrics.EnsembleStddevSqrtBeforeTimeAvg(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_variance': metrics.EnsembleVariance(
ensemble_dim=ENSEMBLE_DIM.value
),
Expand All @@ -467,6 +463,16 @@ def main(argv: list[str]) -> None:
'energy_score_skill': metrics.EnergyScoreSkill(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_mean_rmse_sqrt_before_time_avg': (
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg(
ensemble_dim=ENSEMBLE_DIM.value
)
),
'ensemble_stddev_sqrt_before_time_avg': (
metrics.EnsembleStddevSqrtBeforeTimeAvg(
ensemble_dim=ENSEMBLE_DIM.value
)
),
},
against_analysis=False,
derived_variables=derived_variables,
Expand Down
4 changes: 2 additions & 2 deletions scripts/evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def _test(self, use_beam=True, input_chunks=None):
evaluate.main([])

for config_name in eval_configs:
expected_sizes_2d = {'metric': 5, 'lead_time': 4, 'region': 4}
expected_sizes_3d = {'metric': 5, 'lead_time': 4, 'region': 4, 'level': 3}
expected_sizes_2d = {'metric': 4, 'lead_time': 4, 'region': 4}
expected_sizes_3d = {'metric': 4, 'lead_time': 4, 'region': 4, 'level': 3}

with self.subTest(config_name):
results_path = os.path.join(output_dir, f'{config_name}.nc')
Expand Down

0 comments on commit a2fd459

Please sign in to comment.