Skip to content

Commit 1167715

Browse files
author
Weatherbench authors
committed
No public description
PiperOrigin-RevId: 580488801
1 parent 0fa2cba commit 1167715

File tree

3 files changed

+60
-5
lines changed

3 files changed

+60
-5
lines changed

scripts/evaluate.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from weatherbench2 import flag_utils
4848
from weatherbench2 import metrics
4949
from weatherbench2.derived_variables import DERIVED_VARIABLE_DICT
50+
from weatherbench2.regions import CombinedRegion
51+
from weatherbench2.regions import LandRegion
5052
from weatherbench2.regions import SliceRegion
5153
import xarray as xr
5254

@@ -91,7 +93,9 @@
9193
EVALUATE_CLIMATOLOGY = flags.DEFINE_bool(
9294
'evaluate_climatology',
9395
False,
94-
'Evaluate climatology forecast specified in climatology path',
96+
'Evaluate climatology forecast specified in climatology path. Note that'
97+
' this will not work for probabilistic evaluation. Please use the'
98+
' EVALUATE_PROBABILISTIC_CLIMATOLOGY flag.',
9599
)
96100
EVALUATE_PROBABILISTIC_CLIMATOLOGY = flags.DEFINE_bool(
97101
'evaluate_probabilistic_climatology',
@@ -122,6 +126,15 @@
122126
'predefined regions.'
123127
),
124128
)
129+
LSM_DATASET = flags.DEFINE_string(
130+
'lsm_dataset',
131+
None,
132+
help=(
133+
'Dataset containing land-sea-mask at same resolution of datasets to be'
134+
' evaluated. Required if region with land-sea-mask is picked. If None,'
135+
' defaults to observation dataset.'
136+
),
137+
)
125138
COMPUTE_SEEPS = flags.DEFINE_bool(
126139
'compute_seeps', False, 'Compute SEEPS for total_precipitation_24hr.'
127140
)
@@ -305,6 +318,23 @@ def main(argv: list[str]) -> None:
305318
'arctic': SliceRegion(lat_slice=slice(60, 90)),
306319
'antarctic': SliceRegion(lat_slice=slice(-90, -60)),
307320
}
321+
try:
322+
if LSM_DATASET.value:
323+
land_sea_mask = xr.open_zarr(LSM_DATASET.value)['land_sea_mask'].compute()
324+
else:
325+
land_sea_mask = xr.open_zarr(OBS_PATH.value)['land_sea_mask'].compute()
326+
land_regions = {
327+
'global_land': LandRegion(land_sea_mask=land_sea_mask),
328+
'extra-tropics_land': CombinedRegion(
329+
regions=[
330+
SliceRegion(lat_slice=[slice(None, -20), slice(20, None)]),
331+
LandRegion(land_sea_mask=land_sea_mask),
332+
]
333+
),
334+
}
335+
predefined_regions = predefined_regions | land_regions
336+
except KeyError:
337+
print('No land_sea_mask found.')
308338
if REGIONS.value == ['all']:
309339
regions = predefined_regions
310340
elif REGIONS.value is None:

weatherbench2/evaluation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def _evaluate(
678678
)
679679
forecast_pipeline |= beam.MapTuple(
680680
self._climatology_like_forecast_chunk,
681-
probabilistic_climatology=probabilistic_climatology,
681+
climatology=probabilistic_climatology,
682682
variables=variables,
683683
)
684684
elif self.eval_config.evaluate_persistence:

weatherbench2/regions.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class Region:
3939

4040
def apply(
4141
self, dataset: xr.Dataset, weights: xr.DataArray
42-
) -> tuple[xr.Dataset, xr.Dataset]:
42+
) -> tuple[xr.Dataset, xr.DataArray]:
4343
"""Apply region selection to dataset and/or weights.
4444
4545
Args:
@@ -48,8 +48,8 @@ def apply(
4848
4949
Returns:
5050
dataset: Potentially modified (sliced) dataset.
51-
weights: Potentially modified weights dataset, to be used in combination
52-
with dataset, e.g. in _spatial_average().
51+
weights: Potentially modified weights data array, to be used in
52+
combination with dataset, e.g. in _spatial_average().
5353
"""
5454
raise NotImplementedError
5555

@@ -128,6 +128,31 @@ def apply( # pytype: disable=signature-mismatch
128128
) -> tuple[xr.Dataset, xr.DataArray]:
129129
"""Returns weights multiplied with a boolean land mask."""
130130
land_weights = self.land_sea_mask
131+
# Make sure lsm has same dtype for lat/lon
132+
land_weights = land_weights.assign_coords(
133+
latitude=land_weights.latitude.astype(dataset.latitude.dtype),
134+
longitude=land_weights.longitude.astype(dataset.longitude.dtype),
135+
)
131136
if self.threshold is not None:
132137
land_weights = (land_weights > self.threshold).astype(float)
133138
return dataset, weights * land_weights
139+
140+
141+
@dataclasses.dataclass
142+
class CombinedRegion(Region):
143+
"""Sequentially applies regions selections.
144+
145+
Allows for combination of e.g. SliceRegion and LandRegion.
146+
147+
Attributes:
148+
regions: List of Region instances
149+
"""
150+
151+
regions: list[Region] = dataclasses.field(default_factory=list)
152+
153+
def apply( # pytype: disable=signature-mismatch
154+
self, dataset: xr.Dataset, weights: xr.DataArray
155+
) -> tuple[xr.Dataset, xr.DataArray]:
156+
for region in self.regions:
157+
dataset, weights = region.apply(dataset, weights)
158+
return dataset, weights

0 commit comments

Comments
 (0)