Skip to content

Commit 912befb

Browse files
add tranformer blocks shell
1 parent 8a771a9 commit 912befb

File tree

1 file changed

+63
-3
lines changed

1 file changed

+63
-3
lines changed

src/maxdiffusion/models/transformers/transformer_flux_flax.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,28 @@
1919
from typing import Any, Dict, Optional, Tuple, Union
2020
import jax
2121
import jax.numpy as jnp
22+
import flax
2223
import flax.linen as nn
24+
from ...configuration_utils import ConfigMixin
25+
from ..modeling_flax_utils import FlaxModelMixin
2326
from ..normalization_flax import AdaLayerNormZeroSingle, AdaLayerNormContinuous
2427
from ..attention_flax import FlaxAttention
2528
from ..embeddings_flax import FluxPosEmbed, CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
2629
from ...common_types import BlockSizes
2730
from ... import max_logging
31+
from ...utils import BaseOutput
32+
33+
@flax.struct.dataclass
34+
class Transformer2DModelOutput(BaseOutput):
35+
"""
36+
The output of [`FluxTransformer2DModel`].
37+
38+
Args:
39+
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
40+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
41+
"""
42+
43+
sample: jnp.ndarray
2844

2945
class FluxTransformerBlock(nn.Module):
3046
r"""
@@ -178,7 +194,26 @@ def setup(self):
178194

179195
self.tranformer_blocks = nn.scan(
180196
FluxTransformerBlock,
181-
197+
variable_axes={"params" : 0},
198+
split_rngs={"params" : True},
199+
in_axes=(
200+
nn.broadcast,
201+
nn.broadcast,
202+
nn.broadcast,
203+
nn.broadcast,
204+
),
205+
length=self.num_layers
206+
)(
207+
dim=self.dim,
208+
num_attention_heads=self.num_attention_heads,
209+
attention_head_dim=self.attention_head_dim,
210+
attention_kernel=self.attention_kernel,
211+
flash_min_seq_length=self.flash_min_seq_length,
212+
flash_block_sizes=self.flash_block_sizes,
213+
mesh=self.mesh,
214+
dtype=self.dtype,
215+
weights_dtype=self.weights_dtype,
216+
precision=self.precision
182217
)
183218

184219
self.single_tranformer_blocks = nn.scan(
@@ -224,7 +259,9 @@ def __call__(
224259
img_ids,
225260
txt_ids,
226261
guidance,
227-
joint_attention_kwargs = None,):
262+
return_dict: bool = True,
263+
train: bool = False,
264+
):
228265

229266
hidden_states = self.x_embedder(hidden_states)
230267

@@ -256,6 +293,29 @@ def __call__(
256293
ids = jnp.concatenate((txt_ids, img_ids), axis=0)
257294
image_rotary_emb = self.pos_embed(ids)
258295

259-
jax.lax.fori_loop(0, )
296+
encoder_hidden_states, hidden_states = self.tranformer_blocks(
297+
hidden_states=hidden_states,
298+
encoder_hidden_states=encoder_hidden_states,
299+
temb=temb,
300+
image_rotary_emb=image_rotary_emb
301+
)
302+
303+
hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], dim=1)
304+
305+
hidden_states = self.single_tranformer_blocks(
306+
hidden_states=hidden_states,
307+
temb=temb,
308+
image_rotary_emb=image_rotary_emb
309+
)
310+
311+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
312+
313+
hidden_states = self.norm_out(hidden_states, temb)
314+
output = self.proj_out(hidden_states)
315+
316+
if not return_dict:
317+
return (output,)
318+
319+
return Transformer2DModelOutput(sample=output)
260320

261321

0 commit comments

Comments
 (0)