Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/diffusers/models/transformers/transformer_chronoedit.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def apply_rotary_emb(
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
# FIXME(DefTruth): Since the key/value in cross-attention depends
# solely on encoder_hidden_states_img (img), the (q_chunk * k) * v
# computation can be parallelized independently. Thus, there is
# no need to pass the parallel_config here.
parallel_config=None,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
Expand All @@ -150,7 +154,11 @@ def apply_rotary_emb(
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
# FIXME(DefTruth): Since the key/value in cross-attention depends
# solely on encoder_hidden_states (text), the (q_chunk * k) * v
# computation can be parallelized independently. Thus, there is
# no need to pass the parallel_config here.
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
Expand Down Expand Up @@ -568,9 +576,10 @@ class ChronoEditTransformer3DModel(
"blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"blocks.*": {
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
# NOTE(DefTruth): We need to disable the splitting of encoder_hidden_states because
# the image_encoder consistently generates 257 tokens for image_embed. This causes
# the shape of encoder_hidden_states—whose token count is always 769 (512 + 257)
# after concatenation—to be indivisible by the number of devices in the CP.
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}

Expand Down