|
19 | 19 | from typing import Any, Dict, Optional, Tuple, Union
|
20 | 20 | import jax
|
21 | 21 | import jax.numpy as jnp
|
| 22 | +import flax |
22 | 23 | import flax.linen as nn
|
| 24 | +from ...configuration_utils import ConfigMixin |
| 25 | +from ..modeling_flax_utils import FlaxModelMixin |
23 | 26 | from ..normalization_flax import AdaLayerNormZeroSingle, AdaLayerNormContinuous
|
24 | 27 | from ..attention_flax import FlaxAttention
|
25 | 28 | from ..embeddings_flax import FluxPosEmbed, CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
|
26 | 29 | from ...common_types import BlockSizes
|
27 | 30 | from ... import max_logging
|
| 31 | +from ...utils import BaseOutput |
| 32 | + |
| 33 | +@flax.struct.dataclass |
| 34 | +class Transformer2DModelOutput(BaseOutput): |
| 35 | + """ |
| 36 | + The output of [`FluxTransformer2DModel`]. |
| 37 | +
|
| 38 | + Args: |
| 39 | + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): |
| 40 | + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. |
| 41 | + """ |
| 42 | + |
| 43 | + sample: jnp.ndarray |
28 | 44 |
|
29 | 45 | class FluxTransformerBlock(nn.Module):
|
30 | 46 | r"""
|
@@ -178,7 +194,26 @@ def setup(self):
|
178 | 194 |
|
179 | 195 | self.tranformer_blocks = nn.scan(
|
180 | 196 | FluxTransformerBlock,
|
181 |
| - |
| 197 | + variable_axes={"params" : 0}, |
| 198 | + split_rngs={"params" : True}, |
| 199 | + in_axes=( |
| 200 | + nn.broadcast, |
| 201 | + nn.broadcast, |
| 202 | + nn.broadcast, |
| 203 | + nn.broadcast, |
| 204 | + ), |
| 205 | + length=self.num_layers |
| 206 | + )( |
| 207 | + dim=self.dim, |
| 208 | + num_attention_heads=self.num_attention_heads, |
| 209 | + attention_head_dim=self.attention_head_dim, |
| 210 | + attention_kernel=self.attention_kernel, |
| 211 | + flash_min_seq_length=self.flash_min_seq_length, |
| 212 | + flash_block_sizes=self.flash_block_sizes, |
| 213 | + mesh=self.mesh, |
| 214 | + dtype=self.dtype, |
| 215 | + weights_dtype=self.weights_dtype, |
| 216 | + precision=self.precision |
182 | 217 | )
|
183 | 218 |
|
184 | 219 | self.single_tranformer_blocks = nn.scan(
|
@@ -224,7 +259,9 @@ def __call__(
|
224 | 259 | img_ids,
|
225 | 260 | txt_ids,
|
226 | 261 | guidance,
|
227 |
| - joint_attention_kwargs = None,): |
| 262 | + return_dict: bool = True, |
| 263 | + train: bool = False, |
| 264 | + ): |
228 | 265 |
|
229 | 266 | hidden_states = self.x_embedder(hidden_states)
|
230 | 267 |
|
@@ -256,6 +293,29 @@ def __call__(
|
256 | 293 | ids = jnp.concatenate((txt_ids, img_ids), axis=0)
|
257 | 294 | image_rotary_emb = self.pos_embed(ids)
|
258 | 295 |
|
259 |
| - jax.lax.fori_loop(0, ) |
| 296 | + encoder_hidden_states, hidden_states = self.tranformer_blocks( |
| 297 | + hidden_states=hidden_states, |
| 298 | + encoder_hidden_states=encoder_hidden_states, |
| 299 | + temb=temb, |
| 300 | + image_rotary_emb=image_rotary_emb |
| 301 | + ) |
| 302 | + |
| 303 | + hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], dim=1) |
| 304 | + |
| 305 | + hidden_states = self.single_tranformer_blocks( |
| 306 | + hidden_states=hidden_states, |
| 307 | + temb=temb, |
| 308 | + image_rotary_emb=image_rotary_emb |
| 309 | + ) |
| 310 | + |
| 311 | + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] |
| 312 | + |
| 313 | + hidden_states = self.norm_out(hidden_states, temb) |
| 314 | + output = self.proj_out(hidden_states) |
| 315 | + |
| 316 | + if not return_dict: |
| 317 | + return (output,) |
| 318 | + |
| 319 | + return Transformer2DModelOutput(sample=output) |
260 | 320 |
|
261 | 321 |
|
0 commit comments