@@ -135,8 +135,8 @@ def _spatial_average_l2_norm(
135
135
136
136
137
137
@dataclasses .dataclass
138
- class WindVectorRMSE (Metric ):
139
- """Compute wind vector RMSE . See WB2 paper for definition.
138
+ class WindVectorMSE (Metric ):
139
+ """Compute wind vector mean square error . See WB2 paper for definition.
140
140
141
141
Attributes:
142
142
u_name: Name of U component.
@@ -155,25 +155,57 @@ def compute_chunk(
155
155
region : t .Optional [Region ] = None ,
156
156
) -> xr .Dataset :
157
157
diff = forecast - truth
158
- result = np .sqrt (
159
- _spatial_average (
160
- diff [self .u_name ] ** 2 + diff [self .v_name ] ** 2 ,
161
- region = region ,
162
- )
158
+ result = _spatial_average (
159
+ diff [self .u_name ] ** 2 + diff [self .v_name ] ** 2 ,
160
+ region = region ,
163
161
)
164
162
return result
165
163
166
164
167
165
@dataclasses .dataclass
168
- class RMSE (Metric ):
166
+ class WindVectorRMSESqrtBeforeTimeAvg (Metric ):
167
+ """Compute wind vector RMSE. See WB2 paper for definition.
168
+
169
+ This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
170
+ Most users will prefer to use WindVectorMSE and then take a square root in
171
+ user code after running the evaluate script.
172
+
173
+ Attributes:
174
+ u_name: Name of U component.
175
+ v_name: Name of V component.
176
+ vector_name: Name of wind vector to be computed.
177
+ """
178
+
179
+ u_name : str
180
+ v_name : str
181
+ vector_name : str
182
+
183
+ def compute_chunk (
184
+ self ,
185
+ forecast : xr .Dataset ,
186
+ truth : xr .Dataset ,
187
+ region : t .Optional [Region ] = None ,
188
+ ) -> xr .Dataset :
189
+ mse = WindVectorMSE (
190
+ u_name = self .u_name , v_name = self .v_name , vector_name = self .vector_name
191
+ ).compute_chunk (forecast , truth , region = region )
192
+ return np .sqrt (mse )
193
+
194
+
195
+ @dataclasses .dataclass
196
+ class RMSESqrtBeforeTimeAvg (Metric ):
169
197
"""Root mean squared error.
170
198
199
+ This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
200
+ Most users will prefer to use MSE and then take a square root in user
201
+ code after running the evaluate script.
202
+
171
203
Attributes:
172
- wind_vector_rmse: Optionally provide list of WindVectorRMSE instances to
173
- compute.
204
+ wind_vector_rmse: Optionally provide list of WindVectorRMSESqrtBeforeTimeAvg
205
+ instances to compute.
174
206
"""
175
207
176
- wind_vector_rmse : t .Optional [list [WindVectorRMSE ]] = None
208
+ wind_vector_rmse : t .Optional [list [WindVectorRMSESqrtBeforeTimeAvg ]] = None
177
209
178
210
def compute_chunk (
179
211
self ,
@@ -192,15 +224,28 @@ def compute_chunk(
192
224
193
225
@dataclasses .dataclass
194
226
class MSE (Metric ):
195
- """Mean squared error."""
227
+ """Mean squared error.
228
+
229
+ Attributes:
230
+ wind_vector_mse: Optionally provide list of WindVectorMSE instances to
231
+ compute.
232
+ """
233
+
234
+ wind_vector_mse : t .Optional [list [WindVectorMSE ]] = None
196
235
197
236
def compute_chunk (
198
237
self ,
199
238
forecast : xr .Dataset ,
200
239
truth : xr .Dataset ,
201
240
region : t .Optional [Region ] = None ,
202
241
) -> xr .Dataset :
203
- return _spatial_average ((forecast - truth ) ** 2 , region = region )
242
+ results = _spatial_average ((forecast - truth ) ** 2 , region = region )
243
+ if self .wind_vector_mse is not None :
244
+ for wv in self .wind_vector_mse :
245
+ results [wv .vector_name ] = wv .compute_chunk (
246
+ forecast , truth , region = region
247
+ )
248
+ return results
204
249
205
250
206
251
@dataclasses .dataclass
@@ -717,14 +762,15 @@ def _pointwise_gaussian_crps(
717
762
718
763
719
764
@dataclasses .dataclass
720
- class EnsembleStddev (EnsembleMetric ):
765
+ class EnsembleStddevSqrtBeforeTimeAvg (EnsembleMetric ):
721
766
"""The standard deviation of an ensemble of forecasts.
722
767
723
- This forms the SPREAD component of the traditional spread-skill-ratio. See
724
- [Garg & Rasp & Thuerey, 2022].
768
+ This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
769
+ Most users will prefer to use EnsembleVariance and then take a square root in
770
+ user code after running the evaluate script.
725
771
726
772
Given predictive ensemble Xₜ at times t = (1,..., T),
727
- EnsembleStddev := (1 / T) Σₜ ‖σ(Xₜ)‖
773
+ EnsembleStddevSqrtBeforeTimeAvg := (1 / T) Σₜ ‖σ(Xₜ)‖
728
774
Above σ(Xₜ) is element-wise standard deviation, and ‖⋅‖ is an area-weighted
729
775
L2 norm.
730
776
@@ -737,15 +783,6 @@ class EnsembleStddev(EnsembleMetric):
737
783
738
784
NaN values propagate through and result in NaN in the corresponding output
739
785
position.
740
-
741
- We use the unbiased estimator of σ(Xₜ) (dividing by n_ensemble - 1). If
742
- n_ensemble = 1, we return zero for the stddev. This choice allows
743
- EnsembleStddev to behave in the spread-skill-ratio as expected.
744
-
745
- References:
746
- [Garg & Rasp & Thuerey, 2022], WeatherBench Probability: A benchmark dataset
747
- for probabilistic medium-range weather forecasting along with deep learning
748
- baseline models.
749
786
"""
750
787
751
788
def compute_chunk (
@@ -754,7 +791,7 @@ def compute_chunk(
754
791
truth : xr .Dataset ,
755
792
region : t .Optional [Region ] = None ,
756
793
) -> xr .Dataset :
757
- """EnsembleStddev , averaged over space, for a time chunk of data."""
794
+ """Ensemble Stddev , averaged over space, for a time chunk of data."""
758
795
del truth # unused
759
796
n_ensemble = _get_n_ensemble (forecast , self .ensemble_dim )
760
797
@@ -825,15 +862,16 @@ def compute_chunk(
825
862
826
863
827
864
@dataclasses .dataclass
828
- class EnsembleMeanRMSE (EnsembleMetric ):
865
+ class EnsembleMeanRMSESqrtBeforeTimeAvg (EnsembleMetric ):
829
866
"""RMSE between the ensemble mean and ground truth.
830
867
831
- This forms the SKILL component of the traditional spread-skill-ratio. See
832
- [Garg & Rasp & Thuerey, 2022].
868
+ This SqrtBeforeTimeAvg metric takes a square root before any time averaging.
869
+ Most users will prefer to use EnsembleMeanMSE and then take a square root in
870
+ user code after running the evaluate script.
833
871
834
872
Given ground truth Yₜ, and predictive ensemble Xₜ, both at times
835
873
t = (1,..., T),
836
- EnsembleMeanRMSE := (1 / T) Σₜ ‖Y - E(Xₜ)‖.
874
+ EnsembleMeanRMSESqrtBeforeTimeAvg := (1 / T) Σₜ ‖Y - E(Xₜ)‖.
837
875
Above, `E` is ensemble average, and ‖⋅‖ is an area-weighted L2 norm.
838
876
839
877
Estimation is done separately for each tendency, level, and lag time.
@@ -845,11 +883,6 @@ class EnsembleMeanRMSE(EnsembleMetric):
845
883
846
884
NaN values propagate through and result in NaN in the corresponding output
847
885
position.
848
-
849
- References:
850
- [Garg & Rasp & Thuerey, 2022], WeatherBench Probability: A benchmark dataset
851
- for probabilistic medium-range weather forecasting along with deep learning
852
- baseline models.
853
886
"""
854
887
855
888
def compute_chunk (
@@ -1005,7 +1038,7 @@ def compute_chunk(
1005
1038
1006
1039
1007
1040
# TODO(shoyer): Consider adding WindVectorEnergyScore based on a pair of wind
1008
- # components, as a sort of probabilistic variant of WindVectorRMSE .
1041
+ # components, as a sort of probabilistic variant of WindVectorMSE .
1009
1042
1010
1043
1011
1044
@dataclasses .dataclass
0 commit comments