Skip to content

Commit f5a4f73

Browse files
langmoreWeatherbench2 authors
authored and
Weatherbench2 authors
committed
Update RMSE metrics to be named RMSESqrtBeforeTimeAvg to warn users of their operation. Add WindVectorMSE to make MSE more complete as the default choice.
Subsequent CL will remove RMSE from default metrics. PiperOrigin-RevId: 587741031
1 parent 75ccb1d commit f5a4f73

File tree

5 files changed

+110
-57
lines changed

5 files changed

+110
-57
lines changed

scripts/evaluate.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,15 @@
229229
)
230230

231231

232-
def _wind_vector_rmse():
233-
"""Defines Wind Vector RMSEs if U/V components are in variables."""
234-
wind_vector_rmse = []
232+
def _wind_vector_error(err_type: str):
233+
"""Defines Wind Vector [R]MSEs if U/V components are in variables."""
234+
if err_type == 'mse':
235+
cls = metrics.WindVectorMSE
236+
elif err_type == 'rmse':
237+
cls = metrics.WindVectorRMSESqrtBeforeTimeAvg
238+
else:
239+
raise ValueError(f'Unrecognized {err_type=}')
240+
wind_vector_error = []
235241
available = set(VARIABLES.value).union(DERIVED_VARIABLES.value)
236242
for u_name, v_name, vector_name in [
237243
('u_component_of_wind', 'v_component_of_wind', 'wind_vector'),
@@ -248,14 +254,14 @@ def _wind_vector_rmse():
248254
),
249255
]:
250256
if u_name in available and v_name in available:
251-
wind_vector_rmse.append(
252-
metrics.WindVectorRMSE(
257+
wind_vector_error.append(
258+
cls(
253259
u_name=u_name,
254260
v_name=v_name,
255261
vector_name=vector_name,
256262
)
257263
)
258-
return wind_vector_rmse
264+
return wind_vector_error
259265

260266

261267
def main(argv: list[str]) -> None:
@@ -349,8 +355,10 @@ def main(argv: list[str]) -> None:
349355
climatology = evaluation.make_latitude_increasing(climatology)
350356

351357
deterministic_metrics = {
352-
'rmse': metrics.RMSE(wind_vector_rmse=_wind_vector_rmse()),
353-
'mse': metrics.MSE(),
358+
'rmse': metrics.RMSESqrtBeforeTimeAvg(
359+
wind_vector_rmse=_wind_vector_error('rmse')
360+
),
361+
'mse': metrics.MSE(wind_vector_mse=_wind_vector_error('mse')),
354362
'acc': metrics.ACC(climatology=climatology),
355363
'bias': metrics.Bias(),
356364
'mae': metrics.MAE(),
@@ -427,13 +435,13 @@ def main(argv: list[str]) -> None:
427435
ensemble_dim=ENSEMBLE_DIM.value
428436
),
429437
'crps_skill': metrics.CRPSSkill(ensemble_dim=ENSEMBLE_DIM.value),
430-
'ensemble_mean_rmse': metrics.EnsembleMeanRMSE(
438+
'ensemble_mean_rmse': metrics.EnsembleMeanRMSESqrtBeforeTimeAvg(
431439
ensemble_dim=ENSEMBLE_DIM.value
432440
),
433441
'ensemble_mean_mse': metrics.EnsembleMeanMSE(
434442
ensemble_dim=ENSEMBLE_DIM.value
435443
),
436-
'ensemble_stddev': metrics.EnsembleStddev(
444+
'ensemble_stddev': metrics.EnsembleStddevSqrtBeforeTimeAvg(
437445
ensemble_dim=ENSEMBLE_DIM.value
438446
),
439447
'ensemble_variance': metrics.EnsembleVariance(

weatherbench2/evaluation_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@ def test_in_memory_and_beam_consistency(self):
8686
eval_configs = {
8787
'forecast_vs_era': config.Eval(
8888
metrics={
89-
'rmse': metrics.RMSE(),
89+
'rmse': metrics.RMSESqrtBeforeTimeAvg(),
9090
'acc': metrics.ACC(climatology=climatology),
9191
},
9292
against_analysis=False,
9393
),
9494
'forecast_vs_era_by_region': config.Eval(
95-
metrics={'rmse': metrics.RMSE()},
95+
metrics={'rmse': metrics.RMSESqrtBeforeTimeAvg()},
9696
against_analysis=False,
9797
regions=regions,
9898
),
@@ -101,7 +101,7 @@ def test_in_memory_and_beam_consistency(self):
101101
against_analysis=False,
102102
),
103103
'forecast_vs_era_temporal': config.Eval(
104-
metrics={'rmse': metrics.RMSE()},
104+
metrics={'rmse': metrics.RMSESqrtBeforeTimeAvg()},
105105
against_analysis=False,
106106
temporal_mean=False,
107107
),

weatherbench2/metrics.py

+70-37
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def _spatial_average_l2_norm(
135135

136136

137137
@dataclasses.dataclass
138-
class WindVectorRMSE(Metric):
139-
"""Compute wind vector RMSE. See WB2 paper for definition.
138+
class WindVectorMSE(Metric):
139+
"""Compute wind vector mean square error. See WB2 paper for definition.
140140
141141
Attributes:
142142
u_name: Name of U component.
@@ -155,25 +155,57 @@ def compute_chunk(
155155
region: t.Optional[Region] = None,
156156
) -> xr.Dataset:
157157
diff = forecast - truth
158-
result = np.sqrt(
159-
_spatial_average(
160-
diff[self.u_name] ** 2 + diff[self.v_name] ** 2,
161-
region=region,
162-
)
158+
result = _spatial_average(
159+
diff[self.u_name] ** 2 + diff[self.v_name] ** 2,
160+
region=region,
163161
)
164162
return result
165163

166164

167165
@dataclasses.dataclass
168-
class RMSE(Metric):
166+
class WindVectorRMSESqrtBeforeTimeAvg(Metric):
167+
"""Compute wind vector RMSE. See WB2 paper for definition.
168+
169+
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
170+
Most users will prefer to use WindVectorMSE and then take a square root in
171+
user code after running the evaluate script.
172+
173+
Attributes:
174+
u_name: Name of U component.
175+
v_name: Name of V component.
176+
vector_name: Name of wind vector to be computed.
177+
"""
178+
179+
u_name: str
180+
v_name: str
181+
vector_name: str
182+
183+
def compute_chunk(
184+
self,
185+
forecast: xr.Dataset,
186+
truth: xr.Dataset,
187+
region: t.Optional[Region] = None,
188+
) -> xr.Dataset:
189+
mse = WindVectorMSE(
190+
u_name=self.u_name, v_name=self.v_name, vector_name=self.vector_name
191+
).compute_chunk(forecast, truth, region=region)
192+
return np.sqrt(mse)
193+
194+
195+
@dataclasses.dataclass
196+
class RMSESqrtBeforeTimeAvg(Metric):
169197
"""Root mean squared error.
170198
199+
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
200+
Most users will prefer to use MSE and then take a square root in user
201+
code after running the evaluate script.
202+
171203
Attributes:
172-
wind_vector_rmse: Optionally provide list of WindVectorRMSE instances to
173-
compute.
204+
wind_vector_rmse: Optionally provide list of WindVectorRMSESqrtBeforeTimeAvg
205+
instances to compute.
174206
"""
175207

176-
wind_vector_rmse: t.Optional[list[WindVectorRMSE]] = None
208+
wind_vector_rmse: t.Optional[list[WindVectorRMSESqrtBeforeTimeAvg]] = None
177209

178210
def compute_chunk(
179211
self,
@@ -192,15 +224,28 @@ def compute_chunk(
192224

193225
@dataclasses.dataclass
194226
class MSE(Metric):
195-
"""Mean squared error."""
227+
"""Mean squared error.
228+
229+
Attributes:
230+
wind_vector_mse: Optionally provide list of WindVectorMSE instances to
231+
compute.
232+
"""
233+
234+
wind_vector_mse: t.Optional[list[WindVectorMSE]] = None
196235

197236
def compute_chunk(
198237
self,
199238
forecast: xr.Dataset,
200239
truth: xr.Dataset,
201240
region: t.Optional[Region] = None,
202241
) -> xr.Dataset:
203-
return _spatial_average((forecast - truth) ** 2, region=region)
242+
results = _spatial_average((forecast - truth) ** 2, region=region)
243+
if self.wind_vector_mse is not None:
244+
for wv in self.wind_vector_mse:
245+
results[wv.vector_name] = wv.compute_chunk(
246+
forecast, truth, region=region
247+
)
248+
return results
204249

205250

206251
@dataclasses.dataclass
@@ -717,14 +762,15 @@ def _pointwise_gaussian_crps(
717762

718763

719764
@dataclasses.dataclass
720-
class EnsembleStddev(EnsembleMetric):
765+
class EnsembleStddevSqrtBeforeTimeAvg(EnsembleMetric):
721766
"""The standard deviation of an ensemble of forecasts.
722767
723-
This forms the SPREAD component of the traditional spread-skill-ratio. See
724-
[Garg & Rasp & Thuerey, 2022].
768+
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
769+
Most users will prefer to use EnsembleVariance and then take a square root in
770+
user code after running the evaluate script.
725771
726772
Given predictive ensemble Xₜ at times t = (1,..., T),
727-
EnsembleStddev := (1 / T) Σₜ ‖σ(Xₜ)‖
773+
EnsembleStddevSqrtBeforeTimeAvg := (1 / T) Σₜ ‖σ(Xₜ)‖
728774
Above σ(Xₜ) is element-wise standard deviation, and ‖⋅‖ is an area-weighted
729775
L2 norm.
730776
@@ -737,15 +783,6 @@ class EnsembleStddev(EnsembleMetric):
737783
738784
NaN values propagate through and result in NaN in the corresponding output
739785
position.
740-
741-
We use the unbiased estimator of σ(Xₜ) (dividing by n_ensemble - 1). If
742-
n_ensemble = 1, we return zero for the stddev. This choice allows
743-
EnsembleStddev to behave in the spread-skill-ratio as expected.
744-
745-
References:
746-
[Garg & Rasp & Thuerey, 2022], WeatherBench Probability: A benchmark dataset
747-
for probabilistic medium-range weather forecasting along with deep learning
748-
baseline models.
749786
"""
750787

751788
def compute_chunk(
@@ -754,7 +791,7 @@ def compute_chunk(
754791
truth: xr.Dataset,
755792
region: t.Optional[Region] = None,
756793
) -> xr.Dataset:
757-
"""EnsembleStddev, averaged over space, for a time chunk of data."""
794+
"""Ensemble Stddev, averaged over space, for a time chunk of data."""
758795
del truth # unused
759796
n_ensemble = _get_n_ensemble(forecast, self.ensemble_dim)
760797

@@ -825,15 +862,16 @@ def compute_chunk(
825862

826863

827864
@dataclasses.dataclass
828-
class EnsembleMeanRMSE(EnsembleMetric):
865+
class EnsembleMeanRMSESqrtBeforeTimeAvg(EnsembleMetric):
829866
"""RMSE between the ensemble mean and ground truth.
830867
831-
This forms the SKILL component of the traditional spread-skill-ratio. See
832-
[Garg & Rasp & Thuerey, 2022].
868+
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
869+
Most users will prefer to use EnsembleMeanMSE and then take a square root in
870+
user code after running the evaluate script.
833871
834872
Given ground truth Yₜ, and predictive ensemble Xₜ, both at times
835873
t = (1,..., T),
836-
EnsembleMeanRMSE := (1 / T) Σₜ ‖Y - E(Xₜ)‖.
874+
EnsembleMeanRMSESqrtBeforeTimeAvg := (1 / T) Σₜ ‖Y - E(Xₜ)‖.
837875
Above, `E` is ensemble average, and ‖⋅‖ is an area-weighted L2 norm.
838876
839877
Estimation is done separately for each tendency, level, and lag time.
@@ -845,11 +883,6 @@ class EnsembleMeanRMSE(EnsembleMetric):
845883
846884
NaN values propagate through and result in NaN in the corresponding output
847885
position.
848-
849-
References:
850-
[Garg & Rasp & Thuerey, 2022], WeatherBench Probability: A benchmark dataset
851-
for probabilistic medium-range weather forecasting along with deep learning
852-
baseline models.
853886
"""
854887

855888
def compute_chunk(
@@ -1005,7 +1038,7 @@ def compute_chunk(
10051038

10061039

10071040
# TODO(shoyer): Consider adding WindVectorEnergyScore based on a pair of wind
1008-
# components, as a sort of probabilistic variant of WindVectorRMSE.
1041+
# components, as a sort of probabilistic variant of WindVectorMSE.
10091042

10101043

10111044
@dataclasses.dataclass

weatherbench2/metrics_test.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_get_lat_weights(self):
7070
xr.testing.assert_allclose(expected, weights)
7171

7272
def test_wind_vector_rmse(self):
73-
wv = metrics.WindVectorRMSE(
73+
wv = metrics.WindVectorRMSESqrtBeforeTimeAvg(
7474
u_name='u_component_of_wind',
7575
v_name='v_component_of_wind',
7676
vector_name='wind_vector',
@@ -119,7 +119,7 @@ def test_wind_vector_rmse(self):
119119
dict(testcase_name='nan', invalid_value=np.nan),
120120
)
121121
def test_rmse_over_invalid_region(self, invalid_value):
122-
rmse = metrics.RMSE()
122+
rmse = metrics.RMSESqrtBeforeTimeAvg()
123123
truth = xr.Dataset(
124124
{'wind_speed': ('latitude', [0.0, invalid_value, 0.0])},
125125
coords={'latitude': [-45, 0, 45]},
@@ -460,8 +460,12 @@ class EnsembleMeanRMSEAndStddevTest(parameterized.TestCase):
460460
def test_on_random_dataset(self, ensemble_size):
461461
truth, forecast = get_random_truth_and_forecast(ensemble_size=ensemble_size)
462462

463-
rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth)
464-
ensemble_stddev = metrics.EnsembleStddev().compute_chunk(forecast, truth)
463+
rmse = metrics.EnsembleMeanRMSESqrtBeforeTimeAvg().compute_chunk(
464+
forecast, truth
465+
)
466+
ensemble_stddev = metrics.EnsembleStddevSqrtBeforeTimeAvg().compute_chunk(
467+
forecast, truth
468+
)
465469

466470
for dataset in [rmse, ensemble_stddev]:
467471
self.assertEqual(
@@ -496,15 +500,23 @@ def test_effect_of_large_bias_on_rmse(self):
496500
truth, forecast = get_random_truth_and_forecast(ensemble_size=10)
497501
truth += 1000
498502

499-
mean_rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth).mean()
503+
mean_rmse = (
504+
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg()
505+
.compute_chunk(forecast, truth)
506+
.mean()
507+
)
500508

501509
# Dominated by bias of 1000
502510
np.testing.assert_allclose(1000, mean_rmse.geopotential.values, rtol=1e-3)
503511

504512
def test_perfect_prediction_zero_rmse(self):
505513
truth, unused_forecast = get_random_truth_and_forecast(ensemble_size=10)
506514
forecast = truth.expand_dims({metrics.REALIZATION: 1})
507-
mean_rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth).mean()
515+
mean_rmse = (
516+
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg()
517+
.compute_chunk(forecast, truth)
518+
.mean()
519+
)
508520

509521
xr.testing.assert_allclose(xr.zeros_like(mean_rmse), mean_rmse)
510522

weatherbench2/regions_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def testLandRegion(self):
4343
lsm = lsm.where(lsm.latitude < 1.0, 1)
4444
land_region = regions.LandRegion(lsm)
4545

46-
rmse = metrics.RMSE()
46+
rmse = metrics.RMSESqrtBeforeTimeAvg()
4747

4848
results = rmse.compute(forecast, truth, region=land_region)
4949
np.testing.assert_allclose(results['2m_temperature'].values, 0.0)

0 commit comments

Comments
 (0)