diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 172faab..c6c62ee 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -229,9 +229,15 @@ ) -def _wind_vector_rmse(): - """Defines Wind Vector RMSEs if U/V components are in variables.""" - wind_vector_rmse = [] +def _wind_vector_error(err_type: str): + """Defines Wind Vector [R]MSEs if U/V components are in variables.""" + if err_type == 'mse': + cls = metrics.WindVectorMSE + elif err_type == 'rmse': + cls = metrics.WindVectorRMSESqrtBeforeTimeAvg + else: + raise ValueError(f'Unrecognized {err_type=}') + wind_vector_error = [] available = set(VARIABLES.value).union(DERIVED_VARIABLES.value) for u_name, v_name, vector_name in [ ('u_component_of_wind', 'v_component_of_wind', 'wind_vector'), @@ -248,14 +254,14 @@ def _wind_vector_rmse(): ), ]: if u_name in available and v_name in available: - wind_vector_rmse.append( - metrics.WindVectorRMSE( + wind_vector_error.append( + cls( u_name=u_name, v_name=v_name, vector_name=vector_name, ) ) - return wind_vector_rmse + return wind_vector_error def main(argv: list[str]) -> None: @@ -349,12 +355,16 @@ def main(argv: list[str]) -> None: climatology = evaluation.make_latitude_increasing(climatology) deterministic_metrics = { - 'rmse': metrics.RMSE(wind_vector_rmse=_wind_vector_rmse()), - 'mse': metrics.MSE(), + 'mse': metrics.MSE(wind_vector_mse=_wind_vector_error('mse')), 'acc': metrics.ACC(climatology=climatology), 'bias': metrics.Bias(), 'mae': metrics.MAE(), } + rmse_metrics = { + 'rmse_sqrt_before_time_avg': metrics.RMSESqrtBeforeTimeAvg( + wind_vector_rmse=_wind_vector_error('rmse') + ), + } spatial_metrics = { 'bias': metrics.SpatialBias(), 'mse': metrics.SpatialMSE(), @@ -404,7 +414,7 @@ def main(argv: list[str]) -> None: output_format='zarr', ), 'deterministic_temporal': config.Eval( - metrics=deterministic_metrics, + metrics=deterministic_metrics | rmse_metrics, against_analysis=False, regions=regions, derived_variables=derived_variables, @@ -427,15 +437,9 @@ def main(argv: list[str]) -> None: ensemble_dim=ENSEMBLE_DIM.value ), 'crps_skill': metrics.CRPSSkill(ensemble_dim=ENSEMBLE_DIM.value), - 'ensemble_mean_rmse': metrics.EnsembleMeanRMSE( - ensemble_dim=ENSEMBLE_DIM.value - ), 'ensemble_mean_mse': metrics.EnsembleMeanMSE( ensemble_dim=ENSEMBLE_DIM.value ), - 'ensemble_stddev': metrics.EnsembleStddev( - ensemble_dim=ENSEMBLE_DIM.value - ), 'ensemble_variance': metrics.EnsembleVariance( ensemble_dim=ENSEMBLE_DIM.value ), @@ -459,6 +463,16 @@ def main(argv: list[str]) -> None: 'energy_score_skill': metrics.EnergyScoreSkill( ensemble_dim=ENSEMBLE_DIM.value ), + 'ensemble_mean_rmse_sqrt_before_time_avg': ( + metrics.EnsembleMeanRMSESqrtBeforeTimeAvg( + ensemble_dim=ENSEMBLE_DIM.value + ) + ), + 'ensemble_stddev_sqrt_before_time_avg': ( + metrics.EnsembleStddevSqrtBeforeTimeAvg( + ensemble_dim=ENSEMBLE_DIM.value + ) + ), }, against_analysis=False, derived_variables=derived_variables, diff --git a/scripts/evaluate_test.py b/scripts/evaluate_test.py index 46cdfd3..54556ca 100644 --- a/scripts/evaluate_test.py +++ b/scripts/evaluate_test.py @@ -93,8 +93,8 @@ def _test(self, use_beam=True, input_chunks=None): evaluate.main([]) for config_name in eval_configs: - expected_sizes_2d = {'metric': 5, 'lead_time': 4, 'region': 4} - expected_sizes_3d = {'metric': 5, 'lead_time': 4, 'region': 4, 'level': 3} + expected_sizes_2d = {'metric': 4, 'lead_time': 4, 'region': 4} + expected_sizes_3d = {'metric': 4, 'lead_time': 4, 'region': 4, 'level': 3} with self.subTest(config_name): results_path = os.path.join(output_dir, f'{config_name}.nc') diff --git a/weatherbench2/evaluation_test.py b/weatherbench2/evaluation_test.py index a94951e..1dd770f 100644 --- a/weatherbench2/evaluation_test.py +++ b/weatherbench2/evaluation_test.py @@ -86,13 +86,13 @@ def test_in_memory_and_beam_consistency(self): eval_configs = { 'forecast_vs_era': config.Eval( metrics={ - 'rmse': metrics.RMSE(), + 'rmse': metrics.RMSESqrtBeforeTimeAvg(), 'acc': metrics.ACC(climatology=climatology), }, against_analysis=False, ), 'forecast_vs_era_by_region': config.Eval( - metrics={'rmse': metrics.RMSE()}, + metrics={'rmse': metrics.RMSESqrtBeforeTimeAvg()}, against_analysis=False, regions=regions, ), @@ -101,7 +101,7 @@ def test_in_memory_and_beam_consistency(self): against_analysis=False, ), 'forecast_vs_era_temporal': config.Eval( - metrics={'rmse': metrics.RMSE()}, + metrics={'rmse': metrics.RMSESqrtBeforeTimeAvg()}, against_analysis=False, temporal_mean=False, ), diff --git a/weatherbench2/metrics.py b/weatherbench2/metrics.py index bffc988..665edf9 100644 --- a/weatherbench2/metrics.py +++ b/weatherbench2/metrics.py @@ -135,8 +135,8 @@ def _spatial_average_l2_norm( @dataclasses.dataclass -class WindVectorRMSE(Metric): - """Compute wind vector RMSE. See WB2 paper for definition. +class WindVectorMSE(Metric): + """Compute wind vector mean square error. See WB2 paper for definition. Attributes: u_name: Name of U component. @@ -155,25 +155,57 @@ def compute_chunk( region: t.Optional[Region] = None, ) -> xr.Dataset: diff = forecast - truth - result = np.sqrt( - _spatial_average( - diff[self.u_name] ** 2 + diff[self.v_name] ** 2, - region=region, - ) + result = _spatial_average( + diff[self.u_name] ** 2 + diff[self.v_name] ** 2, + region=region, ) return result @dataclasses.dataclass -class RMSE(Metric): +class WindVectorRMSESqrtBeforeTimeAvg(Metric): + """Compute wind vector RMSE. See WB2 paper for definition. + + This SqrtBeforeTimeAvg metric takes a square root before any time averaging. + Most users will prefer to use WindVectorMSE and then take a square root in + user code after running the evaluate script. + + Attributes: + u_name: Name of U component. + v_name: Name of V component. + vector_name: Name of wind vector to be computed. + """ + + u_name: str + v_name: str + vector_name: str + + def compute_chunk( + self, + forecast: xr.Dataset, + truth: xr.Dataset, + region: t.Optional[Region] = None, + ) -> xr.Dataset: + mse = WindVectorMSE( + u_name=self.u_name, v_name=self.v_name, vector_name=self.vector_name + ).compute_chunk(forecast, truth, region=region) + return np.sqrt(mse) + + +@dataclasses.dataclass +class RMSESqrtBeforeTimeAvg(Metric): """Root mean squared error. + This SqrtBeforeTimeAvg metric takes a square root before any time averaging. + Most users will prefer to use MSE and then take a square root in user + code after running the evaluate script. + Attributes: - wind_vector_rmse: Optionally provide list of WindVectorRMSE instances to - compute. + wind_vector_rmse: Optionally provide list of WindVectorRMSESqrtBeforeTimeAvg + instances to compute. """ - wind_vector_rmse: t.Optional[list[WindVectorRMSE]] = None + wind_vector_rmse: t.Optional[list[WindVectorRMSESqrtBeforeTimeAvg]] = None def compute_chunk( self, @@ -192,7 +224,14 @@ def compute_chunk( @dataclasses.dataclass class MSE(Metric): - """Mean squared error.""" + """Mean squared error. + + Attributes: + wind_vector_mse: Optionally provide list of WindVectorMSE instances to + compute. + """ + + wind_vector_mse: t.Optional[list[WindVectorMSE]] = None def compute_chunk( self, @@ -200,7 +239,13 @@ def compute_chunk( truth: xr.Dataset, region: t.Optional[Region] = None, ) -> xr.Dataset: - return _spatial_average((forecast - truth) ** 2, region=region) + results = _spatial_average((forecast - truth) ** 2, region=region) + if self.wind_vector_mse is not None: + for wv in self.wind_vector_mse: + results[wv.vector_name] = wv.compute_chunk( + forecast, truth, region=region + ) + return results @dataclasses.dataclass @@ -717,14 +762,15 @@ def _pointwise_gaussian_crps( @dataclasses.dataclass -class EnsembleStddev(EnsembleMetric): +class EnsembleStddevSqrtBeforeTimeAvg(EnsembleMetric): """The standard deviation of an ensemble of forecasts. - This forms the SPREAD component of the traditional spread-skill-ratio. See - [Garg & Rasp & Thuerey, 2022]. + This SqrtBeforeTimeAvg metric takes a square root before any time averaging. + Most users will prefer to use EnsembleVariance and then take a square root in + user code after running the evaluate script. Given predictive ensemble Xₜ at times t = (1,..., T), - EnsembleStddev := (1 / T) Σₜ ‖σ(Xₜ)‖ + EnsembleStddevSqrtBeforeTimeAvg := (1 / T) Σₜ ‖σ(Xₜ)‖ Above σ(Xₜ) is element-wise standard deviation, and ‖⋅‖ is an area-weighted L2 norm. @@ -737,15 +783,6 @@ class EnsembleStddev(EnsembleMetric): NaN values propagate through and result in NaN in the corresponding output position. - - We use the unbiased estimator of σ(Xₜ) (dividing by n_ensemble - 1). If - n_ensemble = 1, we return zero for the stddev. This choice allows - EnsembleStddev to behave in the spread-skill-ratio as expected. - - References: - [Garg & Rasp & Thuerey, 2022], WeatherBench Probability: A benchmark dataset - for probabilistic medium-range weather forecasting along with deep learning - baseline models. """ def compute_chunk( @@ -754,7 +791,7 @@ def compute_chunk( truth: xr.Dataset, region: t.Optional[Region] = None, ) -> xr.Dataset: - """EnsembleStddev, averaged over space, for a time chunk of data.""" + """Ensemble Stddev, averaged over space, for a time chunk of data.""" del truth # unused n_ensemble = _get_n_ensemble(forecast, self.ensemble_dim) @@ -825,15 +862,16 @@ def compute_chunk( @dataclasses.dataclass -class EnsembleMeanRMSE(EnsembleMetric): +class EnsembleMeanRMSESqrtBeforeTimeAvg(EnsembleMetric): """RMSE between the ensemble mean and ground truth. - This forms the SKILL component of the traditional spread-skill-ratio. See - [Garg & Rasp & Thuerey, 2022]. + This SqrtBeforeTimeAvg metric takes a square root before any time averaging. + Most users will prefer to use EnsembleMeanMSE and then take a square root in + user code after running the evaluate script. Given ground truth Yₜ, and predictive ensemble Xₜ, both at times t = (1,..., T), - EnsembleMeanRMSE := (1 / T) Σₜ ‖Y - E(Xₜ)‖. + EnsembleMeanRMSESqrtBeforeTimeAvg := (1 / T) Σₜ ‖Y - E(Xₜ)‖. Above, `E` is ensemble average, and ‖⋅‖ is an area-weighted L2 norm. Estimation is done separately for each tendency, level, and lag time. @@ -845,11 +883,6 @@ class EnsembleMeanRMSE(EnsembleMetric): NaN values propagate through and result in NaN in the corresponding output position. - - References: - [Garg & Rasp & Thuerey, 2022], WeatherBench Probability: A benchmark dataset - for probabilistic medium-range weather forecasting along with deep learning - baseline models. """ def compute_chunk( @@ -1005,7 +1038,7 @@ def compute_chunk( # TODO(shoyer): Consider adding WindVectorEnergyScore based on a pair of wind -# components, as a sort of probabilistic variant of WindVectorRMSE. +# components, as a sort of probabilistic variant of WindVectorMSE. @dataclasses.dataclass diff --git a/weatherbench2/metrics_test.py b/weatherbench2/metrics_test.py index 836a6dd..74e0e42 100644 --- a/weatherbench2/metrics_test.py +++ b/weatherbench2/metrics_test.py @@ -70,7 +70,7 @@ def test_get_lat_weights(self): xr.testing.assert_allclose(expected, weights) def test_wind_vector_rmse(self): - wv = metrics.WindVectorRMSE( + wv = metrics.WindVectorRMSESqrtBeforeTimeAvg( u_name='u_component_of_wind', v_name='v_component_of_wind', vector_name='wind_vector', @@ -119,7 +119,7 @@ def test_wind_vector_rmse(self): dict(testcase_name='nan', invalid_value=np.nan), ) def test_rmse_over_invalid_region(self, invalid_value): - rmse = metrics.RMSE() + rmse = metrics.RMSESqrtBeforeTimeAvg() truth = xr.Dataset( {'wind_speed': ('latitude', [0.0, invalid_value, 0.0])}, coords={'latitude': [-45, 0, 45]}, @@ -460,8 +460,12 @@ class EnsembleMeanRMSEAndStddevTest(parameterized.TestCase): def test_on_random_dataset(self, ensemble_size): truth, forecast = get_random_truth_and_forecast(ensemble_size=ensemble_size) - rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth) - ensemble_stddev = metrics.EnsembleStddev().compute_chunk(forecast, truth) + rmse = metrics.EnsembleMeanRMSESqrtBeforeTimeAvg().compute_chunk( + forecast, truth + ) + ensemble_stddev = metrics.EnsembleStddevSqrtBeforeTimeAvg().compute_chunk( + forecast, truth + ) for dataset in [rmse, ensemble_stddev]: self.assertEqual( @@ -496,7 +500,11 @@ def test_effect_of_large_bias_on_rmse(self): truth, forecast = get_random_truth_and_forecast(ensemble_size=10) truth += 1000 - mean_rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth).mean() + mean_rmse = ( + metrics.EnsembleMeanRMSESqrtBeforeTimeAvg() + .compute_chunk(forecast, truth) + .mean() + ) # Dominated by bias of 1000 np.testing.assert_allclose(1000, mean_rmse.geopotential.values, rtol=1e-3) @@ -504,7 +512,11 @@ def test_effect_of_large_bias_on_rmse(self): def test_perfect_prediction_zero_rmse(self): truth, unused_forecast = get_random_truth_and_forecast(ensemble_size=10) forecast = truth.expand_dims({metrics.REALIZATION: 1}) - mean_rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth).mean() + mean_rmse = ( + metrics.EnsembleMeanRMSESqrtBeforeTimeAvg() + .compute_chunk(forecast, truth) + .mean() + ) xr.testing.assert_allclose(xr.zeros_like(mean_rmse), mean_rmse) diff --git a/weatherbench2/regions_test.py b/weatherbench2/regions_test.py index 0d8960d..1dc0397 100644 --- a/weatherbench2/regions_test.py +++ b/weatherbench2/regions_test.py @@ -43,7 +43,7 @@ def testLandRegion(self): lsm = lsm.where(lsm.latitude < 1.0, 1) land_region = regions.LandRegion(lsm) - rmse = metrics.RMSE() + rmse = metrics.RMSESqrtBeforeTimeAvg() results = rmse.compute(forecast, truth, region=land_region) np.testing.assert_allclose(results['2m_temperature'].values, 0.0)