@@ -39,7 +39,7 @@ class Region:
39
39
40
40
def apply (
41
41
self , dataset : xr .Dataset , weights : xr .DataArray
42
- ) -> tuple [xr .Dataset , xr .Dataset ]:
42
+ ) -> tuple [xr .Dataset , xr .DataArray ]:
43
43
"""Apply region selection to dataset and/or weights.
44
44
45
45
Args:
@@ -48,8 +48,8 @@ def apply(
48
48
49
49
Returns:
50
50
dataset: Potentially modified (sliced) dataset.
51
- weights: Potentially modified weights dataset , to be used in combination
52
- with dataset, e.g. in _spatial_average().
51
+ weights: Potentially modified weights data array , to be used in
52
+ combination with dataset, e.g. in _spatial_average().
53
53
"""
54
54
raise NotImplementedError
55
55
@@ -128,6 +128,31 @@ def apply( # pytype: disable=signature-mismatch
128
128
) -> tuple [xr .Dataset , xr .DataArray ]:
129
129
"""Returns weights multiplied with a boolean land mask."""
130
130
land_weights = self .land_sea_mask
131
+ # Make sure lsm has same dtype for lat/lon
132
+ land_weights = land_weights .assign_coords (
133
+ latitude = land_weights .latitude .astype (dataset .latitude .dtype ),
134
+ longitude = land_weights .longitude .astype (dataset .longitude .dtype ),
135
+ )
131
136
if self .threshold is not None :
132
137
land_weights = (land_weights > self .threshold ).astype (float )
133
138
return dataset , weights * land_weights
139
+
140
+
141
+ @dataclasses .dataclass
142
+ class CombinedRegion (Region ):
143
+ """Sequentially applies regions selections.
144
+
145
+ Allows for combination of e.g. SliceRegion and LandRegion.
146
+
147
+ Attributes:
148
+ regions: List of Region instances
149
+ """
150
+
151
+ regions : list [Region ] = dataclasses .field (default_factory = list )
152
+
153
+ def apply ( # pytype: disable=signature-mismatch
154
+ self , dataset : xr .Dataset , weights : xr .DataArray
155
+ ) -> tuple [xr .Dataset , xr .DataArray ]:
156
+ for region in self .regions :
157
+ dataset , weights = region .apply (dataset , weights )
158
+ return dataset , weights
0 commit comments