Skip to content

Commit 8a771a9

Browse files
start adding flux transformer.
1 parent b6ed307 commit 8a771a9

File tree

3 files changed

+389
-9
lines changed

3 files changed

+389
-9
lines changed

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
15+
from typing import List, Union
16+
import jax
1617
import flax.linen as nn
1718
import jax.numpy as jnp
1819

@@ -96,3 +97,163 @@ def __call__(self, timesteps):
9697
return get_sinusoidal_embeddings(
9798
timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
9899
)
100+
101+
def get_1d_rotary_pos_embed(
102+
dim: int,
103+
pos: Union[jnp.array, int],
104+
theta: float = 10000.0,
105+
use_real=False,
106+
linear_factor=1.0,
107+
ntk_factor=1.0,
108+
repeat_interleave_real=True,
109+
freqs_dtype=jnp.float32
110+
):
111+
"""
112+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
113+
"""
114+
assert dim % 2 == 0
115+
116+
if isinstance(pos, int):
117+
pos = jnp.arange(pos)
118+
119+
theta = theta * ntk_factor
120+
freqs = (
121+
1.0
122+
/ (theta ** (jnp.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim))
123+
/ linear_factor
124+
)
125+
freqs = jnp.outer(pos, freqs)
126+
if use_real and repeat_interleave_real:
127+
freqs_cos = jnp.cos(freqs).repeat(2, axis=1).astype(jnp.float32)
128+
freqs_sin = jnp.sin(freqs).repeat(2, axis=1).astype(jnp.float32)
129+
return freqs_cos, freqs_sin
130+
elif use_real:
131+
freqs_cos = jnp.concatenate([jnp.cos(freqs), jnp.cos(freqs)], axis=-1).astype(jnp.float32)
132+
freqs_sin = jnp.concatenate([jnp.sin(freqs), jnp.sin(freqs)], axis=-1).astype(jnp.float32)
133+
return freqs_cos, freqs_sin
134+
else:
135+
raise ValueError(f"use_real {use_real} and repeat_interleave_real {repeat_interleave_real} is not supported")
136+
137+
class PixArtAlphaTextProjection(nn.Module):
138+
"""
139+
Projects caption embeddings. Also handles dropout for classifier-free guidance.
140+
141+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
142+
"""
143+
144+
hidden_size: int
145+
out_features: int = None
146+
act_fn: str ='gelu_tanh'
147+
dtype: jnp.dtype = jnp.float32
148+
weights_dtype: jnp.dtype = jnp.float32
149+
precision: jax.lax.Precision = None
150+
151+
@nn.compact
152+
def __call__(self, caption):
153+
hidden_states = nn.Dense(
154+
self.hidden_size,
155+
use_bias=True,
156+
dtype=self.dtype,
157+
param_dtype=self.weights_dtype,
158+
precision=self.precision
159+
)(caption)
160+
161+
if self.act_fn == 'gelu_tanh':
162+
act_1 = nn.gelu
163+
elif self.act_fn == 'silu':
164+
act_1 = nn.swish
165+
else:
166+
raise ValueError(f"Unknown activation function: {self.act_fn}")
167+
hidden_states = act_1(hidden_states)
168+
169+
hidden_states = nn.Dense(self.out_features)(hidden_states)
170+
return hidden_states
171+
172+
173+
class FluxPosEmbed(nn.Module):
174+
theta: int
175+
axes_dim: List[int]
176+
177+
@nn.compact
178+
def __call__(self, ids):
179+
n_axes = ids.shape[-1]
180+
cos_out = []
181+
sin_out = []
182+
pos = ids.astype(jnp.float32)
183+
freqs_dtype = jnp.float32
184+
for i in range(n_axes):
185+
cos, sin = get_1d_rotary_pos_embed(
186+
self.axes_dim[i], pos[:i],
187+
repeat_interleave_real=True,
188+
use_real=True,
189+
freqs_dtype=freqs_dtype
190+
)
191+
cos_out.append(cos)
192+
sin_out.append(sin)
193+
194+
freqs_cos = jnp.concatenate(cos_out, axis=-1)
195+
freqs_sin = jnp.concatenate(sin_out, axis=-1)
196+
return freqs_cos, freqs_sin
197+
198+
class CombinedTimestepTextProjEmbeddings(nn.Module):
199+
embedding_dim: int
200+
pooled_projection_dim: int
201+
dtype: jnp.dtype = jnp.float32
202+
weights_dtype: jnp.dtype = jnp.float32
203+
precision: jax.lax.Precision = None
204+
205+
@nn.compact
206+
def __call__(self, timestep, pooled_projection):
207+
timesteps_proj = FlaxTimesteps(dim=256, flip_sin_to_cos=True, freq_shift=0)(timestep)
208+
timestep_emb = FlaxTimestepEmbedding(
209+
time_embed_dim=self.embedding_dim,
210+
dtype=self.dtype,
211+
weights_dtype=self.weights_dtype,
212+
precision=self.precision
213+
)(timesteps_proj.astype(pooled_projection.dtype))
214+
215+
pooled_projections = PixArtAlphaTextProjection(
216+
self.embedding_dim,
217+
act_fn='silu',
218+
dtype=self.dtype,
219+
weights_dtype=self.weights_dtype,
220+
)(pooled_projection)
221+
222+
conditioning = timestep_emb + pooled_projection
223+
return conditioning
224+
225+
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
226+
embedding_dim: int
227+
pooled_projection_dim: int
228+
dtype: jnp.dtype = jnp.float32
229+
weights_dtype: jnp.dtype = jnp.float32
230+
precision: jax.lax.Precision = None
231+
232+
@nn.compact
233+
def __call__(self, timestep, guidance, pooled_projection):
234+
timesteps_proj = FlaxTimesteps(dim=256, flip_sin_to_cos=True, freq_shift=0)(timestep)
235+
timestep_emb = FlaxTimestepEmbedding(
236+
time_embed_dim=self.embedding_dim,
237+
dtype=self.dtype,
238+
weights_dtype=self.weights_dtype
239+
)(timesteps_proj.astype(pooled_projection.dtype))
240+
241+
guidance_proj = FlaxTimesteps(dim=256, flip_sin_to_cos=True, freq_shift=0)(guidance)
242+
guidance_emb = FlaxTimestepEmbedding(
243+
time_embed_dim=self.embedding_dim,
244+
dtype=self.dtype,
245+
weights_dtype=self.weights_dtype
246+
)(guidance_proj.astype(pooled_projection.dtype))
247+
248+
time_guidance_emb = timestep_emb + guidance_emb
249+
250+
pooled_projections = PixArtAlphaTextProjection(
251+
self.embedding_dim,
252+
act_fn='silu',
253+
dtype=self.dtype,
254+
weights_dtype=self.weights_dtype,
255+
precision=self.precision
256+
)(pooled_projection)
257+
conditioning = time_guidance_emb + pooled_projections
258+
259+
return conditioning

src/maxdiffusion/models/normalization_flax.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,30 @@
1414
limitations under the License.
1515
"""
1616

17-
'''This script is used an example of how to shard the UNET on TPU.'''
18-
17+
import jax
1918
import jax.numpy as jnp
2019
import flax.linen as nn
2120

22-
class FlaxAdaLayerNormZeroSingle(nn.Module):
21+
class AdaLayerNormContinuous(nn.Module):
22+
embedding_dim: int
23+
elementwise_affine: bool = True
24+
eps: float = 1e-5
25+
bias: bool = True
26+
norm_type: str = "layer_norm"
27+
dtype: jnp.dtype = jnp.float32
28+
weights_dtype: jnp.dtype = jnp.float32
29+
precision: jax.lax.Precision = None
30+
31+
@nn.compact
32+
def __call__(self, x, conditioning_embedding):
33+
assert self.norm_type == 'layer_norm'
34+
emb = nn.Dense(self.embedding_dim * 2, use_bias=self.bias)(nn.sigmoid(conditioning_embedding))
35+
scale, shift = jnp.split(emb, 2, axis=1)
36+
x = nn.LayerNorm(epsilon=self.eps, use_bias=self.elementwise_affine, use_scale=self.elementwise_affine)(x)
37+
x *= (1 + scale[:, None, :]) + shift[:, None, :]
38+
return x
39+
40+
class AdaLayerNormZeroSingle(nn.Module):
2341
r"""
2442
Norm layer adaptive layer norm zero (adaLN-Zero).
2543
@@ -30,11 +48,20 @@ class FlaxAdaLayerNormZeroSingle(nn.Module):
3048
embedding_dim: int
3149
norm_type: str = "layer_norm"
3250
bias: bool = True
51+
dtype: jnp.dtype = jnp.float32
52+
weights_dtype: jnp.dtype = jnp.float32
53+
precision: jax.lax.Precision = None
3354

3455
@nn.compact
3556
def __call__(self, x, emb):
3657
emb = nn.silu(emb)
37-
emb = nn.Dense(3 * self.embedding_dim, use_bias=self.bias)(emb)
58+
emb = nn.Dense(
59+
3 * self.embedding_dim,
60+
use_bias=self.bias,
61+
dtype=self.dtype,
62+
param_dtype=self.weights_dtype,
63+
precision=self.precision
64+
)(emb)
3865
shift_msa, scale_msa, gate_msa = jnp.split(emb, 3, axis=1)
3966
if self.norm_type == "layer_norm":
4067
x = nn.LayerNorm(epsilon=1e-6, use_bias=False, use_scale=False)(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]

0 commit comments

Comments
 (0)