Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588176564
  • Loading branch information
langmore authored and Weatherbench2 authors committed Dec 6, 2023
1 parent 75ccb1d commit 99ce4af
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 64 deletions.
44 changes: 29 additions & 15 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,15 @@
)


def _wind_vector_rmse():
"""Defines Wind Vector RMSEs if U/V components are in variables."""
wind_vector_rmse = []
def _wind_vector_error(err_type: str):
"""Defines Wind Vector [R]MSEs if U/V components are in variables."""
if err_type == 'mse':
cls = metrics.WindVectorMSE
elif err_type == 'rmse':
cls = metrics.WindVectorRMSESqrtBeforeTimeAvg
else:
raise ValueError(f'Unrecognized {err_type=}')
wind_vector_error = []
available = set(VARIABLES.value).union(DERIVED_VARIABLES.value)
for u_name, v_name, vector_name in [
('u_component_of_wind', 'v_component_of_wind', 'wind_vector'),
Expand All @@ -248,14 +254,14 @@ def _wind_vector_rmse():
),
]:
if u_name in available and v_name in available:
wind_vector_rmse.append(
metrics.WindVectorRMSE(
wind_vector_error.append(
cls(
u_name=u_name,
v_name=v_name,
vector_name=vector_name,
)
)
return wind_vector_rmse
return wind_vector_error


def main(argv: list[str]) -> None:
Expand Down Expand Up @@ -349,12 +355,16 @@ def main(argv: list[str]) -> None:
climatology = evaluation.make_latitude_increasing(climatology)

deterministic_metrics = {
'rmse': metrics.RMSE(wind_vector_rmse=_wind_vector_rmse()),
'mse': metrics.MSE(),
'mse': metrics.MSE(wind_vector_mse=_wind_vector_error('mse')),
'acc': metrics.ACC(climatology=climatology),
'bias': metrics.Bias(),
'mae': metrics.MAE(),
}
rmse_metrics = {
'rmse_sqrt_before_time_avg': metrics.RMSESqrtBeforeTimeAvg(
wind_vector_rmse=_wind_vector_error('rmse')
),
}
spatial_metrics = {
'bias': metrics.SpatialBias(),
'mse': metrics.SpatialMSE(),
Expand Down Expand Up @@ -404,7 +414,7 @@ def main(argv: list[str]) -> None:
output_format='zarr',
),
'deterministic_temporal': config.Eval(
metrics=deterministic_metrics,
metrics=deterministic_metrics | rmse_metrics,
against_analysis=False,
regions=regions,
derived_variables=derived_variables,
Expand All @@ -427,15 +437,9 @@ def main(argv: list[str]) -> None:
ensemble_dim=ENSEMBLE_DIM.value
),
'crps_skill': metrics.CRPSSkill(ensemble_dim=ENSEMBLE_DIM.value),
'ensemble_mean_rmse': metrics.EnsembleMeanRMSE(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_mean_mse': metrics.EnsembleMeanMSE(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_stddev': metrics.EnsembleStddev(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_variance': metrics.EnsembleVariance(
ensemble_dim=ENSEMBLE_DIM.value
),
Expand All @@ -459,6 +463,16 @@ def main(argv: list[str]) -> None:
'energy_score_skill': metrics.EnergyScoreSkill(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_mean_rmse_sqrt_before_time_avg': (
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg(
ensemble_dim=ENSEMBLE_DIM.value
)
),
'ensemble_stddev_sqrt_before_time_avg': (
metrics.EnsembleStddevSqrtBeforeTimeAvg(
ensemble_dim=ENSEMBLE_DIM.value
)
),
},
against_analysis=False,
derived_variables=derived_variables,
Expand Down
4 changes: 2 additions & 2 deletions scripts/evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def _test(self, use_beam=True, input_chunks=None):
evaluate.main([])

for config_name in eval_configs:
expected_sizes_2d = {'metric': 5, 'lead_time': 4, 'region': 4}
expected_sizes_3d = {'metric': 5, 'lead_time': 4, 'region': 4, 'level': 3}
expected_sizes_2d = {'metric': 4, 'lead_time': 4, 'region': 4}
expected_sizes_3d = {'metric': 4, 'lead_time': 4, 'region': 4, 'level': 3}

with self.subTest(config_name):
results_path = os.path.join(output_dir, f'{config_name}.nc')
Expand Down
6 changes: 3 additions & 3 deletions weatherbench2/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def test_in_memory_and_beam_consistency(self):
eval_configs = {
'forecast_vs_era': config.Eval(
metrics={
'rmse': metrics.RMSE(),
'rmse': metrics.RMSESqrtBeforeTimeAvg(),
'acc': metrics.ACC(climatology=climatology),
},
against_analysis=False,
),
'forecast_vs_era_by_region': config.Eval(
metrics={'rmse': metrics.RMSE()},
metrics={'rmse': metrics.RMSESqrtBeforeTimeAvg()},
against_analysis=False,
regions=regions,
),
Expand All @@ -101,7 +101,7 @@ def test_in_memory_and_beam_consistency(self):
against_analysis=False,
),
'forecast_vs_era_temporal': config.Eval(
metrics={'rmse': metrics.RMSE()},
metrics={'rmse': metrics.RMSESqrtBeforeTimeAvg()},
against_analysis=False,
temporal_mean=False,
),
Expand Down
107 changes: 70 additions & 37 deletions weatherbench2/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def _spatial_average_l2_norm(


@dataclasses.dataclass
class WindVectorRMSE(Metric):
"""Compute wind vector RMSE. See WB2 paper for definition.
class WindVectorMSE(Metric):
"""Compute wind vector mean square error. See WB2 paper for definition.
Attributes:
u_name: Name of U component.
Expand All @@ -155,25 +155,57 @@ def compute_chunk(
region: t.Optional[Region] = None,
) -> xr.Dataset:
diff = forecast - truth
result = np.sqrt(
_spatial_average(
diff[self.u_name] ** 2 + diff[self.v_name] ** 2,
region=region,
)
result = _spatial_average(
diff[self.u_name] ** 2 + diff[self.v_name] ** 2,
region=region,
)
return result


@dataclasses.dataclass
class RMSE(Metric):
class WindVectorRMSESqrtBeforeTimeAvg(Metric):
"""Compute wind vector RMSE. See WB2 paper for definition.
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
Most users will prefer to use WindVectorMSE and then take a square root in
user code after running the evaluate script.
Attributes:
u_name: Name of U component.
v_name: Name of V component.
vector_name: Name of wind vector to be computed.
"""

u_name: str
v_name: str
vector_name: str

def compute_chunk(
self,
forecast: xr.Dataset,
truth: xr.Dataset,
region: t.Optional[Region] = None,
) -> xr.Dataset:
mse = WindVectorMSE(
u_name=self.u_name, v_name=self.v_name, vector_name=self.vector_name
).compute_chunk(forecast, truth, region=region)
return np.sqrt(mse)


@dataclasses.dataclass
class RMSESqrtBeforeTimeAvg(Metric):
"""Root mean squared error.
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
Most users will prefer to use MSE and then take a square root in user
code after running the evaluate script.
Attributes:
wind_vector_rmse: Optionally provide list of WindVectorRMSE instances to
compute.
wind_vector_rmse: Optionally provide list of WindVectorRMSESqrtBeforeTimeAvg
instances to compute.
"""

wind_vector_rmse: t.Optional[list[WindVectorRMSE]] = None
wind_vector_rmse: t.Optional[list[WindVectorRMSESqrtBeforeTimeAvg]] = None

def compute_chunk(
self,
Expand All @@ -192,15 +224,28 @@ def compute_chunk(

@dataclasses.dataclass
class MSE(Metric):
"""Mean squared error."""
"""Mean squared error.
Attributes:
wind_vector_mse: Optionally provide list of WindVectorMSE instances to
compute.
"""

wind_vector_mse: t.Optional[list[WindVectorMSE]] = None

def compute_chunk(
self,
forecast: xr.Dataset,
truth: xr.Dataset,
region: t.Optional[Region] = None,
) -> xr.Dataset:
return _spatial_average((forecast - truth) ** 2, region=region)
results = _spatial_average((forecast - truth) ** 2, region=region)
if self.wind_vector_mse is not None:
for wv in self.wind_vector_mse:
results[wv.vector_name] = wv.compute_chunk(
forecast, truth, region=region
)
return results


@dataclasses.dataclass
Expand Down Expand Up @@ -717,14 +762,15 @@ def _pointwise_gaussian_crps(


@dataclasses.dataclass
class EnsembleStddev(EnsembleMetric):
class EnsembleStddevSqrtBeforeTimeAvg(EnsembleMetric):
"""The standard deviation of an ensemble of forecasts.
This forms the SPREAD component of the traditional spread-skill-ratio. See
[Garg & Rasp & Thuerey, 2022].
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
Most users will prefer to use EnsembleVariance and then take a square root in
user code after running the evaluate script.
Given predictive ensemble Xₜ at times t = (1,..., T),
EnsembleStddev := (1 / T) Σₜ ‖σ(Xₜ)‖
EnsembleStddevSqrtBeforeTimeAvg := (1 / T) Σₜ ‖σ(Xₜ)‖
Above σ(Xₜ) is element-wise standard deviation, and ‖⋅‖ is an area-weighted
L2 norm.
Expand All @@ -737,15 +783,6 @@ class EnsembleStddev(EnsembleMetric):
NaN values propagate through and result in NaN in the corresponding output
position.
We use the unbiased estimator of σ(Xₜ) (dividing by n_ensemble - 1). If
n_ensemble = 1, we return zero for the stddev. This choice allows
EnsembleStddev to behave in the spread-skill-ratio as expected.
References:
[Garg & Rasp & Thuerey, 2022], WeatherBench Probability: A benchmark dataset
for probabilistic medium-range weather forecasting along with deep learning
baseline models.
"""

def compute_chunk(
Expand All @@ -754,7 +791,7 @@ def compute_chunk(
truth: xr.Dataset,
region: t.Optional[Region] = None,
) -> xr.Dataset:
"""EnsembleStddev, averaged over space, for a time chunk of data."""
"""Ensemble Stddev, averaged over space, for a time chunk of data."""
del truth # unused
n_ensemble = _get_n_ensemble(forecast, self.ensemble_dim)

Expand Down Expand Up @@ -825,15 +862,16 @@ def compute_chunk(


@dataclasses.dataclass
class EnsembleMeanRMSE(EnsembleMetric):
class EnsembleMeanRMSESqrtBeforeTimeAvg(EnsembleMetric):
"""RMSE between the ensemble mean and ground truth.
This forms the SKILL component of the traditional spread-skill-ratio. See
[Garg & Rasp & Thuerey, 2022].
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
Most users will prefer to use EnsembleMeanMSE and then take a square root in
user code after running the evaluate script.
Given ground truth Yₜ, and predictive ensemble Xₜ, both at times
t = (1,..., T),
EnsembleMeanRMSE := (1 / T) Σₜ ‖Y - E(Xₜ)‖.
EnsembleMeanRMSESqrtBeforeTimeAvg := (1 / T) Σₜ ‖Y - E(Xₜ)‖.
Above, `E` is ensemble average, and ‖⋅‖ is an area-weighted L2 norm.
Estimation is done separately for each tendency, level, and lag time.
Expand All @@ -845,11 +883,6 @@ class EnsembleMeanRMSE(EnsembleMetric):
NaN values propagate through and result in NaN in the corresponding output
position.
References:
[Garg & Rasp & Thuerey, 2022], WeatherBench Probability: A benchmark dataset
for probabilistic medium-range weather forecasting along with deep learning
baseline models.
"""

def compute_chunk(
Expand Down Expand Up @@ -1005,7 +1038,7 @@ def compute_chunk(


# TODO(shoyer): Consider adding WindVectorEnergyScore based on a pair of wind
# components, as a sort of probabilistic variant of WindVectorRMSE.
# components, as a sort of probabilistic variant of WindVectorMSE.


@dataclasses.dataclass
Expand Down
24 changes: 18 additions & 6 deletions weatherbench2/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_get_lat_weights(self):
xr.testing.assert_allclose(expected, weights)

def test_wind_vector_rmse(self):
wv = metrics.WindVectorRMSE(
wv = metrics.WindVectorRMSESqrtBeforeTimeAvg(
u_name='u_component_of_wind',
v_name='v_component_of_wind',
vector_name='wind_vector',
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_wind_vector_rmse(self):
dict(testcase_name='nan', invalid_value=np.nan),
)
def test_rmse_over_invalid_region(self, invalid_value):
rmse = metrics.RMSE()
rmse = metrics.RMSESqrtBeforeTimeAvg()
truth = xr.Dataset(
{'wind_speed': ('latitude', [0.0, invalid_value, 0.0])},
coords={'latitude': [-45, 0, 45]},
Expand Down Expand Up @@ -460,8 +460,12 @@ class EnsembleMeanRMSEAndStddevTest(parameterized.TestCase):
def test_on_random_dataset(self, ensemble_size):
truth, forecast = get_random_truth_and_forecast(ensemble_size=ensemble_size)

rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth)
ensemble_stddev = metrics.EnsembleStddev().compute_chunk(forecast, truth)
rmse = metrics.EnsembleMeanRMSESqrtBeforeTimeAvg().compute_chunk(
forecast, truth
)
ensemble_stddev = metrics.EnsembleStddevSqrtBeforeTimeAvg().compute_chunk(
forecast, truth
)

for dataset in [rmse, ensemble_stddev]:
self.assertEqual(
Expand Down Expand Up @@ -496,15 +500,23 @@ def test_effect_of_large_bias_on_rmse(self):
truth, forecast = get_random_truth_and_forecast(ensemble_size=10)
truth += 1000

mean_rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth).mean()
mean_rmse = (
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg()
.compute_chunk(forecast, truth)
.mean()
)

# Dominated by bias of 1000
np.testing.assert_allclose(1000, mean_rmse.geopotential.values, rtol=1e-3)

def test_perfect_prediction_zero_rmse(self):
truth, unused_forecast = get_random_truth_and_forecast(ensemble_size=10)
forecast = truth.expand_dims({metrics.REALIZATION: 1})
mean_rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth).mean()
mean_rmse = (
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg()
.compute_chunk(forecast, truth)
.mean()
)

xr.testing.assert_allclose(xr.zeros_like(mean_rmse), mean_rmse)

Expand Down
2 changes: 1 addition & 1 deletion weatherbench2/regions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def testLandRegion(self):
lsm = lsm.where(lsm.latitude < 1.0, 1)
land_region = regions.LandRegion(lsm)

rmse = metrics.RMSE()
rmse = metrics.RMSESqrtBeforeTimeAvg()

results = rmse.compute(forecast, truth, region=land_region)
np.testing.assert_allclose(results['2m_temperature'].values, 0.0)
Expand Down

0 comments on commit 99ce4af

Please sign in to comment.