Skip to content

Commit 23435d2

Browse files
authored
Refactor MultiStreamDataSampler and data loaders (ecmwf#102)
* Initial commit to enable support for source and target channel selection. This requires a substantial refactoring, which was anyway necessary to modularize the code. In particular, the batch construction for source and target has been moved completely to the Batchifyer, which finally simplifies the MultiStreamDataSampler to a reasonable size. This change should also facilitate the re-introduction of masked token modeling. * Simplified interface of ctor of MultiStreamDataSampler by passing config and using it wherever possible. * Cleaned up interface * Adding more streams---but they need to properly integrated to ensure all functionality is supported. * - Fixed various issues in FESOM dataloader and made it compliant with changes in MultiStreamdataSampler - ruff-ing * Initial commit to enable support for source and target channel selection. This requires a substantial refactoring, which was anyway necessary to modularize the code. In particular, the batch construction for source and target has been moved completely to the Batchifyer, which finally simplifies the MultiStreamDataSampler to a reasonable size. This change should also facilitate the re-introduction of masked token modeling. * Simplified interface of ctor of MultiStreamDataSampler by passing config and using it wherever possible. * Cleaned up interface * Adding more streams---but they need to properly integrated to ensure all functionality is supported. * - Fixed various issues in FESOM dataloader and made it compliant with changes in MultiStreamdataSampler - ruff-ing * Updates
1 parent d3f8c78 commit 23435d2

File tree

12 files changed

+835
-555
lines changed

12 files changed

+835
-555
lines changed

config/streams/streams_anemoi/era5.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
ERA5 :
1111
type : anemoi
1212
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6.zarr']
13+
# source : ['u_', 'v_', '10u', '10v']
14+
# target : ['10u', '10v']
1315
loss_weight : 1.
14-
source_variables : [null]
15-
target_variables : [null]
1616
diagnostic : False
1717
masking_rate : 0.6
1818
masking_rate_none : 0.05
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# (C) Copyright 2024 WeatherGenerator contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
FESOM :
11+
type : fesom
12+
filenames : ['coupled_yearly']
13+
loss_weight : 1.
14+
source : null
15+
target : ['sst']
16+
masking_rate : 0.6
17+
masking_rate_none : 0.05
18+
token_size : 64
19+
embed :
20+
net : transformer
21+
num_tokens : 1
22+
num_heads : 2
23+
dim_embed : 256
24+
num_blocks : 2
25+
embed_target_coords :
26+
net : linear
27+
dim_embed : 256
28+
target_readout :
29+
type : 'obs_value' # token or obs_value
30+
num_layers : 2
31+
num_heads : 4
32+
# sampling_rate : 0.2
33+
pred_head :
34+
ens_size : 1
35+
num_layers : 1

src/weathergen/datasets/anemoi_dataset.py

Lines changed: 157 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@ class AnemoiDataset:
1818

1919
def __init__(
2020
self,
21-
filename: str,
2221
start: int,
2322
end: int,
2423
len_hrs: int,
25-
step_hrs: int = None,
26-
normalize: bool = True,
27-
select: list[str] = None,
24+
step_hrs: int,
25+
filename: str,
26+
stream_info: dict,
2827
) -> None:
28+
# TODO: add support for different normalization modes
29+
2930
assert len_hrs == step_hrs, "Currently only step_hrs=len_hrs is supported"
3031

3132
# open dataset to peak that it is compatible with requested parameters
@@ -40,30 +41,61 @@ def __init__(
4041
dt_start = datetime.datetime.strptime(str(start), format_str)
4142
dt_end = datetime.datetime.strptime(str(end), format_str)
4243

44+
# TODO, TODO, TODO: we need proper alignment for the case where self.ds.frequency
45+
# is not a multile of len_hrs
46+
self.num_steps_per_window = int((len_hrs * 3600) / self.ds.frequency.seconds)
47+
4348
# open dataset
4449

4550
# caches lats and lons
4651
self.latitudes = self.ds.latitudes.astype(np.float32)
4752
self.longitudes = self.ds.longitudes.astype(np.float32)
4853

49-
# find physical fields (i.e. filter out auxiliary information to facilitate prediction)
50-
self.fields_idx = np.sort(
54+
# TODO: define in base class
55+
self.geoinfo_idx = []
56+
57+
# Determine source and target channels, filtering out forcings etc and using
58+
# specified source and target channels if specified
59+
source_channels = stream_info["source"] if "source" in stream_info else None
60+
self.source_idx = np.sort(
5161
[
5262
self.ds.name_to_index[k]
5363
for i, (k, v) in enumerate(self.ds.typed_variables.items())
54-
if not v.is_computed_forcing and not v.is_constant_in_time
64+
if (
65+
not v.is_computed_forcing
66+
and not v.is_constant_in_time
67+
and (
68+
np.array([f in k for f in source_channels]).any()
69+
if source_channels
70+
else True
71+
)
72+
)
5573
]
5674
)
57-
# TODO: use complement of self.fields_idx as geoinfo
58-
self.fields = [self.ds.variables[i] for i in self.fields_idx]
59-
self.colnames = ["lat", "lon"] + self.fields
60-
self.selected_colnames = self.colnames
75+
target_channels = stream_info["target"] if "target" in stream_info else None
76+
self.target_idx = np.sort(
77+
[
78+
self.ds.name_to_index[k]
79+
for i, (k, v) in enumerate(self.ds.typed_variables.items())
80+
if (
81+
not v.is_computed_forcing
82+
and not v.is_constant_in_time
83+
and (
84+
np.array([f in k for f in target_channels]).any()
85+
if target_channels
86+
else True
87+
)
88+
)
89+
]
90+
)
91+
self.source_channels = [self.ds.variables[i] for i in self.source_idx]
92+
self.target_channels = [self.ds.variables[i] for i in self.target_idx]
6193

6294
self.properties = {
63-
"obs_id": 0,
64-
"means": self.ds.statistics["mean"],
65-
"vars": np.square(self.ds.statistics["stdev"]),
95+
"stream_id": 0,
6696
}
97+
self.mean = self.ds.statistics["mean"]
98+
self.stdev = self.ds.statistics["stdev"]
6799

68100
# set dataset to None when no overlap with time range
69101
if dt_start >= ds_dt_end or dt_end <= ds_dt_start:
@@ -80,26 +112,128 @@ def __len__(self):
80112

81113
return len(self.ds)
82114

83-
def __getitem__(self, idx: int) -> tuple:
84-
"Get (data,datetime) for given index"
115+
def get_source(self, idx: int) -> tuple[np.array, np.array, np.array, np.array]:
116+
"""
117+
TODO
118+
"""
119+
return self._get(idx, self.source_idx)
120+
121+
def get_target(self, idx: int) -> tuple[np.array, np.array, np.array, np.array]:
122+
"""
123+
TODO
124+
"""
125+
return self._get(idx, self.target_idx)
126+
127+
def _get(
128+
self, idx: int, channels_idx: np.array
129+
) -> tuple[np.array, np.array, np.array, np.array]:
130+
"""
131+
TODO
132+
"""
85133

86134
if not self.ds:
87-
return (np.array([], dtype=np.float32), np.array([], dtype=np.float32))
135+
return (
136+
np.array([], dtype=np.float32),
137+
np.array([], dtype=np.float32),
138+
np.array([], dtype=np.float32),
139+
np.array([], dtype=np.float32),
140+
)
141+
142+
# extract number of time steps and collapse ensemble dimension
143+
data = self.ds[idx : idx + self.num_steps_per_window][:, :, 0]
144+
# extract channels
145+
data = (
146+
data[:, channels_idx].transpose([0, 2, 1]).reshape((data.shape[0] * data.shape[2], -1))
147+
)
88148

89-
# prepend lat and lon to data; squeeze out ensemble dimension (for the moment)
90-
data = np.concatenate(
149+
# construct lat/lon coords
150+
latlon = np.concatenate(
91151
[
92152
np.expand_dims(self.latitudes, 0),
93153
np.expand_dims(self.longitudes, 0),
94-
self.ds[idx].squeeze(),
95154
],
96155
0,
97156
).transpose()
157+
latlon = np.repeat(latlon, self.num_steps_per_window, axis=0).reshape((-1, latlon.shape[1]))
98158

99-
# date time matching #data points of data
100-
datetimes = np.full(data.shape[0], self.ds.dates[idx])
159+
# empty geoinfos for anemoi
160+
geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype)
101161

102-
return (data, datetimes)
162+
# date time matching #data points of data
163+
datetimes = np.repeat(
164+
np.expand_dims(self.ds.dates[idx : idx + self.num_steps_per_window], 0),
165+
data.shape[0],
166+
axis=0,
167+
).flatten()
168+
169+
return (latlon, geoinfos, data, datetimes)
170+
171+
def get_source_size(self):
172+
"""
173+
TODO
174+
"""
175+
return 2 + len(self.geoinfo_idx) + len(self.source_idx)
176+
177+
def get_source_num_channels(self):
178+
"""
179+
TODO
180+
"""
181+
return len(self.source_idx)
182+
183+
def get_target_size(self):
184+
"""
185+
TODO
186+
"""
187+
return 2 + len(self.geoinfo_idx) + len(self.target_idx)
188+
189+
def get_target_num_channels(self):
190+
"""
191+
TODO
192+
"""
193+
return len(self.target_idx)
194+
195+
def get_geoinfo_size(self):
196+
"""
197+
TODO
198+
"""
199+
return len(self.geoinfo_idx)
200+
201+
def normalize_coords(self, coords):
202+
"""
203+
TODO
204+
"""
205+
coords[..., 0] = np.sin(np.deg2rad(coords[..., 0]))
206+
coords[..., 1] = np.sin(0.5 * np.deg2rad(coords[..., 1]))
207+
208+
return coords
209+
210+
def normalize_geoinfos(self, geoinfos):
211+
"""
212+
TODO
213+
"""
214+
215+
assert geoinfos.shape[-1] == 0
216+
return geoinfos
217+
218+
def normalize_source_channels(self, source):
219+
"""
220+
TODO
221+
"""
222+
assert source.shape[1] == len(self.source_idx)
223+
for i, ch in enumerate(self.source_idx):
224+
source[..., i] = (source[..., i] - self.mean[ch]) / self.stdev[ch]
225+
226+
return source
227+
228+
def normalize_target_channels(self, target):
229+
"""
230+
TODO
231+
"""
232+
assert target.shape[1] == len(self.target_idx)
233+
for i, ch in enumerate(self.target_idx):
234+
target[..., i] = (target[..., i] - self.mean[ch]) / self.stdev[ch]
235+
236+
return target
103237

104238
def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]:
105239
if not self.ds:
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from datetime import datetime
2+
3+
import numpy as np
4+
import zarr
5+
6+
7+
class AtmorepDataset:
8+
def __init__(
9+
self,
10+
filename: str,
11+
start: datetime | int,
12+
end: datetime | int,
13+
len_hrs: int,
14+
step_hrs: int | None = None,
15+
normalize: bool = True,
16+
select: list[str] | None = None,
17+
):
18+
format_str = "%Y%m%d%H%M%S"
19+
if type(start) is int:
20+
start = datetime.strptime(str(start), format_str)
21+
22+
if type(end) is int:
23+
end = datetime.strptime(str(end), format_str)
24+
25+
self.normalize = normalize
26+
self.filename = filename
27+
self.z = zarr.open(filename, mode="r")
28+
29+
self.lats, self.lons = np.meshgrid(np.array(self.z["lats"]), np.array(self.z["lons"]))
30+
self.lats = self.lats.flatten()
31+
self.lons = self.lons.flatten()
32+
# Reshape lats and lons to be in shape (1, len_hrs, size_lat * size_lon), ready to added to data
33+
self.lats = np.expand_dims(np.stack((self.lats,) * len_hrs, axis=1).T, 0)
34+
self.lons = np.expand_dims(np.stack((self.lons,) * len_hrs, axis=1).T, 0)
35+
36+
self.time = np.array(self.z["time"], dtype=np.datetime64)
37+
self.start_idx = np.searchsorted(self.time, start)
38+
self.end_idx = np.searchsorted(self.time, end)
39+
40+
assert self.end_idx > self.start_idx, (
41+
f"Abort: Final index of {self.end_idx} is the same of larger than start index {self.start_idx}"
42+
)
43+
44+
self.colnames = ["lat", "lon"] + list(self.z.attrs["fields"])
45+
self.len_hrs = len_hrs
46+
# Ignore step_hrs, idk how it supposed to work
47+
self.step_hrs = 1
48+
49+
self.selected_colnames = self.colnames[2:]
50+
self.selected_cols_idx = np.arange(len(self.selected_colnames))
51+
self.data = self.z["data"]
52+
53+
self.properties = {
54+
"obs_id": 0,
55+
"means": np.zeros(len(self.colnames), dtype=np.float32),
56+
"vars": np.ones(len(self.colnames), dtype=np.float32),
57+
}
58+
59+
if select:
60+
self.select(select)
61+
62+
def select(self, cols_list: list[str]) -> None:
63+
"""
64+
Allow user to specify which columns they want to access.
65+
Get functions only returned for these specified columns.
66+
"""
67+
self.selected_colnames = cols_list
68+
self.selected_cols_idx = np.array([self.colnames.index(item) for item in cols_list])
69+
70+
def __len__(self):
71+
return self.end_idx - self.start_idx - self.len_hrs
72+
73+
def __getitem__(self, idx: int) -> tuple:
74+
start_row = self.start_idx + idx
75+
end_row = start_row + self.len_hrs
76+
77+
data = self.data.oindex[start_row:end_row, :, 0, :, :]
78+
datetimes = np.tile(self.time[start_row:end_row], data.shape[-1] * data.shape[-2])
79+
80+
data = np.reshape(data, (data.shape[1], data.shape[0], -1))
81+
data = np.concatenate([self.lats, self.lons, data], 0).T
82+
data = np.reshape(data, (-1, data.shape[-1]))
83+
84+
return (data.astype(np.float32), datetimes)
85+
86+
def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]:
87+
start_row = self.start_idx + idx
88+
end_row = start_row + self.len_hrs
89+
return (self.time[start_row], self.time[end_row])

0 commit comments

Comments
 (0)