Skip to content

Commit 2bfbeb7

Browse files
ilopezgpWeatherbench2 authors
authored andcommitted
[weatherbench2] Add GaussianVariance metric.
PiperOrigin-RevId: 592111655
1 parent 582d3d2 commit 2bfbeb7

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

weatherbench2/metrics.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -712,15 +712,15 @@ def _rank_da(da: xr.DataArray) -> np.ndarray:
712712

713713
@dataclasses.dataclass
714714
class GaussianCRPS(Metric):
715-
"""The spread measure associated with CRPS, E|X - X'|."""
715+
"""The analytical formulation of CRPS for a Gaussian."""
716716

717717
def compute_chunk(
718718
self,
719719
forecast: xr.Dataset,
720720
truth: xr.Dataset,
721721
region: t.Optional[Region] = None,
722722
) -> xr.Dataset:
723-
"""CRPSSpread, averaged over space, for a time chunk of data."""
723+
"""GaussianCRPS, averaged over space, for a time chunk of data."""
724724
return _spatial_average(
725725
_pointwise_gaussian_crps(forecast, truth),
726726
region=region,
@@ -770,6 +770,33 @@ def _pointwise_gaussian_crps(
770770
return xr.Dataset(dataset, coords=forecast.coords)
771771

772772

773+
@dataclasses.dataclass
774+
class GaussianVariance(Metric):
775+
"""The variance of a Gaussian forecast."""
776+
777+
def compute_chunk(
778+
self,
779+
forecast: xr.Dataset,
780+
truth: xr.Dataset,
781+
region: t.Optional[Region] = None,
782+
) -> xr.Dataset:
783+
"""GaussianVariance, averaged over space, for a time chunk of data."""
784+
del truth # unused
785+
var_list = []
786+
dataset = {}
787+
for var in forecast.keys():
788+
if f"{var}_std" in forecast.keys():
789+
var_list.append(var)
790+
for var_name in var_list:
791+
variance = forecast[f"{var_name}_std"] * forecast[f"{var_name}_std"]
792+
dataset[var_name] = variance
793+
794+
return _spatial_average(
795+
xr.Dataset(dataset, coords=forecast.coords),
796+
region=region,
797+
)
798+
799+
773800
@dataclasses.dataclass
774801
class EnsembleStddevSqrtBeforeTimeAvg(EnsembleMetric):
775802
"""The standard deviation of an ensemble of forecasts.

weatherbench2/metrics_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,28 @@ def test_gaussian_crps(self):
259259
np.testing.assert_allclose(result['2m_temperature'].values, expected)
260260

261261

262+
class GaussianVarianceTest(parameterized.TestCase):
263+
264+
def test_gaussian_variance(self):
265+
forecast = schema.mock_forecast_data(
266+
variables_3d=[],
267+
variables_2d=['2m_temperature', '2m_temperature_std'],
268+
time_start='2022-01-01',
269+
time_stop='2022-01-02',
270+
lead_stop='1 day',
271+
)
272+
truth = schema.mock_truth_data(
273+
variables_3d=[],
274+
variables_2d=['2m_temperature'],
275+
time_start='2022-01-01',
276+
time_stop='2022-01-20',
277+
)
278+
forecast['2m_temperature_std'] = forecast['2m_temperature_std'] + 1.0
279+
result = metrics.GaussianVariance().compute(forecast, truth)
280+
expected = np.array([1.0, 1.0])
281+
np.testing.assert_allclose(result['2m_temperature'].values, expected)
282+
283+
262284
class RankHistogramTest(parameterized.TestCase):
263285

264286
@parameterized.named_parameters(

0 commit comments

Comments
 (0)