Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 580212540
  • Loading branch information
Vivian-X-Y authored and Weatherbench2 authors committed Nov 7, 2023
1 parent 8d7db4b commit 0fa2cba
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
60 changes: 60 additions & 0 deletions weatherbench2/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,66 @@ def _rank_da(da: xr.DataArray) -> np.ndarray:
return ds.copy(data={k: _rank_da(v) for k, v in ds.items()})


@dataclasses.dataclass
class GaussianCRPS(Metric):
"""The spread measure associated with CRPS, E|X - X'|."""

def compute_chunk(
self,
forecast: xr.Dataset,
truth: xr.Dataset,
region: t.Optional[Region] = None,
) -> xr.Dataset:
"""CRPSSpread, averaged over space, for a time chunk of data."""
return _spatial_average(
_pointwise_gaussian_crps(forecast, truth),
region=region,
)


def _pointwise_gaussian_crps(
forecast: xr.Dataset, truth: xr.Dataset
) -> xr.Dataset:
r"""Returns pointwise CRPS of a Gaussian distribution with mean and std values.
CRPS of a Gaussian distribution with mean value m and standard deviation s
can be computed as
CRPS(F_(m,s), y) = s * {(y-m)/s * [2G((y-m)/s) - 1] + 2g((y-m)/s) -
1/\sqrt(\pi))}
where G and g denote the CDF and PDF of a standard Gaussian distribution,
respectively.
References:
[Gneiting, Raftery, Westveld III, Goldman, 2005], Calibrated Probabilistic
Forecasting Using Ensemble Model Output Statistics and Minimum CRPS Estimation
DOI: https://doi.org/10.1175/MWR2904.1
Args:
forecast: A forecast dataset.
truth: A ground truth dataset.
Returns:
xr.Dataset: Pointwise calculated crps for a Gaussian distribution.
"""
var_list = []
dataset = {}
for var in forecast.keys():
if f"{var}_std" in forecast.keys():
var_list.append(var)
for var_name in var_list:
norm_diff = (forecast[var_name] - truth[var_name]) / forecast[
f"{var_name}_std"
]
value = forecast[f"{var_name}_std"] * (
norm_diff * (2 * xr.apply_ufunc(stats.norm.cdf, norm_diff.load()) - 1)
+ 2 * xr.apply_ufunc(stats.norm.pdf, norm_diff.load())
- 1 / np.sqrt(np.pi)
)
dataset[var_name] = value
return xr.Dataset(dataset, coords=forecast.coords)


@dataclasses.dataclass
class EnsembleStddev(EnsembleMetric):
"""The standard deviation of an ensemble of forecasts.
Expand Down
23 changes: 23 additions & 0 deletions weatherbench2/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,29 @@ def test_repeated_forecasts_are_okay(self):
)


class GaussianCRPSTest(parameterized.TestCase):

def test_gaussian_crps(self):
forecast = schema.mock_forecast_data(
variables_3d=[],
variables_2d=['2m_temperature', '2m_temperature_std'],
time_start='2022-01-01',
time_stop='2022-01-02',
lead_stop='1 day',
)
truth = schema.mock_truth_data(
variables_3d=[],
variables_2d=['2m_temperature'],
time_start='2022-01-01',
time_stop='2022-01-20',
)
forecast = forecast + 1.0
truth = truth + 1.02
result = metrics.GaussianCRPS().compute(forecast, truth)
expected = np.array([0.23385455, 0.23385455])
np.testing.assert_allclose(result['2m_temperature'].values, expected)


class RankHistogramTest(parameterized.TestCase):

@parameterized.named_parameters(
Expand Down

0 comments on commit 0fa2cba

Please sign in to comment.