Skip to content

Commit 21f20b9

Browse files
langmoreWeatherbench2 authors
authored and
Weatherbench2 authors
committed
No public description
PiperOrigin-RevId: 588176564
1 parent 75ccb1d commit 21f20b9

File tree

6 files changed

+123
-64
lines changed

6 files changed

+123
-64
lines changed

scripts/evaluate.py

+29-15
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,15 @@
229229
)
230230

231231

232-
def _wind_vector_rmse():
233-
"""Defines Wind Vector RMSEs if U/V components are in variables."""
234-
wind_vector_rmse = []
232+
def _wind_vector_error(err_type: str):
233+
"""Defines Wind Vector [R]MSEs if U/V components are in variables."""
234+
if err_type == 'mse':
235+
cls = metrics.WindVectorMSE
236+
elif err_type == 'rmse':
237+
cls = metrics.WindVectorRMSESqrtBeforeTimeAvg
238+
else:
239+
raise ValueError(f'Unrecognized {err_type=}')
240+
wind_vector_error = []
235241
available = set(VARIABLES.value).union(DERIVED_VARIABLES.value)
236242
for u_name, v_name, vector_name in [
237243
('u_component_of_wind', 'v_component_of_wind', 'wind_vector'),
@@ -248,14 +254,14 @@ def _wind_vector_rmse():
248254
),
249255
]:
250256
if u_name in available and v_name in available:
251-
wind_vector_rmse.append(
252-
metrics.WindVectorRMSE(
257+
wind_vector_error.append(
258+
cls(
253259
u_name=u_name,
254260
v_name=v_name,
255261
vector_name=vector_name,
256262
)
257263
)
258-
return wind_vector_rmse
264+
return wind_vector_error
259265

260266

261267
def main(argv: list[str]) -> None:
@@ -349,12 +355,16 @@ def main(argv: list[str]) -> None:
349355
climatology = evaluation.make_latitude_increasing(climatology)
350356

351357
deterministic_metrics = {
352-
'rmse': metrics.RMSE(wind_vector_rmse=_wind_vector_rmse()),
353-
'mse': metrics.MSE(),
358+
'mse': metrics.MSE(wind_vector_mse=_wind_vector_error('mse')),
354359
'acc': metrics.ACC(climatology=climatology),
355360
'bias': metrics.Bias(),
356361
'mae': metrics.MAE(),
357362
}
363+
rmse_metrics = {
364+
'rmse_sqrt_before_time_avg': metrics.RMSESqrtBeforeTimeAvg(
365+
wind_vector_rmse=_wind_vector_error('rmse')
366+
),
367+
}
358368
spatial_metrics = {
359369
'bias': metrics.SpatialBias(),
360370
'mse': metrics.SpatialMSE(),
@@ -404,7 +414,7 @@ def main(argv: list[str]) -> None:
404414
output_format='zarr',
405415
),
406416
'deterministic_temporal': config.Eval(
407-
metrics=deterministic_metrics,
417+
metrics=deterministic_metrics | rmse_metrics,
408418
against_analysis=False,
409419
regions=regions,
410420
derived_variables=derived_variables,
@@ -427,15 +437,9 @@ def main(argv: list[str]) -> None:
427437
ensemble_dim=ENSEMBLE_DIM.value
428438
),
429439
'crps_skill': metrics.CRPSSkill(ensemble_dim=ENSEMBLE_DIM.value),
430-
'ensemble_mean_rmse': metrics.EnsembleMeanRMSE(
431-
ensemble_dim=ENSEMBLE_DIM.value
432-
),
433440
'ensemble_mean_mse': metrics.EnsembleMeanMSE(
434441
ensemble_dim=ENSEMBLE_DIM.value
435442
),
436-
'ensemble_stddev': metrics.EnsembleStddev(
437-
ensemble_dim=ENSEMBLE_DIM.value
438-
),
439443
'ensemble_variance': metrics.EnsembleVariance(
440444
ensemble_dim=ENSEMBLE_DIM.value
441445
),
@@ -459,6 +463,16 @@ def main(argv: list[str]) -> None:
459463
'energy_score_skill': metrics.EnergyScoreSkill(
460464
ensemble_dim=ENSEMBLE_DIM.value
461465
),
466+
'ensemble_mean_rmse_sqrt_before_time_avg': (
467+
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg(
468+
ensemble_dim=ENSEMBLE_DIM.value
469+
)
470+
),
471+
'ensemble_stddev_sqrt_before_time_avg': (
472+
metrics.EnsembleStddevSqrtBeforeTimeAvg(
473+
ensemble_dim=ENSEMBLE_DIM.value
474+
)
475+
),
462476
},
463477
against_analysis=False,
464478
derived_variables=derived_variables,

scripts/evaluate_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def _test(self, use_beam=True, input_chunks=None):
9393
evaluate.main([])
9494

9595
for config_name in eval_configs:
96-
expected_sizes_2d = {'metric': 5, 'lead_time': 4, 'region': 4}
97-
expected_sizes_3d = {'metric': 5, 'lead_time': 4, 'region': 4, 'level': 3}
96+
expected_sizes_2d = {'metric': 4, 'lead_time': 4, 'region': 4}
97+
expected_sizes_3d = {'metric': 4, 'lead_time': 4, 'region': 4, 'level': 3}
9898

9999
with self.subTest(config_name):
100100
results_path = os.path.join(output_dir, f'{config_name}.nc')

weatherbench2/evaluation_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@ def test_in_memory_and_beam_consistency(self):
8686
eval_configs = {
8787
'forecast_vs_era': config.Eval(
8888
metrics={
89-
'rmse': metrics.RMSE(),
89+
'rmse': metrics.RMSESqrtBeforeTimeAvg(),
9090
'acc': metrics.ACC(climatology=climatology),
9191
},
9292
against_analysis=False,
9393
),
9494
'forecast_vs_era_by_region': config.Eval(
95-
metrics={'rmse': metrics.RMSE()},
95+
metrics={'rmse': metrics.RMSESqrtBeforeTimeAvg()},
9696
against_analysis=False,
9797
regions=regions,
9898
),
@@ -101,7 +101,7 @@ def test_in_memory_and_beam_consistency(self):
101101
against_analysis=False,
102102
),
103103
'forecast_vs_era_temporal': config.Eval(
104-
metrics={'rmse': metrics.RMSE()},
104+
metrics={'rmse': metrics.RMSESqrtBeforeTimeAvg()},
105105
against_analysis=False,
106106
temporal_mean=False,
107107
),

weatherbench2/metrics.py

+70-37
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def _spatial_average_l2_norm(
135135

136136

137137
@dataclasses.dataclass
138-
class WindVectorRMSE(Metric):
139-
"""Compute wind vector RMSE. See WB2 paper for definition.
138+
class WindVectorMSE(Metric):
139+
"""Compute wind vector mean square error. See WB2 paper for definition.
140140
141141
Attributes:
142142
u_name: Name of U component.
@@ -155,25 +155,57 @@ def compute_chunk(
155155
region: t.Optional[Region] = None,
156156
) -> xr.Dataset:
157157
diff = forecast - truth
158-
result = np.sqrt(
159-
_spatial_average(
160-
diff[self.u_name] ** 2 + diff[self.v_name] ** 2,
161-
region=region,
162-
)
158+
result = _spatial_average(
159+
diff[self.u_name] ** 2 + diff[self.v_name] ** 2,
160+
region=region,
163161
)
164162
return result
165163

166164

167165
@dataclasses.dataclass
168-
class RMSE(Metric):
166+
class WindVectorRMSESqrtBeforeTimeAvg(Metric):
167+
"""Compute wind vector RMSE. See WB2 paper for definition.
168+
169+
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
170+
Most users will prefer to use WindVectorMSE and then take a square root in
171+
user code after running the evaluate script.
172+
173+
Attributes:
174+
u_name: Name of U component.
175+
v_name: Name of V component.
176+
vector_name: Name of wind vector to be computed.
177+
"""
178+
179+
u_name: str
180+
v_name: str
181+
vector_name: str
182+
183+
def compute_chunk(
184+
self,
185+
forecast: xr.Dataset,
186+
truth: xr.Dataset,
187+
region: t.Optional[Region] = None,
188+
) -> xr.Dataset:
189+
mse = WindVectorMSE(
190+
u_name=self.u_name, v_name=self.v_name, vector_name=self.vector_name
191+
).compute_chunk(forecast, truth, region=region)
192+
return np.sqrt(mse)
193+
194+
195+
@dataclasses.dataclass
196+
class RMSESqrtBeforeTimeAvg(Metric):
169197
"""Root mean squared error.
170198
199+
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
200+
Most users will prefer to use MSE and then take a square root in user
201+
code after running the evaluate script.
202+
171203
Attributes:
172-
wind_vector_rmse: Optionally provide list of WindVectorRMSE instances to
173-
compute.
204+
wind_vector_rmse: Optionally provide list of WindVectorRMSESqrtBeforeTimeAvg
205+
instances to compute.
174206
"""
175207

176-
wind_vector_rmse: t.Optional[list[WindVectorRMSE]] = None
208+
wind_vector_rmse: t.Optional[list[WindVectorRMSESqrtBeforeTimeAvg]] = None
177209

178210
def compute_chunk(
179211
self,
@@ -192,15 +224,28 @@ def compute_chunk(
192224

193225
@dataclasses.dataclass
194226
class MSE(Metric):
195-
"""Mean squared error."""
227+
"""Mean squared error.
228+
229+
Attributes:
230+
wind_vector_mse: Optionally provide list of WindVectorMSE instances to
231+
compute.
232+
"""
233+
234+
wind_vector_mse: t.Optional[list[WindVectorMSE]] = None
196235

197236
def compute_chunk(
198237
self,
199238
forecast: xr.Dataset,
200239
truth: xr.Dataset,
201240
region: t.Optional[Region] = None,
202241
) -> xr.Dataset:
203-
return _spatial_average((forecast - truth) ** 2, region=region)
242+
results = _spatial_average((forecast - truth) ** 2, region=region)
243+
if self.wind_vector_mse is not None:
244+
for wv in self.wind_vector_mse:
245+
results[wv.vector_name] = wv.compute_chunk(
246+
forecast, truth, region=region
247+
)
248+
return results
204249

205250

206251
@dataclasses.dataclass
@@ -717,14 +762,15 @@ def _pointwise_gaussian_crps(
717762

718763

719764
@dataclasses.dataclass
720-
class EnsembleStddev(EnsembleMetric):
765+
class EnsembleStddevSqrtBeforeTimeAvg(EnsembleMetric):
721766
"""The standard deviation of an ensemble of forecasts.
722767
723-
This forms the SPREAD component of the traditional spread-skill-ratio. See
724-
[Garg & Rasp & Thuerey, 2022].
768+
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
769+
Most users will prefer to use EnsembleVariance and then take a square root in
770+
user code after running the evaluate script.
725771
726772
Given predictive ensemble Xₜ at times t = (1,..., T),
727-
EnsembleStddev := (1 / T) Σₜ ‖σ(Xₜ)‖
773+
EnsembleStddevSqrtBeforeTimeAvg := (1 / T) Σₜ ‖σ(Xₜ)‖
728774
Above σ(Xₜ) is element-wise standard deviation, and ‖⋅‖ is an area-weighted
729775
L2 norm.
730776
@@ -737,15 +783,6 @@ class EnsembleStddev(EnsembleMetric):
737783
738784
NaN values propagate through and result in NaN in the corresponding output
739785
position.
740-
741-
We use the unbiased estimator of σ(Xₜ) (dividing by n_ensemble - 1). If
742-
n_ensemble = 1, we return zero for the stddev. This choice allows
743-
EnsembleStddev to behave in the spread-skill-ratio as expected.
744-
745-
References:
746-
[Garg & Rasp & Thuerey, 2022], WeatherBench Probability: A benchmark dataset
747-
for probabilistic medium-range weather forecasting along with deep learning
748-
baseline models.
749786
"""
750787

751788
def compute_chunk(
@@ -754,7 +791,7 @@ def compute_chunk(
754791
truth: xr.Dataset,
755792
region: t.Optional[Region] = None,
756793
) -> xr.Dataset:
757-
"""EnsembleStddev, averaged over space, for a time chunk of data."""
794+
"""Ensemble Stddev, averaged over space, for a time chunk of data."""
758795
del truth # unused
759796
n_ensemble = _get_n_ensemble(forecast, self.ensemble_dim)
760797

@@ -825,15 +862,16 @@ def compute_chunk(
825862

826863

827864
@dataclasses.dataclass
828-
class EnsembleMeanRMSE(EnsembleMetric):
865+
class EnsembleMeanRMSESqrtBeforeTimeAvg(EnsembleMetric):
829866
"""RMSE between the ensemble mean and ground truth.
830867
831-
This forms the SKILL component of the traditional spread-skill-ratio. See
832-
[Garg & Rasp & Thuerey, 2022].
868+
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
869+
Most users will prefer to use EnsembleMeanMSE and then take a square root in
870+
user code after running the evaluate script.
833871
834872
Given ground truth Yₜ, and predictive ensemble Xₜ, both at times
835873
t = (1,..., T),
836-
EnsembleMeanRMSE := (1 / T) Σₜ ‖Y - E(Xₜ)‖.
874+
EnsembleMeanRMSESqrtBeforeTimeAvg := (1 / T) Σₜ ‖Y - E(Xₜ)‖.
837875
Above, `E` is ensemble average, and ‖⋅‖ is an area-weighted L2 norm.
838876
839877
Estimation is done separately for each tendency, level, and lag time.
@@ -845,11 +883,6 @@ class EnsembleMeanRMSE(EnsembleMetric):
845883
846884
NaN values propagate through and result in NaN in the corresponding output
847885
position.
848-
849-
References:
850-
[Garg & Rasp & Thuerey, 2022], WeatherBench Probability: A benchmark dataset
851-
for probabilistic medium-range weather forecasting along with deep learning
852-
baseline models.
853886
"""
854887

855888
def compute_chunk(
@@ -1005,7 +1038,7 @@ def compute_chunk(
10051038

10061039

10071040
# TODO(shoyer): Consider adding WindVectorEnergyScore based on a pair of wind
1008-
# components, as a sort of probabilistic variant of WindVectorRMSE.
1041+
# components, as a sort of probabilistic variant of WindVectorMSE.
10091042

10101043

10111044
@dataclasses.dataclass

weatherbench2/metrics_test.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_get_lat_weights(self):
7070
xr.testing.assert_allclose(expected, weights)
7171

7272
def test_wind_vector_rmse(self):
73-
wv = metrics.WindVectorRMSE(
73+
wv = metrics.WindVectorRMSESqrtBeforeTimeAvg(
7474
u_name='u_component_of_wind',
7575
v_name='v_component_of_wind',
7676
vector_name='wind_vector',
@@ -119,7 +119,7 @@ def test_wind_vector_rmse(self):
119119
dict(testcase_name='nan', invalid_value=np.nan),
120120
)
121121
def test_rmse_over_invalid_region(self, invalid_value):
122-
rmse = metrics.RMSE()
122+
rmse = metrics.RMSESqrtBeforeTimeAvg()
123123
truth = xr.Dataset(
124124
{'wind_speed': ('latitude', [0.0, invalid_value, 0.0])},
125125
coords={'latitude': [-45, 0, 45]},
@@ -460,8 +460,12 @@ class EnsembleMeanRMSEAndStddevTest(parameterized.TestCase):
460460
def test_on_random_dataset(self, ensemble_size):
461461
truth, forecast = get_random_truth_and_forecast(ensemble_size=ensemble_size)
462462

463-
rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth)
464-
ensemble_stddev = metrics.EnsembleStddev().compute_chunk(forecast, truth)
463+
rmse = metrics.EnsembleMeanRMSESqrtBeforeTimeAvg().compute_chunk(
464+
forecast, truth
465+
)
466+
ensemble_stddev = metrics.EnsembleStddevSqrtBeforeTimeAvg().compute_chunk(
467+
forecast, truth
468+
)
465469

466470
for dataset in [rmse, ensemble_stddev]:
467471
self.assertEqual(
@@ -496,15 +500,23 @@ def test_effect_of_large_bias_on_rmse(self):
496500
truth, forecast = get_random_truth_and_forecast(ensemble_size=10)
497501
truth += 1000
498502

499-
mean_rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth).mean()
503+
mean_rmse = (
504+
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg()
505+
.compute_chunk(forecast, truth)
506+
.mean()
507+
)
500508

501509
# Dominated by bias of 1000
502510
np.testing.assert_allclose(1000, mean_rmse.geopotential.values, rtol=1e-3)
503511

504512
def test_perfect_prediction_zero_rmse(self):
505513
truth, unused_forecast = get_random_truth_and_forecast(ensemble_size=10)
506514
forecast = truth.expand_dims({metrics.REALIZATION: 1})
507-
mean_rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth).mean()
515+
mean_rmse = (
516+
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg()
517+
.compute_chunk(forecast, truth)
518+
.mean()
519+
)
508520

509521
xr.testing.assert_allclose(xr.zeros_like(mean_rmse), mean_rmse)
510522

weatherbench2/regions_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def testLandRegion(self):
4343
lsm = lsm.where(lsm.latitude < 1.0, 1)
4444
land_region = regions.LandRegion(lsm)
4545

46-
rmse = metrics.RMSE()
46+
rmse = metrics.RMSESqrtBeforeTimeAvg()
4747

4848
results = rmse.compute(forecast, truth, region=land_region)
4949
np.testing.assert_allclose(results['2m_temperature'].values, 0.0)

0 commit comments

Comments
 (0)