From a2fd459148fbc2809d6f09c781b8031597734131 Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Mon, 4 Dec 2023 11:47:46 -0800 Subject: [PATCH] Remove RMSE from default metrics dicts. Change the metric name to include the `sqrt_before_time_avg` caveat. Breaking change for users of top-level WB2 API. PiperOrigin-RevId: 587793549 --- scripts/evaluate.py | 26 ++++++++++++++++---------- scripts/evaluate_test.py | 4 ++-- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 13e92c3..c6c62ee 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -355,14 +355,16 @@ def main(argv: list[str]) -> None: climatology = evaluation.make_latitude_increasing(climatology) deterministic_metrics = { - 'rmse': metrics.RMSESqrtBeforeTimeAvg( - wind_vector_rmse=_wind_vector_error('rmse') - ), 'mse': metrics.MSE(wind_vector_mse=_wind_vector_error('mse')), 'acc': metrics.ACC(climatology=climatology), 'bias': metrics.Bias(), 'mae': metrics.MAE(), } + rmse_metrics = { + 'rmse_sqrt_before_time_avg': metrics.RMSESqrtBeforeTimeAvg( + wind_vector_rmse=_wind_vector_error('rmse') + ), + } spatial_metrics = { 'bias': metrics.SpatialBias(), 'mse': metrics.SpatialMSE(), @@ -412,7 +414,7 @@ def main(argv: list[str]) -> None: output_format='zarr', ), 'deterministic_temporal': config.Eval( - metrics=deterministic_metrics, + metrics=deterministic_metrics | rmse_metrics, against_analysis=False, regions=regions, derived_variables=derived_variables, @@ -435,15 +437,9 @@ def main(argv: list[str]) -> None: ensemble_dim=ENSEMBLE_DIM.value ), 'crps_skill': metrics.CRPSSkill(ensemble_dim=ENSEMBLE_DIM.value), - 'ensemble_mean_rmse': metrics.EnsembleMeanRMSESqrtBeforeTimeAvg( - ensemble_dim=ENSEMBLE_DIM.value - ), 'ensemble_mean_mse': metrics.EnsembleMeanMSE( ensemble_dim=ENSEMBLE_DIM.value ), - 'ensemble_stddev': metrics.EnsembleStddevSqrtBeforeTimeAvg( - ensemble_dim=ENSEMBLE_DIM.value - ), 'ensemble_variance': metrics.EnsembleVariance( ensemble_dim=ENSEMBLE_DIM.value ), @@ -467,6 +463,16 @@ def main(argv: list[str]) -> None: 'energy_score_skill': metrics.EnergyScoreSkill( ensemble_dim=ENSEMBLE_DIM.value ), + 'ensemble_mean_rmse_sqrt_before_time_avg': ( + metrics.EnsembleMeanRMSESqrtBeforeTimeAvg( + ensemble_dim=ENSEMBLE_DIM.value + ) + ), + 'ensemble_stddev_sqrt_before_time_avg': ( + metrics.EnsembleStddevSqrtBeforeTimeAvg( + ensemble_dim=ENSEMBLE_DIM.value + ) + ), }, against_analysis=False, derived_variables=derived_variables, diff --git a/scripts/evaluate_test.py b/scripts/evaluate_test.py index 46cdfd3..54556ca 100644 --- a/scripts/evaluate_test.py +++ b/scripts/evaluate_test.py @@ -93,8 +93,8 @@ def _test(self, use_beam=True, input_chunks=None): evaluate.main([]) for config_name in eval_configs: - expected_sizes_2d = {'metric': 5, 'lead_time': 4, 'region': 4} - expected_sizes_3d = {'metric': 5, 'lead_time': 4, 'region': 4, 'level': 3} + expected_sizes_2d = {'metric': 4, 'lead_time': 4, 'region': 4} + expected_sizes_3d = {'metric': 4, 'lead_time': 4, 'region': 4, 'level': 3} with self.subTest(config_name): results_path = os.path.join(output_dir, f'{config_name}.nc')