Skip to content

Conversation

@DefTruth
Copy link
Contributor

@DefTruth DefTruth commented Nov 14, 2025

fixed #12661, fix the crash of ChronoEdit with context parallelism.

  1. We need to disable the splitting of encoder_hidden_states because the image_encoder consistently generates 257 tokens for image_embed. This causes the shape of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation—to be indivisible by the number of devices in the CP.

  2. Since the key/value in cross-attention depends solely on encoder_hidden_states (text or img), the (q_chunk * k) * v computation can be parallelized independently. Thus, there is no need to pass the parallel_config for cross-attention. This change reduces redundant all-to-all communications—specifically (3+1)×2=8 for the two cross-attention operations (text and img)—thereby improving ChronoEdit’s performance under context parallelism. With this optimization alone, I have achieved a nearly 1.85× speedup on L20x2, without relying on other optimizations such as torch.compile.

@sayakpaul @yiyixuxu @DN6

Reproduce

  • test script
import os
import time
import torch
import numpy as np
from PIL import Image
import torch.distributed as dist
from diffusers import (
    AutoencoderKLWan,
    ChronoEditTransformer3DModel,
    ChronoEditPipeline,
)
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers import ContextParallelConfig
from diffusers.utils import load_image
from transformers import CLIPVisionModel


dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
world_size = dist.get_world_size()
torch.cuda.set_device(device)

model_id = "nvidia/ChronoEdit-14B-Diffusers"
model_id = os.environ.get("CHRONO_EDIT_DIR", model_id)

image_encoder = CLIPVisionModel.from_pretrained(
    model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(
    model_id, subfolder="vae", torch_dtype=torch.float32
)
transformer = ChronoEditTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)

pipe = ChronoEditPipeline.from_pretrained(
    model_id,
    vae=vae,
    image_encoder=image_encoder,
    transformer=transformer,
    torch_dtype=torch.bfloat16,
    quantization_config=(
        PipelineQuantizationConfig(
            quant_backend="bitsandbytes_4bit",
            quant_kwargs={
                "load_in_4bit": True,
                "bnb_4bit_quant_type": "nf4",
                "bnb_4bit_compute_dtype": torch.bfloat16,
            },
            # text_encoder: ~ 6GiB, transformer: ~ 8GiB, total: ~14GiB
            components_to_quantize=["text_encoder", "transformer"],
        )
    ),
).to(device)

torch.cuda.empty_cache()
assert isinstance(pipe.vae, AutoencoderKLWan)
pipe.vae.enable_tiling()

image = load_image("../examples/data/chrono_edit_example.png")

max_area = 720 * 1280
aspect_ratio = image.height / image.width
mod_value = (
    pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
)
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))

prompt = (
    "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
    "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
)

assert isinstance(pipe.transformer, ChronoEditTransformer3DModel)
pipe.transformer.set_attention_backend("native")
if world_size > 1:
    pipe.transformer.enable_parallelism(
        config=ContextParallelConfig(ulysses_degree=world_size)
    )

pipe.set_progress_bar_config(disable=rank != 0)


def run_pipe(warmup: bool = False):
    output = pipe(
        image=image,
        prompt=prompt,
        height=height,
        width=width,
        num_frames=5,
        guidance_scale=5.0,
        enable_temporal_reasoning=False,
        num_temporal_reasoning_steps=0,
        num_inference_steps=50 if not warmup else 2,
        generator=torch.Generator("cuda").manual_seed(0),
    ).frames[0]
    output = Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8"))
    return output


start = time.time()
output = run_pipe()
end = time.time()

if rank == 0:
    time_cost = end - start
    save_path = f"chrono-edit.{world_size}gpus.png"
    print(f"Time cost: {time_cost:.2f}s")
    print(f"Saving image to {save_path}")
    output.save(save_path)

if dist.is_initialized():
    dist.destroy_process_group()
  • test cmd:
torchrun --nproc_per_node=4 run_chrono_edit.py
  • w/o this PR:
rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 188, in new_forward
[rank1]:     args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 157, in pre_forward
[rank1]:     input_val = self._prepare_cp_input(input_val, cpm)
[rank1]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 211, in _prepare_cp_input
[rank1]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 261, in shard
[rank1]:     assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size
  • w/ this PR:
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
100%|████████████████████████████████| 50/50 [01:22<00:00,  1.64s/it]
Saving image to chrono-edit.4gpus.png
Baseline Ulysses 4
chrono-edit C0_Q1_bitsandbytes_4bit_NONE chrono-edit C0_Q1_bitsandbytes_4bit_NONE_Ulysses4

@DefTruth
Copy link
Contributor Author

@sayakpaul @yiyixuxu @DN6 Hi~ can you take a look to this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Crash while using ChronoEdit with context parallelism

1 participant