Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No public description #98

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,15 @@
)


def _wind_vector_rmse():
"""Defines Wind Vector RMSEs if U/V components are in variables."""
wind_vector_rmse = []
def _wind_vector_error(err_type: str):
"""Defines Wind Vector [R]MSEs if U/V components are in variables."""
if err_type == 'mse':
cls = metrics.WindVectorMSE
elif err_type == 'rmse':
cls = metrics.WindVectorRMSESqrtBeforeTimeAvg
else:
raise ValueError(f'Unrecognized {err_type=}')
wind_vector_error = []
available = set(VARIABLES.value).union(DERIVED_VARIABLES.value)
for u_name, v_name, vector_name in [
('u_component_of_wind', 'v_component_of_wind', 'wind_vector'),
Expand All @@ -248,14 +254,14 @@ def _wind_vector_rmse():
),
]:
if u_name in available and v_name in available:
wind_vector_rmse.append(
metrics.WindVectorRMSE(
wind_vector_error.append(
cls(
u_name=u_name,
v_name=v_name,
vector_name=vector_name,
)
)
return wind_vector_rmse
return wind_vector_error


def main(argv: list[str]) -> None:
Expand Down Expand Up @@ -349,12 +355,16 @@ def main(argv: list[str]) -> None:
climatology = evaluation.make_latitude_increasing(climatology)

deterministic_metrics = {
'rmse': metrics.RMSE(wind_vector_rmse=_wind_vector_rmse()),
'mse': metrics.MSE(),
'mse': metrics.MSE(wind_vector_mse=_wind_vector_error('mse')),
'acc': metrics.ACC(climatology=climatology),
'bias': metrics.Bias(),
'mae': metrics.MAE(),
}
rmse_metrics = {
'rmse_sqrt_before_time_avg': metrics.RMSESqrtBeforeTimeAvg(
wind_vector_rmse=_wind_vector_error('rmse')
),
}
spatial_metrics = {
'bias': metrics.SpatialBias(),
'mse': metrics.SpatialMSE(),
Expand Down Expand Up @@ -404,7 +414,7 @@ def main(argv: list[str]) -> None:
output_format='zarr',
),
'deterministic_temporal': config.Eval(
metrics=deterministic_metrics,
metrics=deterministic_metrics | rmse_metrics,
against_analysis=False,
regions=regions,
derived_variables=derived_variables,
Expand All @@ -427,15 +437,9 @@ def main(argv: list[str]) -> None:
ensemble_dim=ENSEMBLE_DIM.value
),
'crps_skill': metrics.CRPSSkill(ensemble_dim=ENSEMBLE_DIM.value),
'ensemble_mean_rmse': metrics.EnsembleMeanRMSE(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_mean_mse': metrics.EnsembleMeanMSE(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_stddev': metrics.EnsembleStddev(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_variance': metrics.EnsembleVariance(
ensemble_dim=ENSEMBLE_DIM.value
),
Expand All @@ -459,6 +463,16 @@ def main(argv: list[str]) -> None:
'energy_score_skill': metrics.EnergyScoreSkill(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_mean_rmse_sqrt_before_time_avg': (
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg(
ensemble_dim=ENSEMBLE_DIM.value
)
),
'ensemble_stddev_sqrt_before_time_avg': (
metrics.EnsembleStddevSqrtBeforeTimeAvg(
ensemble_dim=ENSEMBLE_DIM.value
)
),
},
against_analysis=False,
derived_variables=derived_variables,
Expand Down
4 changes: 2 additions & 2 deletions scripts/evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def _test(self, use_beam=True, input_chunks=None):
evaluate.main([])

for config_name in eval_configs:
expected_sizes_2d = {'metric': 5, 'lead_time': 4, 'region': 4}
expected_sizes_3d = {'metric': 5, 'lead_time': 4, 'region': 4, 'level': 3}
expected_sizes_2d = {'metric': 4, 'lead_time': 4, 'region': 4}
expected_sizes_3d = {'metric': 4, 'lead_time': 4, 'region': 4, 'level': 3}

with self.subTest(config_name):
results_path = os.path.join(output_dir, f'{config_name}.nc')
Expand Down
6 changes: 3 additions & 3 deletions weatherbench2/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def test_in_memory_and_beam_consistency(self):
eval_configs = {
'forecast_vs_era': config.Eval(
metrics={
'rmse': metrics.RMSE(),
'rmse': metrics.RMSESqrtBeforeTimeAvg(),
'acc': metrics.ACC(climatology=climatology),
},
against_analysis=False,
),
'forecast_vs_era_by_region': config.Eval(
metrics={'rmse': metrics.RMSE()},
metrics={'rmse': metrics.RMSESqrtBeforeTimeAvg()},
against_analysis=False,
regions=regions,
),
Expand All @@ -101,7 +101,7 @@ def test_in_memory_and_beam_consistency(self):
against_analysis=False,
),
'forecast_vs_era_temporal': config.Eval(
metrics={'rmse': metrics.RMSE()},
metrics={'rmse': metrics.RMSESqrtBeforeTimeAvg()},
against_analysis=False,
temporal_mean=False,
),
Expand Down
107 changes: 70 additions & 37 deletions weatherbench2/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def _spatial_average_l2_norm(


@dataclasses.dataclass
class WindVectorRMSE(Metric):
"""Compute wind vector RMSE. See WB2 paper for definition.
class WindVectorMSE(Metric):
"""Compute wind vector mean square error. See WB2 paper for definition.

Attributes:
u_name: Name of U component.
Expand All @@ -155,25 +155,57 @@ def compute_chunk(
region: t.Optional[Region] = None,
) -> xr.Dataset:
diff = forecast - truth
result = np.sqrt(
_spatial_average(
diff[self.u_name] ** 2 + diff[self.v_name] ** 2,
region=region,
)
result = _spatial_average(
diff[self.u_name] ** 2 + diff[self.v_name] ** 2,
region=region,
)
return result


@dataclasses.dataclass
class RMSE(Metric):
class WindVectorRMSESqrtBeforeTimeAvg(Metric):
"""Compute wind vector RMSE. See WB2 paper for definition.

This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
Most users will prefer to use WindVectorMSE and then take a square root in
user code after running the evaluate script.

Attributes:
u_name: Name of U component.
v_name: Name of V component.
vector_name: Name of wind vector to be computed.
"""

u_name: str
v_name: str
vector_name: str

def compute_chunk(
self,
forecast: xr.Dataset,
truth: xr.Dataset,
region: t.Optional[Region] = None,
) -> xr.Dataset:
mse = WindVectorMSE(
u_name=self.u_name, v_name=self.v_name, vector_name=self.vector_name
).compute_chunk(forecast, truth, region=region)
return np.sqrt(mse)


@dataclasses.dataclass
class RMSESqrtBeforeTimeAvg(Metric):
"""Root mean squared error.

This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
Most users will prefer to use MSE and then take a square root in user
code after running the evaluate script.

Attributes:
wind_vector_rmse: Optionally provide list of WindVectorRMSE instances to
compute.
wind_vector_rmse: Optionally provide list of WindVectorRMSESqrtBeforeTimeAvg
instances to compute.
"""

wind_vector_rmse: t.Optional[list[WindVectorRMSE]] = None
wind_vector_rmse: t.Optional[list[WindVectorRMSESqrtBeforeTimeAvg]] = None

def compute_chunk(
self,
Expand All @@ -192,15 +224,28 @@ def compute_chunk(

@dataclasses.dataclass
class MSE(Metric):
"""Mean squared error."""
"""Mean squared error.

Attributes:
wind_vector_mse: Optionally provide list of WindVectorMSE instances to
compute.
"""

wind_vector_mse: t.Optional[list[WindVectorMSE]] = None

def compute_chunk(
self,
forecast: xr.Dataset,
truth: xr.Dataset,
region: t.Optional[Region] = None,
) -> xr.Dataset:
return _spatial_average((forecast - truth) ** 2, region=region)
results = _spatial_average((forecast - truth) ** 2, region=region)
if self.wind_vector_mse is not None:
for wv in self.wind_vector_mse:
results[wv.vector_name] = wv.compute_chunk(
forecast, truth, region=region
)
return results


@dataclasses.dataclass
Expand Down Expand Up @@ -717,14 +762,15 @@ def _pointwise_gaussian_crps(


@dataclasses.dataclass
class EnsembleStddev(EnsembleMetric):
class EnsembleStddevSqrtBeforeTimeAvg(EnsembleMetric):
"""The standard deviation of an ensemble of forecasts.

This forms the SPREAD component of the traditional spread-skill-ratio. See
[Garg & Rasp & Thuerey, 2022].
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
Most users will prefer to use EnsembleVariance and then take a square root in
user code after running the evaluate script.

Given predictive ensemble Xₜ at times t = (1,..., T),
EnsembleStddev := (1 / T) Σₜ ‖σ(Xₜ)‖
EnsembleStddevSqrtBeforeTimeAvg := (1 / T) Σₜ ‖σ(Xₜ)‖
Above σ(Xₜ) is element-wise standard deviation, and ‖⋅‖ is an area-weighted
L2 norm.

Expand All @@ -737,15 +783,6 @@ class EnsembleStddev(EnsembleMetric):

NaN values propagate through and result in NaN in the corresponding output
position.

We use the unbiased estimator of σ(Xₜ) (dividing by n_ensemble - 1). If
n_ensemble = 1, we return zero for the stddev. This choice allows
EnsembleStddev to behave in the spread-skill-ratio as expected.

References:
[Garg & Rasp & Thuerey, 2022], WeatherBench Probability: A benchmark dataset
for probabilistic medium-range weather forecasting along with deep learning
baseline models.
"""

def compute_chunk(
Expand All @@ -754,7 +791,7 @@ def compute_chunk(
truth: xr.Dataset,
region: t.Optional[Region] = None,
) -> xr.Dataset:
"""EnsembleStddev, averaged over space, for a time chunk of data."""
"""Ensemble Stddev, averaged over space, for a time chunk of data."""
del truth # unused
n_ensemble = _get_n_ensemble(forecast, self.ensemble_dim)

Expand Down Expand Up @@ -825,15 +862,16 @@ def compute_chunk(


@dataclasses.dataclass
class EnsembleMeanRMSE(EnsembleMetric):
class EnsembleMeanRMSESqrtBeforeTimeAvg(EnsembleMetric):
"""RMSE between the ensemble mean and ground truth.

This forms the SKILL component of the traditional spread-skill-ratio. See
[Garg & Rasp & Thuerey, 2022].
This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
Most users will prefer to use EnsembleMeanMSE and then take a square root in
user code after running the evaluate script.

Given ground truth Yₜ, and predictive ensemble Xₜ, both at times
t = (1,..., T),
EnsembleMeanRMSE := (1 / T) Σₜ ‖Y - E(Xₜ)‖.
EnsembleMeanRMSESqrtBeforeTimeAvg := (1 / T) Σₜ ‖Y - E(Xₜ)‖.
Above, `E` is ensemble average, and ‖⋅‖ is an area-weighted L2 norm.

Estimation is done separately for each tendency, level, and lag time.
Expand All @@ -845,11 +883,6 @@ class EnsembleMeanRMSE(EnsembleMetric):

NaN values propagate through and result in NaN in the corresponding output
position.

References:
[Garg & Rasp & Thuerey, 2022], WeatherBench Probability: A benchmark dataset
for probabilistic medium-range weather forecasting along with deep learning
baseline models.
"""

def compute_chunk(
Expand Down Expand Up @@ -1005,7 +1038,7 @@ def compute_chunk(


# TODO(shoyer): Consider adding WindVectorEnergyScore based on a pair of wind
# components, as a sort of probabilistic variant of WindVectorRMSE.
# components, as a sort of probabilistic variant of WindVectorMSE.


@dataclasses.dataclass
Expand Down
24 changes: 18 additions & 6 deletions weatherbench2/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_get_lat_weights(self):
xr.testing.assert_allclose(expected, weights)

def test_wind_vector_rmse(self):
wv = metrics.WindVectorRMSE(
wv = metrics.WindVectorRMSESqrtBeforeTimeAvg(
u_name='u_component_of_wind',
v_name='v_component_of_wind',
vector_name='wind_vector',
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_wind_vector_rmse(self):
dict(testcase_name='nan', invalid_value=np.nan),
)
def test_rmse_over_invalid_region(self, invalid_value):
rmse = metrics.RMSE()
rmse = metrics.RMSESqrtBeforeTimeAvg()
truth = xr.Dataset(
{'wind_speed': ('latitude', [0.0, invalid_value, 0.0])},
coords={'latitude': [-45, 0, 45]},
Expand Down Expand Up @@ -460,8 +460,12 @@ class EnsembleMeanRMSEAndStddevTest(parameterized.TestCase):
def test_on_random_dataset(self, ensemble_size):
truth, forecast = get_random_truth_and_forecast(ensemble_size=ensemble_size)

rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth)
ensemble_stddev = metrics.EnsembleStddev().compute_chunk(forecast, truth)
rmse = metrics.EnsembleMeanRMSESqrtBeforeTimeAvg().compute_chunk(
forecast, truth
)
ensemble_stddev = metrics.EnsembleStddevSqrtBeforeTimeAvg().compute_chunk(
forecast, truth
)

for dataset in [rmse, ensemble_stddev]:
self.assertEqual(
Expand Down Expand Up @@ -496,15 +500,23 @@ def test_effect_of_large_bias_on_rmse(self):
truth, forecast = get_random_truth_and_forecast(ensemble_size=10)
truth += 1000

mean_rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth).mean()
mean_rmse = (
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg()
.compute_chunk(forecast, truth)
.mean()
)

# Dominated by bias of 1000
np.testing.assert_allclose(1000, mean_rmse.geopotential.values, rtol=1e-3)

def test_perfect_prediction_zero_rmse(self):
truth, unused_forecast = get_random_truth_and_forecast(ensemble_size=10)
forecast = truth.expand_dims({metrics.REALIZATION: 1})
mean_rmse = metrics.EnsembleMeanRMSE().compute_chunk(forecast, truth).mean()
mean_rmse = (
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg()
.compute_chunk(forecast, truth)
.mean()
)

xr.testing.assert_allclose(xr.zeros_like(mean_rmse), mean_rmse)

Expand Down
2 changes: 1 addition & 1 deletion weatherbench2/regions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def testLandRegion(self):
lsm = lsm.where(lsm.latitude < 1.0, 1)
land_region = regions.LandRegion(lsm)

rmse = metrics.RMSE()
rmse = metrics.RMSESqrtBeforeTimeAvg()

results = rmse.compute(forecast, truth, region=land_region)
np.testing.assert_allclose(results['2m_temperature'].values, 0.0)
Expand Down
Loading