12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import math
15
-
15
+ from typing import List , Union
16
+ import jax
16
17
import flax .linen as nn
17
18
import jax .numpy as jnp
18
19
@@ -96,3 +97,163 @@ def __call__(self, timesteps):
96
97
return get_sinusoidal_embeddings (
97
98
timesteps , embedding_dim = self .dim , flip_sin_to_cos = self .flip_sin_to_cos , freq_shift = self .freq_shift
98
99
)
100
+
101
+ def get_1d_rotary_pos_embed (
102
+ dim : int ,
103
+ pos : Union [jnp .array , int ],
104
+ theta : float = 10000.0 ,
105
+ use_real = False ,
106
+ linear_factor = 1.0 ,
107
+ ntk_factor = 1.0 ,
108
+ repeat_interleave_real = True ,
109
+ freqs_dtype = jnp .float32
110
+ ):
111
+ """
112
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
113
+ """
114
+ assert dim % 2 == 0
115
+
116
+ if isinstance (pos , int ):
117
+ pos = jnp .arange (pos )
118
+
119
+ theta = theta * ntk_factor
120
+ freqs = (
121
+ 1.0
122
+ / (theta ** (jnp .arange (0 , dim , 2 , dtype = freqs_dtype )[: (dim // 2 )] / dim ))
123
+ / linear_factor
124
+ )
125
+ freqs = jnp .outer (pos , freqs )
126
+ if use_real and repeat_interleave_real :
127
+ freqs_cos = jnp .cos (freqs ).repeat (2 , axis = 1 ).astype (jnp .float32 )
128
+ freqs_sin = jnp .sin (freqs ).repeat (2 , axis = 1 ).astype (jnp .float32 )
129
+ return freqs_cos , freqs_sin
130
+ elif use_real :
131
+ freqs_cos = jnp .concatenate ([jnp .cos (freqs ), jnp .cos (freqs )], axis = - 1 ).astype (jnp .float32 )
132
+ freqs_sin = jnp .concatenate ([jnp .sin (freqs ), jnp .sin (freqs )], axis = - 1 ).astype (jnp .float32 )
133
+ return freqs_cos , freqs_sin
134
+ else :
135
+ raise ValueError (f"use_real { use_real } and repeat_interleave_real { repeat_interleave_real } is not supported" )
136
+
137
+ class PixArtAlphaTextProjection (nn .Module ):
138
+ """
139
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
140
+
141
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
142
+ """
143
+
144
+ hidden_size : int
145
+ out_features : int = None
146
+ act_fn : str = 'gelu_tanh'
147
+ dtype : jnp .dtype = jnp .float32
148
+ weights_dtype : jnp .dtype = jnp .float32
149
+ precision : jax .lax .Precision = None
150
+
151
+ @nn .compact
152
+ def __call__ (self , caption ):
153
+ hidden_states = nn .Dense (
154
+ self .hidden_size ,
155
+ use_bias = True ,
156
+ dtype = self .dtype ,
157
+ param_dtype = self .weights_dtype ,
158
+ precision = self .precision
159
+ )(caption )
160
+
161
+ if self .act_fn == 'gelu_tanh' :
162
+ act_1 = nn .gelu
163
+ elif self .act_fn == 'silu' :
164
+ act_1 = nn .swish
165
+ else :
166
+ raise ValueError (f"Unknown activation function: { self .act_fn } " )
167
+ hidden_states = act_1 (hidden_states )
168
+
169
+ hidden_states = nn .Dense (self .out_features )(hidden_states )
170
+ return hidden_states
171
+
172
+
173
+ class FluxPosEmbed (nn .Module ):
174
+ theta : int
175
+ axes_dim : List [int ]
176
+
177
+ @nn .compact
178
+ def __call__ (self , ids ):
179
+ n_axes = ids .shape [- 1 ]
180
+ cos_out = []
181
+ sin_out = []
182
+ pos = ids .astype (jnp .float32 )
183
+ freqs_dtype = jnp .float32
184
+ for i in range (n_axes ):
185
+ cos , sin = get_1d_rotary_pos_embed (
186
+ self .axes_dim [i ], pos [:i ],
187
+ repeat_interleave_real = True ,
188
+ use_real = True ,
189
+ freqs_dtype = freqs_dtype
190
+ )
191
+ cos_out .append (cos )
192
+ sin_out .append (sin )
193
+
194
+ freqs_cos = jnp .concatenate (cos_out , axis = - 1 )
195
+ freqs_sin = jnp .concatenate (sin_out , axis = - 1 )
196
+ return freqs_cos , freqs_sin
197
+
198
+ class CombinedTimestepTextProjEmbeddings (nn .Module ):
199
+ embedding_dim : int
200
+ pooled_projection_dim : int
201
+ dtype : jnp .dtype = jnp .float32
202
+ weights_dtype : jnp .dtype = jnp .float32
203
+ precision : jax .lax .Precision = None
204
+
205
+ @nn .compact
206
+ def __call__ (self , timestep , pooled_projection ):
207
+ timesteps_proj = FlaxTimesteps (dim = 256 , flip_sin_to_cos = True , freq_shift = 0 )(timestep )
208
+ timestep_emb = FlaxTimestepEmbedding (
209
+ time_embed_dim = self .embedding_dim ,
210
+ dtype = self .dtype ,
211
+ weights_dtype = self .weights_dtype ,
212
+ precision = self .precision
213
+ )(timesteps_proj .astype (pooled_projection .dtype ))
214
+
215
+ pooled_projections = PixArtAlphaTextProjection (
216
+ self .embedding_dim ,
217
+ act_fn = 'silu' ,
218
+ dtype = self .dtype ,
219
+ weights_dtype = self .weights_dtype ,
220
+ )(pooled_projection )
221
+
222
+ conditioning = timestep_emb + pooled_projection
223
+ return conditioning
224
+
225
+ class CombinedTimestepGuidanceTextProjEmbeddings (nn .Module ):
226
+ embedding_dim : int
227
+ pooled_projection_dim : int
228
+ dtype : jnp .dtype = jnp .float32
229
+ weights_dtype : jnp .dtype = jnp .float32
230
+ precision : jax .lax .Precision = None
231
+
232
+ @nn .compact
233
+ def __call__ (self , timestep , guidance , pooled_projection ):
234
+ timesteps_proj = FlaxTimesteps (dim = 256 , flip_sin_to_cos = True , freq_shift = 0 )(timestep )
235
+ timestep_emb = FlaxTimestepEmbedding (
236
+ time_embed_dim = self .embedding_dim ,
237
+ dtype = self .dtype ,
238
+ weights_dtype = self .weights_dtype
239
+ )(timesteps_proj .astype (pooled_projection .dtype ))
240
+
241
+ guidance_proj = FlaxTimesteps (dim = 256 , flip_sin_to_cos = True , freq_shift = 0 )(guidance )
242
+ guidance_emb = FlaxTimestepEmbedding (
243
+ time_embed_dim = self .embedding_dim ,
244
+ dtype = self .dtype ,
245
+ weights_dtype = self .weights_dtype
246
+ )(guidance_proj .astype (pooled_projection .dtype ))
247
+
248
+ time_guidance_emb = timestep_emb + guidance_emb
249
+
250
+ pooled_projections = PixArtAlphaTextProjection (
251
+ self .embedding_dim ,
252
+ act_fn = 'silu' ,
253
+ dtype = self .dtype ,
254
+ weights_dtype = self .weights_dtype ,
255
+ precision = self .precision
256
+ )(pooled_projection )
257
+ conditioning = time_guidance_emb + pooled_projections
258
+
259
+ return conditioning
0 commit comments