Skip to content

Commit

Permalink
Add SOURCE_TIME variable to probabilistic climatological forecasts. T…
Browse files Browse the repository at this point in the history
…his allows

users to determine what time in the source data set was responsible for the
"forecast".

By default, `--add_source_time=False`, preventing unexpected variables being included in downstream metrics. Of course users can always specify the `--variables` they want to use for individual metrics and avoid this issue. However, it's best if the defaults are safe.

PiperOrigin-RevId: 703656568
  • Loading branch information
langmore authored and Weatherbench2 authors committed Dec 7, 2024
1 parent 90550ae commit 75fb2b9
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 31 deletions.
42 changes: 39 additions & 3 deletions scripts/compute_probabilistic_climatological_forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
)
TIMEDELTA_SPACING = flags.DEFINE_string(
'timedelta_spacing',
'12h',
'6h',
help=(
'Distance between lead times in forecasts. Must be a multiple of'
' difference between times in INPUT. Must be a multiple or divisor of'
Expand All @@ -155,6 +155,16 @@
),
)

SOURCE_TIME = 'source_time'
ADD_SOURCE_TIME = flags.DEFINE_boolean(
'add_source_time',
False,
help=(
f'Whether to add a "{SOURCE_TIME}" variable, indicating what time in'
' INPUT_PATH was used for each output sample'
),
)

# Determines how to form ensembles.
DAY_WINDOW_SIZE = flags.DEFINE_integer(
'day_window_size',
Expand Down Expand Up @@ -496,9 +506,17 @@ def _emit_sampled_weather(
# output times to scatter it to. That's okay, we will Yield nothing.
for info in values['time_key_and_index_info']:
info = info.copy()
del info['sampled_time_value'] # Was only for ValueError printouts above.
output_ds = ds.copy()
sampled_time_value = info.pop('sampled_time_value')
if ADD_SOURCE_TIME.value:
output_ds[SOURCE_TIME] = xr.DataArray(
# Insert as a DataArray, which lets us assign the proper dims.
[sampled_time_value],
dims=TIME_DIM.value,
coords={TIME_DIM.value: ds[TIME_DIM.value]},
)
output_ds = (
ds.expand_dims({DELTA: [info.pop('timedelta_value')]})
output_ds.expand_dims({DELTA: [info.pop('timedelta_value')]})
.assign_coords({TIME_DIM.value: [info.pop('output_init_time_value')]})
.expand_dims({REALIZATION_NAME.value: [info[REALIZATION_NAME.value]]})
)
Expand Down Expand Up @@ -560,6 +578,24 @@ def main(argv: abc.Sequence[str]) -> None:
assert isinstance(input_ds, xr.Dataset) # To satisfy pytype.
if DELTA in input_ds.dims:
raise ValueError(f'INPUT_PATH data already had {DELTA} as a dimension')
if ADD_SOURCE_TIME.value:
input_ds = input_ds.assign(
# Assign SOURCE_TIME with an arbitrary DataArray of type datetime64[ns].
# Using a DataArray with time index is important: It ensures it will be
# stored as a data_var, and will get indices sliced/expanded below
# correctly. It is also important to not directly use input_ds.time.
{
# TODO(langmore) Remove the "+1" once Xarray bug is fixed;
# https://github.com/pydata/xarray/issues/9859
# Until then, assigning to input_ds[TIME_DIM.value] without the "+1"
# results in an error:
# ValueError: Cannot assign to the .data attribute of dimension
# coordinate a.k.a. IndexVariable 'time'. Instead, add 1 so it is a
# new variable.
SOURCE_TIME: (input_ds[TIME_DIM.value]
+ np.array(1, dtype='timedelta64[ns]')) # fmt: skip
}
)
template = (
xbeam.make_template(input_ds)
.isel({TIME_DIM.value: 0}, drop=True)
Expand Down
108 changes: 80 additions & 28 deletions scripts/compute_probabilistic_climatological_forecasts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,29 @@

from . import compute_probabilistic_climatological_forecasts as cpcf

SOURCE_TIME = cpcf.SOURCE_TIME


def assert_uniform(values, stderr_tol=4, expected_min=None, expected_max=None):
"""Asserts values counts uniformly distributed up to 4 standard errors."""
values = np.asarray(values)
if expected_min is not None:
np.testing.assert_array_equal(values.min(), expected_min)
if expected_max is not None:
np.testing.assert_array_equal(values.max(), expected_max)
counts = pd.Series(values).value_counts().sort_index()
ensemble_size = counts.sum()
fracs = counts / ensemble_size
expected_frac = 1 / len(counts)
standard_error = np.sqrt(expected_frac * (1 - expected_frac) / ensemble_size)
np.testing.assert_allclose(
fracs, expected_frac, atol=stderr_tol * standard_error
)


class GetSampledInitTimesTest(parameterized.TestCase):
"""Test this private method, mostly because the style guide says not to."""

def assert_uniform(self, values, stderr_tol):
"""Asserts values counts uniformly distributed up to 4 standard errors."""
counts = pd.Series(values).value_counts().sort_index()
ensemble_size = counts.sum()
fracs = counts / ensemble_size
expected_frac = 1 / len(counts)
standard_error = np.sqrt(
expected_frac * (1 - expected_frac) / ensemble_size
)
np.testing.assert_allclose(
fracs, expected_frac, atol=stderr_tol * standard_error
)

@parameterized.named_parameters(
dict(
testcase_name='WithReplacement_Ensemble50',
Expand Down Expand Up @@ -124,22 +130,19 @@ def test_sample_statistics(self, with_replacement, ensemble_size):
not with_replacement and ensemble_size == -1
)
if no_edge_effects:
self.assert_uniform(
assert_uniform(
perturbation.days,
stderr_tol=0 if expect_everything_sampled_once else 4,
)
self.assertEqual(perturbation.days.min(), -day_window_size // 2)
self.assertEqual(
perturbation.days.max(),
day_window_size // 2 + day_window_size % 2 - 1,
expected_min=-day_window_size // 2,
expected_max=day_window_size // 2 + day_window_size % 2 - 1,
)

# The years should be uniform.
years = pd.to_datetime(sampled_t).year
self.assertEqual(years.min(), climatology_start_year)
self.assertEqual(years.max(), climatology_end_year)
self.assert_uniform(
years, stderr_tol=0 if expect_everything_sampled_once else 4
assert_uniform(
pd.to_datetime(sampled_t).year,
stderr_tol=0 if expect_everything_sampled_once else 4,
expected_min=climatology_start_year,
expected_max=climatology_end_year,
)


Expand Down Expand Up @@ -170,7 +173,13 @@ def _make_dataset_that_grows_by_one_with_every_timedelta(

@parameterized.named_parameters(
dict(testcase_name='Default'),
dict(testcase_name='CustomTimeName', time_dim='init'),
# A larger ensemble better tests that the distribution of days is uniform.
dict(testcase_name='LargeEnsemble', ensemble_size=100),
dict(
testcase_name='CustomTimeNameNoSourceTime',
time_dim='init',
add_source_time=False,
),
dict(testcase_name='OddWindow', day_window_size=3),
dict(testcase_name='OutputIsLeapYearInFeb', output_leap_location='feb'),
dict(testcase_name='OutputIsLeapYearInDec', output_leap_location='dec'),
Expand Down Expand Up @@ -226,7 +235,11 @@ def test_standard_workflow(
data_year_hasleap=False,
custom_prediction_timedelta_chunk=False,
with_replacement=True,
# Note that all tests passed with ensemble_size=500.
# This is relevant, since the tolerance below is proportional to
# 1 / sqrt(ensemble_size).
ensemble_size=20,
add_source_time=True,
):
input_ds = self._make_dataset_that_grows_by_one_with_every_timedelta(
input_time_resolution=input_time_resolution,
Expand Down Expand Up @@ -287,6 +300,7 @@ def test_standard_workflow(
day_window_size=str(day_window_size),
ensemble_size=str(ensemble_size),
with_replacement=str(with_replacement).lower(),
add_source_time=str(add_source_time).lower(),
variables='temperature',
output_chunks=output_chunks_flag,
runner='DirectRunner',
Expand Down Expand Up @@ -327,7 +341,10 @@ def test_standard_workflow(
)

# Check variables (this is the exciting part!)
self.assertCountEqual(['temperature'], list(output_ds))
self.assertCountEqual(
['temperature'] + ([SOURCE_TIME] if add_source_time else []),
list(output_ds),
)

# Ensemble members differ.
np.testing.assert_array_less(0, output_ds.temperature.var('realization'))
Expand All @@ -338,14 +355,20 @@ def test_standard_workflow(
1, output_ds.temperature.diff('prediction_timedelta')
)

# Source time should also be contiguous
if add_source_time:
np.testing.assert_array_equal(
pd.to_timedelta(timedelta_spacing).to_numpy(),
output_ds[SOURCE_TIME].diff('prediction_timedelta').data,
)

timedeltas_in_a_year = pd.Timedelta(
f'{365 + output_dates_have_leap}d'
) / pd.Timedelta(timedelta_spacing)
timedeltas_in_a_day = pd.Timedelta('1d') / pd.Timedelta(timedelta_spacing)

# Check that the initial times output_t, came from input days of year within
# the specified window. Use the fact that temperature is growing at a rate
# of 1 for every timedelta.
# the specified window.
for region in [
dict(latitude=0, longitude=0, level=0),
dict(
Expand All @@ -359,6 +382,35 @@ def test_standard_workflow(
temperature = (
output_ds.isel(region).isel({time_dim: i_time}).temperature
)

# Precise check using SOURCE_TIME.
# Notice that this check always runs and requires the expected uniform
# distribution, regardless of leap year.
if add_source_time:
output_time = pd.to_datetime(
temperature.isel(prediction_timedelta=0)[time_dim].data
)
source_time = pd.to_datetime(
output_ds.isel(region)
.isel(prediction_timedelta=0)
.isel({time_dim: i_time})[SOURCE_TIME]
.data
)
assert_uniform(
source_time.year,
expected_min=climatology_start_year,
expected_max=climatology_end_year,
)
assert_uniform(
source_time.dayofyear,
expected_min=output_time.dayofyear - day_window_size // 2,
expected_max=output_time.dayofyear
+ day_window_size // 2
+ day_window_size % 2
- 1,
)

# Rough check using the values
# Since temperature is growing linearly at a rate of 1 for every
# timedelta, we expect a certain spread of temperatures...roughly equal
# to the day_window_size. There are edge effects due to way windows
Expand Down

0 comments on commit 75fb2b9

Please sign in to comment.