7
7
# granted to it by virtue of its status as an intergovernmental organisation
8
8
# nor does it submit to any jurisdiction.
9
9
10
-
10
+ import numpy as np
11
11
import torch
12
12
from torch .utils .checkpoint import checkpoint
13
13
@@ -45,6 +45,7 @@ def __init__(
45
45
super (StreamEmbedTransformer , self ).__init__ ()
46
46
47
47
self .num_tokens = num_tokens
48
+ self .token_size = token_size
48
49
self .num_channels = num_channels
49
50
self .dim_in = token_size if mode == "channels" else num_channels
50
51
self .dim_embed = dim_embed
@@ -56,8 +57,6 @@ def __init__(
56
57
57
58
norm = torch .nn .LayerNorm if norm_type == "LayerNorm" else RMSNorm
58
59
59
- self .embed = torch .nn .Linear (self .dim_in , self .dim_embed )
60
-
61
60
self .layers = torch .nn .ModuleList ()
62
61
for _ in range (self .num_blocks ):
63
62
self .layers .append (
@@ -80,6 +79,8 @@ def __init__(
80
79
)
81
80
82
81
if mode == "channels" :
82
+ self .embed = torch .nn .Linear (self .dim_in , self .dim_embed )
83
+
83
84
if self .unembed_mode == "full" :
84
85
self .ln_final = norm (num_channels * self .dim_embed )
85
86
self .unembed = torch .nn .Linear (
@@ -94,6 +95,11 @@ def __init__(
94
95
dim_out = (self .num_tokens * self .dim_out - embed_size_centroids ) // num_channels
95
96
self .unembed = torch .nn .ModuleList (
96
97
[torch .nn .Linear (dim_embed , dim_out ) for _ in range (num_channels )]
98
+ # [
99
+ # torch.nn.Sequential(torch.nn.Linear(dim_embed, max(dim_embed//2,4*dim_out)),
100
+ # torch.nn.GELU(),
101
+ # torch.nn.Linear(max(dim_embed//2,4*dim_out), dim_out)) for _ in range(num_channels)
102
+ # ]
97
103
)
98
104
self .ln_final = torch .nn .ModuleList ([norm (dim_embed ) for _ in range (num_channels )])
99
105
@@ -103,9 +109,12 @@ def __init__(
103
109
self .forward = self .forward_channels
104
110
105
111
elif mode == "columns" :
112
+ assert embed_size_centroids == 0
113
+ self .embed = torch .nn .Linear (self .dim_in , self .dim_embed )
114
+
106
115
assert self .unembed_mode == "block" # only supported mode at the moment
107
116
# padding needed if the unembedded columns cannot be concatenated to dim_out (e.g GPSRO)
108
- self .pad = ( self .dim_out - embed_size_centroids ) % token_size
117
+ self .pad = self .dim_out % token_size
109
118
self .out_pad = torch .nn .Parameter (torch .zeros (self .pad ))
110
119
self .unembed = torch .nn .Linear (
111
120
self .dim_embed ,
@@ -114,6 +123,13 @@ def __init__(
114
123
self .ln_final = norm (dim_out )
115
124
self .forward = self .forward_columns
116
125
126
+ # TODO: factorization when sqrt is not int
127
+ dim1 = int (np .sqrt (dim_out ))
128
+ assert dim1 * dim1 == dim_out
129
+ self .unembed1 = torch .nn .Linear (self .dim_embed , dim1 )
130
+ self .unembed_nonlin = torch .nn .GELU ()
131
+ self .unembed2 = torch .nn .Linear (self .token_size , dim1 )
132
+
117
133
else :
118
134
assert False
119
135
@@ -135,7 +151,7 @@ def forward_channels(self, x_in, centroids):
135
151
elif self .unembed_mode == "block" :
136
152
out = [
137
153
checkpoint (ue , ln (x [:, i ]), use_reentrant = False )
138
- for i , (ue , ln ) in enumerate (zip (self .unembed , self .ln_final , strict = False ))
154
+ for i , (ue , ln ) in enumerate (zip (self .unembed , self .ln_final , strict = True ))
139
155
]
140
156
out = torch .stack (out , dim = 1 ).flatten (- 2 , - 1 )
141
157
else :
@@ -153,27 +169,22 @@ def forward_channels(self, x_in, centroids):
153
169
154
170
return out
155
171
156
- # @torch.compile( dynamic=True)
157
172
def forward_columns (self , x_in , centroids ):
158
173
# embed provided input data
159
174
x = positional_encoding_harmonic (checkpoint (self .embed , x_in , use_reentrant = False ))
160
175
161
176
for layer in self .layers :
162
177
x = checkpoint (layer , x , use_reentrant = False )
163
178
164
- # append centroids
165
- # unembed and reshape
166
- out = checkpoint (self .unembed , x , use_reentrant = False )
167
- out = out .flatten (- 2 , - 1 ).reshape (x .shape [0 ], self .num_tokens , - 1 )
168
- # TODO: unsqueeze will not work with num_tokens > 1
169
- out = torch .cat ([out , self .embed_centroids (centroids ).unsqueeze (1 )], - 1 )
170
- # pad to uniform dim_out (that has to be uniform across streams)
171
- if self .pad > 0 :
172
- out = torch .cat ((out , self .out_pad .repeat ((x .shape [0 ], self .num_tokens , 1 ))), - 1 )
173
- # also encode centroids with overlayed positional encoding
179
+ out = checkpoint (self .unembed1 , x , use_reentrant = False )
180
+ out = self .unembed_nonlin (out )
181
+ out = checkpoint (self .unembed2 , out .transpose (- 2 , - 1 ), use_reentrant = False )
182
+ out = out .flatten (- 2 , - 1 ).unsqueeze (1 )
183
+
184
+ # final normalize and dropout
174
185
out = self .dropout_final (self .ln_final (out ))
175
186
176
- return out
187
+ return out . to ( torch . float16 )
177
188
178
189
179
190
class StreamEmbedLinear (torch .nn .Module ):
0 commit comments