Skip to content

Commit

Permalink
[weatherbench2] Add GaussianVariance metric.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 592111655
  • Loading branch information
ilopezgp authored and Weatherbench2 authors committed Dec 19, 2023
1 parent 582d3d2 commit 2bfbeb7
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
31 changes: 29 additions & 2 deletions weatherbench2/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,15 +712,15 @@ def _rank_da(da: xr.DataArray) -> np.ndarray:

@dataclasses.dataclass
class GaussianCRPS(Metric):
"""The spread measure associated with CRPS, E|X - X'|."""
"""The analytical formulation of CRPS for a Gaussian."""

def compute_chunk(
self,
forecast: xr.Dataset,
truth: xr.Dataset,
region: t.Optional[Region] = None,
) -> xr.Dataset:
"""CRPSSpread, averaged over space, for a time chunk of data."""
"""GaussianCRPS, averaged over space, for a time chunk of data."""
return _spatial_average(
_pointwise_gaussian_crps(forecast, truth),
region=region,
Expand Down Expand Up @@ -770,6 +770,33 @@ def _pointwise_gaussian_crps(
return xr.Dataset(dataset, coords=forecast.coords)


@dataclasses.dataclass
class GaussianVariance(Metric):
"""The variance of a Gaussian forecast."""

def compute_chunk(
self,
forecast: xr.Dataset,
truth: xr.Dataset,
region: t.Optional[Region] = None,
) -> xr.Dataset:
"""GaussianVariance, averaged over space, for a time chunk of data."""
del truth # unused
var_list = []
dataset = {}
for var in forecast.keys():
if f"{var}_std" in forecast.keys():
var_list.append(var)
for var_name in var_list:
variance = forecast[f"{var_name}_std"] * forecast[f"{var_name}_std"]
dataset[var_name] = variance

return _spatial_average(
xr.Dataset(dataset, coords=forecast.coords),
region=region,
)


@dataclasses.dataclass
class EnsembleStddevSqrtBeforeTimeAvg(EnsembleMetric):
"""The standard deviation of an ensemble of forecasts.
Expand Down
22 changes: 22 additions & 0 deletions weatherbench2/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,28 @@ def test_gaussian_crps(self):
np.testing.assert_allclose(result['2m_temperature'].values, expected)


class GaussianVarianceTest(parameterized.TestCase):

def test_gaussian_variance(self):
forecast = schema.mock_forecast_data(
variables_3d=[],
variables_2d=['2m_temperature', '2m_temperature_std'],
time_start='2022-01-01',
time_stop='2022-01-02',
lead_stop='1 day',
)
truth = schema.mock_truth_data(
variables_3d=[],
variables_2d=['2m_temperature'],
time_start='2022-01-01',
time_stop='2022-01-20',
)
forecast['2m_temperature_std'] = forecast['2m_temperature_std'] + 1.0
result = metrics.GaussianVariance().compute(forecast, truth)
expected = np.array([1.0, 1.0])
np.testing.assert_allclose(result['2m_temperature'].values, expected)


class RankHistogramTest(parameterized.TestCase):

@parameterized.named_parameters(
Expand Down

0 comments on commit 2bfbeb7

Please sign in to comment.