Skip to content

Commit a2fd459

Browse files
langmoreWeatherbench2 authors
authored andcommitted
Remove RMSE from default metrics dicts. Change the metric name to include the sqrt_before_time_avg caveat.
Breaking change for users of top-level WB2 API. PiperOrigin-RevId: 587793549
1 parent 670e5aa commit a2fd459

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

scripts/evaluate.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -355,14 +355,16 @@ def main(argv: list[str]) -> None:
355355
climatology = evaluation.make_latitude_increasing(climatology)
356356

357357
deterministic_metrics = {
358-
'rmse': metrics.RMSESqrtBeforeTimeAvg(
359-
wind_vector_rmse=_wind_vector_error('rmse')
360-
),
361358
'mse': metrics.MSE(wind_vector_mse=_wind_vector_error('mse')),
362359
'acc': metrics.ACC(climatology=climatology),
363360
'bias': metrics.Bias(),
364361
'mae': metrics.MAE(),
365362
}
363+
rmse_metrics = {
364+
'rmse_sqrt_before_time_avg': metrics.RMSESqrtBeforeTimeAvg(
365+
wind_vector_rmse=_wind_vector_error('rmse')
366+
),
367+
}
366368
spatial_metrics = {
367369
'bias': metrics.SpatialBias(),
368370
'mse': metrics.SpatialMSE(),
@@ -412,7 +414,7 @@ def main(argv: list[str]) -> None:
412414
output_format='zarr',
413415
),
414416
'deterministic_temporal': config.Eval(
415-
metrics=deterministic_metrics,
417+
metrics=deterministic_metrics | rmse_metrics,
416418
against_analysis=False,
417419
regions=regions,
418420
derived_variables=derived_variables,
@@ -435,15 +437,9 @@ def main(argv: list[str]) -> None:
435437
ensemble_dim=ENSEMBLE_DIM.value
436438
),
437439
'crps_skill': metrics.CRPSSkill(ensemble_dim=ENSEMBLE_DIM.value),
438-
'ensemble_mean_rmse': metrics.EnsembleMeanRMSESqrtBeforeTimeAvg(
439-
ensemble_dim=ENSEMBLE_DIM.value
440-
),
441440
'ensemble_mean_mse': metrics.EnsembleMeanMSE(
442441
ensemble_dim=ENSEMBLE_DIM.value
443442
),
444-
'ensemble_stddev': metrics.EnsembleStddevSqrtBeforeTimeAvg(
445-
ensemble_dim=ENSEMBLE_DIM.value
446-
),
447443
'ensemble_variance': metrics.EnsembleVariance(
448444
ensemble_dim=ENSEMBLE_DIM.value
449445
),
@@ -467,6 +463,16 @@ def main(argv: list[str]) -> None:
467463
'energy_score_skill': metrics.EnergyScoreSkill(
468464
ensemble_dim=ENSEMBLE_DIM.value
469465
),
466+
'ensemble_mean_rmse_sqrt_before_time_avg': (
467+
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg(
468+
ensemble_dim=ENSEMBLE_DIM.value
469+
)
470+
),
471+
'ensemble_stddev_sqrt_before_time_avg': (
472+
metrics.EnsembleStddevSqrtBeforeTimeAvg(
473+
ensemble_dim=ENSEMBLE_DIM.value
474+
)
475+
),
470476
},
471477
against_analysis=False,
472478
derived_variables=derived_variables,

scripts/evaluate_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def _test(self, use_beam=True, input_chunks=None):
9393
evaluate.main([])
9494

9595
for config_name in eval_configs:
96-
expected_sizes_2d = {'metric': 5, 'lead_time': 4, 'region': 4}
97-
expected_sizes_3d = {'metric': 5, 'lead_time': 4, 'region': 4, 'level': 3}
96+
expected_sizes_2d = {'metric': 4, 'lead_time': 4, 'region': 4}
97+
expected_sizes_3d = {'metric': 4, 'lead_time': 4, 'region': 4, 'level': 3}
9898

9999
with self.subTest(config_name):
100100
results_path = os.path.join(output_dir, f'{config_name}.nc')

0 commit comments

Comments
 (0)