Skip to content

Commit 14fe74a

Browse files
authored
Clessig/develop/embedding1 (ecmwf#227)
* - Removed some auxiliary comments that should not have been there - Simplified code by removing dead branch * Fixed bug in source_size() * Fixed problem with columns embedding mode.
1 parent d7a3ee1 commit 14fe74a

File tree

5 files changed

+56
-40
lines changed

5 files changed

+56
-40
lines changed

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,10 @@ def advance(self):
179179
###################################################
180180
def get_sources_size(self):
181181
return [
182-
ds[0].get_source_num_channels() + ds[0].get_geoinfo_size() + ds[0].get_coords_size()
182+
ds[0].get_source_num_channels()
183+
+ ds[0].get_geoinfo_size()
184+
+ ds[0].get_coords_size()
185+
+ self.tokenizer.get_size_time_embedding()
183186
for ds in self.streams_datasets
184187
]
185188

src/weathergen/datasets/tokenizer_forecast.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def __init__(self, healpix_level: int):
123123

124124
self.rng = np.random.default_rng(int(time.time()))
125125

126+
self.size_time_embedding = 6
127+
128+
def get_size_time_embedding(self) -> int:
129+
"""Get size of time embedding"""
130+
return self.size_time_embedding
131+
126132
def reset(self) -> None:
127133
self.rng = np.random.default_rng(int(time.time()))
128134

src/weathergen/datasets/tokenizer_masking.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def __init__(self, healpix_level: int):
123123

124124
self.rng = np.random.default_rng(int(time.time()))
125125

126+
self.size_time_embedding = 6
127+
128+
def get_size_time_embedding(self) -> int:
129+
"""Get size of time embedding"""
130+
return self.size_time_embedding
131+
126132
def reset(self) -> None:
127133
self.rng = np.random.default_rng(int(time.time()))
128134

src/weathergen/model/embeddings.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10-
10+
import numpy as np
1111
import torch
1212
from torch.utils.checkpoint import checkpoint
1313

@@ -45,6 +45,7 @@ def __init__(
4545
super(StreamEmbedTransformer, self).__init__()
4646

4747
self.num_tokens = num_tokens
48+
self.token_size = token_size
4849
self.num_channels = num_channels
4950
self.dim_in = token_size if mode == "channels" else num_channels
5051
self.dim_embed = dim_embed
@@ -56,8 +57,6 @@ def __init__(
5657

5758
norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm
5859

59-
self.embed = torch.nn.Linear(self.dim_in, self.dim_embed)
60-
6160
self.layers = torch.nn.ModuleList()
6261
for _ in range(self.num_blocks):
6362
self.layers.append(
@@ -80,6 +79,8 @@ def __init__(
8079
)
8180

8281
if mode == "channels":
82+
self.embed = torch.nn.Linear(self.dim_in, self.dim_embed)
83+
8384
if self.unembed_mode == "full":
8485
self.ln_final = norm(num_channels * self.dim_embed)
8586
self.unembed = torch.nn.Linear(
@@ -94,6 +95,11 @@ def __init__(
9495
dim_out = (self.num_tokens * self.dim_out - embed_size_centroids) // num_channels
9596
self.unembed = torch.nn.ModuleList(
9697
[torch.nn.Linear(dim_embed, dim_out) for _ in range(num_channels)]
98+
# [
99+
# torch.nn.Sequential(torch.nn.Linear(dim_embed, max(dim_embed//2,4*dim_out)),
100+
# torch.nn.GELU(),
101+
# torch.nn.Linear(max(dim_embed//2,4*dim_out), dim_out)) for _ in range(num_channels)
102+
# ]
97103
)
98104
self.ln_final = torch.nn.ModuleList([norm(dim_embed) for _ in range(num_channels)])
99105

@@ -103,9 +109,12 @@ def __init__(
103109
self.forward = self.forward_channels
104110

105111
elif mode == "columns":
112+
assert embed_size_centroids == 0
113+
self.embed = torch.nn.Linear(self.dim_in, self.dim_embed)
114+
106115
assert self.unembed_mode == "block" # only supported mode at the moment
107116
# padding needed if the unembedded columns cannot be concatenated to dim_out (e.g GPSRO)
108-
self.pad = (self.dim_out - embed_size_centroids) % token_size
117+
self.pad = self.dim_out % token_size
109118
self.out_pad = torch.nn.Parameter(torch.zeros(self.pad))
110119
self.unembed = torch.nn.Linear(
111120
self.dim_embed,
@@ -114,6 +123,13 @@ def __init__(
114123
self.ln_final = norm(dim_out)
115124
self.forward = self.forward_columns
116125

126+
# TODO: factorization when sqrt is not int
127+
dim1 = int(np.sqrt(dim_out))
128+
assert dim1 * dim1 == dim_out
129+
self.unembed1 = torch.nn.Linear(self.dim_embed, dim1)
130+
self.unembed_nonlin = torch.nn.GELU()
131+
self.unembed2 = torch.nn.Linear(self.token_size, dim1)
132+
117133
else:
118134
assert False
119135

@@ -135,7 +151,7 @@ def forward_channels(self, x_in, centroids):
135151
elif self.unembed_mode == "block":
136152
out = [
137153
checkpoint(ue, ln(x[:, i]), use_reentrant=False)
138-
for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=False))
154+
for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True))
139155
]
140156
out = torch.stack(out, dim=1).flatten(-2, -1)
141157
else:
@@ -153,27 +169,22 @@ def forward_channels(self, x_in, centroids):
153169

154170
return out
155171

156-
# @torch.compile( dynamic=True)
157172
def forward_columns(self, x_in, centroids):
158173
# embed provided input data
159174
x = positional_encoding_harmonic(checkpoint(self.embed, x_in, use_reentrant=False))
160175

161176
for layer in self.layers:
162177
x = checkpoint(layer, x, use_reentrant=False)
163178

164-
# append centroids
165-
# unembed and reshape
166-
out = checkpoint(self.unembed, x, use_reentrant=False)
167-
out = out.flatten(-2, -1).reshape(x.shape[0], self.num_tokens, -1)
168-
# TODO: unsqueeze will not work with num_tokens > 1
169-
out = torch.cat([out, self.embed_centroids(centroids).unsqueeze(1)], -1)
170-
# pad to uniform dim_out (that has to be uniform across streams)
171-
if self.pad > 0:
172-
out = torch.cat((out, self.out_pad.repeat((x.shape[0], self.num_tokens, 1))), -1)
173-
# also encode centroids with overlayed positional encoding
179+
out = checkpoint(self.unembed1, x, use_reentrant=False)
180+
out = self.unembed_nonlin(out)
181+
out = checkpoint(self.unembed2, out.transpose(-2, -1), use_reentrant=False)
182+
out = out.flatten(-2, -1).unsqueeze(1)
183+
184+
# final normalize and dropout
174185
out = self.dropout_final(self.ln_final(out))
175186

176-
return out
187+
return out.to(torch.float16)
177188

178189

179190
class StreamEmbedLinear(torch.nn.Module):

src/weathergen/model/model.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,13 @@ def __init__(self, cf, sources_size, targets_num_channels, targets_coords_size):
127127
def create(self):
128128
cf = self.cf
129129

130-
# KCT:iss130
131130
# separate embedding networks for differnt observation types
132131
self.embeds = EmbeddingEngine(cf, self.sources_size).create()
133132

134-
# KCT:iss130
135133
# local assimilation engine
136134
self.ae_local_blocks = LocalAssimilationEngine(cf).create()
137135

138136
##############
139-
# KCT:iss130
140137
# local -> global assimilation engine adapter
141138
self.ae_adapter = Local2GlobalAssimilationEngine(cf).create()
142139

@@ -168,12 +165,10 @@ def create(self):
168165
self.q_cells = torch.nn.Parameter(q_cells, requires_grad=True)
169166

170167
##############
171-
# KCT:iss130
172168
# global assimilation engine
173169
self.ae_global_blocks = GlobalAssimilationEngine(cf, self.num_healpix_cells).create()
174170

175171
###############
176-
# KCT:iss130
177172
# forecasting engine
178173
self.fe_blocks = ForecastingEngine(cf, self.num_healpix_cells).create()
179174

@@ -254,7 +249,6 @@ def create(self):
254249
else:
255250
self.pred_adapter_kv.append(torch.nn.Identity())
256251

257-
# KCT:iss130
258252
# target prediction engines
259253
tte = TargetPredictionEngine(
260254
cf,
@@ -558,28 +552,24 @@ def predict(self, model_params, fstep, tokens, streams_data, target_coords_idxs)
558552
zip(self.target_token_engines, self.pred_adapter_kv, strict=False)
559553
):
560554
si = self.cf.streams[ii]
561-
tro_type = si["target_readout"]["type"] if "type" in si["target_readout"] else "token"
562555
tc_embed = self.embed_target_coords[ii]
563556

564557
assert batch_size == 1
565558

566559
# embed token coords, concatenating along batch dimension (which is taking care of through
567560
# the varlen attention)
568-
if tro_type == "obs_value":
569-
tc_tokens = torch.cat(
570-
[
571-
checkpoint(
572-
tc_embed,
573-
streams_data[i_b][ii].target_coords[fstep],
574-
use_reentrant=False,
575-
)
576-
if len(streams_data[i_b][ii].target_coords[fstep].shape) > 1
577-
else streams_data[i_b][ii].target_coords[fstep]
578-
for i_b in range(len(streams_data))
579-
]
580-
)
581-
else:
582-
assert False
561+
tc_tokens = torch.cat(
562+
[
563+
checkpoint(
564+
tc_embed,
565+
streams_data[i_b][ii].target_coords[fstep],
566+
use_reentrant=False,
567+
)
568+
if len(streams_data[i_b][ii].target_coords[fstep].shape) > 1
569+
else streams_data[i_b][ii].target_coords[fstep]
570+
for i_b in range(len(streams_data))
571+
]
572+
)
583573

584574
if torch.isnan(tc_tokens).any():
585575
nn = si["name"]

0 commit comments

Comments
 (0)