Skip to content

Commit 40198da

Browse files
shoyerWeatherbench2 authors
authored and
Weatherbench2 authors
committed
[wb2] add evaluation of WindVectorRMSE for (a)geostrophic wind
PiperOrigin-RevId: 574623237
1 parent fa6fde3 commit 40198da

File tree

4 files changed

+76
-31
lines changed

4 files changed

+76
-31
lines changed

scripts/evaluate.py

+23-22
Original file line numberDiff line numberDiff line change
@@ -217,28 +217,29 @@
217217
def _wind_vector_rmse():
218218
"""Defines Wind Vector RMSEs if U/V components are in variables."""
219219
wind_vector_rmse = []
220-
if (
221-
'u_component_of_wind' in VARIABLES.value
222-
and 'v_component_of_wind' in VARIABLES.value
223-
):
224-
wind_vector_rmse.append(
225-
metrics.WindVectorRMSE(
226-
u_name='u_component_of_wind',
227-
v_name='v_component_of_wind',
228-
vector_name='wind_vector',
229-
)
230-
)
231-
if (
232-
'10m_u_component_of_wind' in VARIABLES.value
233-
and '10m_v_component_of_wind' in VARIABLES.value
234-
):
235-
wind_vector_rmse.append(
236-
metrics.WindVectorRMSE(
237-
u_name='10m_u_component_of_wind',
238-
v_name='10m_v_component_of_wind',
239-
vector_name='10m_wind_vector',
240-
)
241-
)
220+
available = set(VARIABLES.value).union(DERIVED_VARIABLES.value)
221+
for u_name, v_name, vector_name in [
222+
('u_component_of_wind', 'v_component_of_wind', 'wind_vector'),
223+
('10m_u_component_of_wind', '10m_v_component_of_wind', '10m_wind_vector'),
224+
(
225+
'u_component_of_geostrophic_wind',
226+
'v_component_of_geostrophic_wind',
227+
'geostrophic_wind_vector',
228+
),
229+
(
230+
'u_component_of_ageostrophic_wind',
231+
'v_component_of_ageostrophic_wind',
232+
'ageostrophic_wind_vector',
233+
),
234+
]:
235+
if u_name in available and v_name in available:
236+
wind_vector_rmse.append(
237+
metrics.WindVectorRMSE(
238+
u_name=u_name,
239+
v_name=v_name,
240+
vector_name=vector_name,
241+
)
242+
)
242243
return wind_vector_rmse
243244

244245

scripts/evaluate_test.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ def _test(self, use_beam=True):
3030
'u_component_of_wind',
3131
'v_component_of_wind',
3232
]
33+
derived_variables = [
34+
'wind_speed',
35+
'u_component_of_ageostrophic_wind',
36+
'v_component_of_ageostrophic_wind',
37+
]
3338
variables_2d = ['2m_temperature']
3439
truth = schema.mock_truth_data(
3540
variables_3d=variables_3d,
@@ -49,7 +54,9 @@ def _test(self, use_beam=True):
4954
variables_2d=variables_2d,
5055
)
5156
climatology = climatology.assign(
52-
wind_speed=climatology['u_component_of_wind']
57+
wind_speed=climatology['u_component_of_wind'],
58+
u_component_of_ageostrophic_wind=climatology['u_component_of_wind'],
59+
v_component_of_ageostrophic_wind=climatology['u_component_of_wind'],
5360
)
5461

5562
truth_path = self.create_tempdir('truth').full_path
@@ -80,7 +87,7 @@ def _test(self, use_beam=True):
8087
eval_configs=','.join(eval_configs),
8188
use_beam=use_beam,
8289
variables=variables_3d + variables_2d,
83-
derived_variables=['wind_speed'],
90+
derived_variables=derived_variables,
8491
):
8592
evaluate.main([])
8693

@@ -91,12 +98,20 @@ def _test(self, use_beam=True):
9198
with self.subTest(config_name):
9299
results_path = os.path.join(output_dir, f'{config_name}.nc')
93100
actual = xarray.open_dataset(results_path)
101+
extra_out_vars = [
102+
'wind_speed',
103+
'wind_vector',
104+
'u_component_of_ageostrophic_wind',
105+
'v_component_of_ageostrophic_wind',
106+
'ageostrophic_wind_vector',
107+
]
94108
self.assertEqual(
95-
set(actual),
96-
set(variables_3d + variables_2d + ['wind_speed', 'wind_vector']),
109+
set(actual), set(variables_3d + variables_2d + extra_out_vars)
97110
)
98111
self.assertEqual(actual['geopotential'].sizes, expected_sizes_3d)
99112
self.assertEqual(actual['2m_temperature'].sizes, expected_sizes_2d)
113+
self.assertIn('wind_vector', actual)
114+
self.assertIn('ageostrophic_wind_vector', actual)
100115

101116
def test_in_memory(self):
102117
self._test(use_beam=False)

weatherbench2/derived_variables.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def _geostrophic_wind(
236236
2 * omega * np.sin(np.deg2rad(geopotential.coords['latitude']))
237237
)
238238
# Geostrophic wind is inf on the equator. We don't clip it to ensure that the
239-
# user makes an intentional choice about how handle these invalid values.
239+
# user makes an intentional choice about how handle these invalid values
240+
# (e.g., by evaluating over a region).
240241
return (
241242
-_d_dy(geopotential) / coriolis_parameter,
242243
+_d_dx(geopotential) / coriolis_parameter,
@@ -245,6 +246,7 @@ def _geostrophic_wind(
245246

246247
@dataclasses.dataclass
247248
class _GeostrophicWindVariable(DerivedVariable):
249+
"""Base class for geostrophic wind variables."""
248250
geopotential_name: str = 'geopotential'
249251

250252
@property
@@ -290,9 +292,8 @@ def compute(self, dataset: xr.Dataset) -> xr.DataArray:
290292

291293

292294
@dataclasses.dataclass
293-
class AgeostrophicWindSpeed(DerivedVariable):
294-
"""Calculate ageostrophic wind speed."""
295-
295+
class _AgeostrophicWindVariable(DerivedVariable):
296+
"""Base class for ageostrophic wind variables."""
296297
u_name: str = 'u_component_of_wind'
297298
v_name: str = 'v_component_of_wind'
298299
geopotential_name: str = 'geopotential'
@@ -306,13 +307,35 @@ def core_dims(self) -> t.Tuple[t.Tuple[t.List[str], ...], t.List[str]]:
306307
lon_lat = ['longitude', 'latitude']
307308
return (lon_lat, lon_lat, lon_lat), lon_lat
308309

310+
311+
class AgeostrophicWindSpeed(_AgeostrophicWindVariable):
312+
"""Calculate ageostrophic wind speed."""
313+
309314
def compute(self, dataset: xr.Dataset) -> xr.DataArray:
310315
u = dataset[self.u_name]
311316
v = dataset[self.v_name]
312317
u_geo, v_geo = _geostrophic_wind(dataset[self.geopotential_name])
313318
return np.sqrt((u - u_geo) ** 2 + (v - v_geo) ** 2)
314319

315320

321+
class UComponentOfAgeostrophicWind(_AgeostrophicWindVariable):
322+
"""East-west component of geostrophic wind."""
323+
324+
def compute(self, dataset: xr.Dataset) -> xr.DataArray:
325+
u = dataset[self.u_name]
326+
u_geo, _ = _geostrophic_wind(dataset[self.geopotential_name])
327+
return u - u_geo
328+
329+
330+
class VComponentOfAgeostrophicWind(_AgeostrophicWindVariable):
331+
"""North-south component of geostrophic wind."""
332+
333+
def compute(self, dataset: xr.Dataset) -> xr.DataArray:
334+
v = dataset[self.v_name]
335+
_, v_geo = _geostrophic_wind(dataset[self.geopotential_name])
336+
return v - v_geo
337+
338+
316339
@dataclasses.dataclass
317340
class LapseRate(DerivedVariable):
318341
"""Compute lapse rate in temperature."""
@@ -704,6 +727,8 @@ def compute(self, dataset: xr.Dataset):
704727
'u_component_of_geostrophic_wind': UComponentOfGeostrophicWind(),
705728
'v_component_of_geostrophic_wind': VComponentOfGeostrophicWind(),
706729
'ageostrophic_wind_speed': AgeostrophicWindSpeed(),
730+
'u_component_of_ageostrophic_wind': UComponentOfAgeostrophicWind(),
731+
'v_component_of_ageostrophic_wind': VComponentOfAgeostrophicWind(),
707732
'lapse_rate': LapseRate(),
708733
'total_column_vapor': TotalColumnWater(
709734
water_species_name='specific_humidity'

weatherbench2/metrics.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class WindVectorRMSE(Metric):
140140
141141
Attributes:
142142
u_name: Name of U component.
143-
v_name: Name of v component.
143+
v_name: Name of V component.
144144
vector_name: Name of wind vector to be computed.
145145
"""
146146

@@ -793,3 +793,7 @@ def compute_chunk(
793793
"""Energy score skill, averaged over space, for a time chunk of data."""
794794
_get_n_ensemble(forecast, self.ensemble_dim) # Will raise if no ensembles.
795795
return _spatial_average_l2_norm(forecast - truth).mean(self.ensemble_dim)
796+
797+
798+
# TODO(shoyer): Consider adding WindVectorEnergyScore based on a pair of wind
799+
# components, as a sort of probabilistic variant of WindVectorRMSE.

0 commit comments

Comments
 (0)