Skip to content

Commit 3e4f392

Browse files
ilopezgpWeatherbench2 authors
authored and
Weatherbench2 authors
committed
[weatherbench2] Add auxiliary variables to config.Selection.
PiperOrigin-RevId: 597794742
1 parent d8b9b1a commit 3e4f392

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

weatherbench2/config.py

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class Selection:
3535
levels: List of pressure levels.
3636
lat_slice: Latitude range in degrees.
3737
lon_slice: Longitude range in degrees.
38+
aux_variables: Sequence of auxiliary forecast variables required for certain
39+
evaluation metrics.
3840
"""
3941

4042
variables: t.Sequence[str]
@@ -46,6 +48,7 @@ class Selection:
4648
lon_slice: t.Optional[slice] = dataclasses.field(
4749
default_factory=lambda: slice(None, None)
4850
)
51+
aux_variables: t.Optional[t.Sequence[str]] = None
4952

5053

5154
@dataclasses.dataclass

weatherbench2/evaluation.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,14 @@ def _impose_data_selection(
140140
selection: config.Selection,
141141
select_time: bool = True,
142142
time_dim: Optional[str] = None,
143+
select_aux: bool = False,
143144
) -> xr.Dataset:
144145
"""Returns selection of dataset specified in Selection instance."""
145-
dataset = dataset[selection.variables].sel(
146+
if select_aux and selection.aux_variables is not None:
147+
sel_variables = set(selection.variables) | set(selection.aux_variables)
148+
else:
149+
sel_variables = selection.variables
150+
dataset = dataset[sel_variables].sel(
146151
latitude=selection.lat_slice,
147152
longitude=selection.lon_slice,
148153
)
@@ -314,10 +319,12 @@ def open_forecast_and_truth_datasets(
314319
)
315320

316321
obs_all_times = _impose_data_selection(
317-
obs, data_config.selection, select_time=False
322+
obs,
323+
data_config.selection,
324+
select_time=False,
318325
)
319326
forecast_all_times = _impose_data_selection(
320-
forecast, data_config.selection, select_time=False
327+
forecast, data_config.selection, select_time=False, select_aux=True
321328
)
322329

323330
if data_config.by_init: # Will select appropriate chunks later
@@ -328,6 +335,7 @@ def open_forecast_and_truth_datasets(
328335
forecast,
329336
data_config.selection,
330337
time_dim='init_time' if data_config.by_init else 'time',
338+
select_aux=True,
331339
)
332340

333341
# Determine ground truth dataset

0 commit comments

Comments
 (0)