Skip to content

Diffusers 0.33.0: New Image and Video Models, Memory Optimizations, Caching Methods, Remote VAEs, New Training Scripts, and more

Compare
Choose a tag to compare
@sayakpaul sayakpaul released this 09 Apr 13:37
· 210 commits to main since this release

New Pipelines for Video Generation

Wan 2.1

Wan2.1 is a comprehensive and open suite of video foundation models that pushes the boundaries of video generation. The model release includes 4 different model variants and three different pipelines for Text to Video, Image to Video and Video to Video.

  • Wan-AI/Wan2.1-T2V-1.3B-Diffusers
  • Wan-AI/Wan2.1-T2V-14B-Diffusers
  • Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
  • Wan-AI/Wan2.1-I2V-14B-720P-Diffusers

Check out the docs here to learn more.

LTX Video 0.9.5

LTX Video 0.9.5 is the updated version of the super-fast LTX Video model series. The latest model introduces additional conditioning options, such as keyframe-based animation and video extension (both forward and backward).

To support these additional conditioning inputs, we’ve introduced the LTXConditionPipeline and LTXVideoCondition object.

To learn more about the usage, check out the docs here.

Hunyuan Image to Video

Hunyuan utilizes a pre-trained Multimodal Large Language Model (MLLM) with a Decoder-Only architecture as the text encoder. The input image is processed by the MLLM to generate semantic image tokens. These tokens are then concatenated with the video latent tokens, enabling comprehensive full-attention computation across the combined data and seamlessly integrating information from both the image and its associated caption.

To learn more, check out the docs here.

Others

New Pipelines for Image Generation

Sana-Sprint

SANA-Sprint is an efficient diffusion model for ultra-fast text-to-image generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4, rivaling the quality of models like Flux.

Shoutout to @lawrence-cj for their help and guidance on this PR.

Check out the pipeline docs of SANA-Sprint to learn more.

Lumina2

Lumina-Image-2.0 is a 2B parameter flow-based diffusion transformer for text-to-image generation released under the Apache 2.0 license.

Check out the docs to learn more. Thanks to @zhuole1025 for contributing this through this PR.

One can also LoRA fine-tune Lumina2, taking advantage of its Apach2.0 licensing. Check out the guide for more details.

Omnigen

OmniGen is a unified image generation model that can handle multiple tasks including text-to-image, image editing, subject-driven generation, and various computer vision tasks within a single framework. The model consists of a VAE, and a single transformer based on Phi-3 that handles text and image encoding as well as the diffusion process.

Check out the docs to learn more about OmniGen. Thanks to @staoxiao for contributing OmniGen in this PR.

Others

New Memory Optimizations

Layerwise Casting

PyTorch supports torch.float8_e4m3fn and torch.float8_e5m2 as weight storage dtypes, but they can’t be used for computation on many devices due to unimplemented kernel support.

However, you can still use these dtypes to store model weights in FP8 precision and upcast them to a widely supported dtype such as torch.float16 or torch.bfloat16 on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting. This can potentially cut down the VRAM requirements of a model by 50%.  

Code
import torch
from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.utils import export_to_video

model_id = "THUDM/CogVideoX-5b"

# Load the model in bfloat16 and enable layerwise casting
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)

# Load the pipeline
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = (
    "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
    "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
    "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
    "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
    "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
    "atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)

Group Offloading

Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either torch.nn.ModuleList or torch.nn.Sequential), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced.

On CUDA devices, we also have the option to enable using layer prefetching with CUDA Streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed which makes inference substantially faster while still keeping VRAM requirements very low. With this, we introduce the idea of overlapping computation with data transfer.

One thing to note is that using CUDA streams can cause a considerable spike in CPU RAM usage. Please ensure that the available CPU RAM is 2 times the size of the model if you choose to set use_stream=True. You can reduce CPU RAM usage by setting low_cpu_mem_usage=True. This should limit the CPU RAM used to be roughly the same as the size of the model, but will introduce slight latency in the inference process.

You can also use record_stream=True when using use_stream=True to obtain more speedups at the expense of slightly increased memory usage.

Code
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video

# Load the pipeline
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)

# We can utilize the enable_group_offload method for Diffusers model implementations
pipe.transformer.enable_group_offload(
	onload_device=onload_device, 
	offload_device=offload_device, 
	offload_type="leaf_level", 
	use_stream=True
)

prompt = (
    "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
    "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
    "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
    "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
    "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
    "atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
# This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline.
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
export_to_video(video, "output.mp4", fps=8)

Group offloading can also be applied to non-Diffusers models such as text encoders from the transformers library.

Code
import torch
from diffusers import CogVideoXPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video

# Load the pipeline
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)

# For any other model implementations, the apply_group_offloading function can be used
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)

Remote Components

Remote components are an experimental feature designed to offload memory-intensive steps of the inference pipeline to remote endpoints. The initial implementation focuses primarily on VAE decoding operations. Below are the currently supported model endpoints:

Model Endpoint Model
Stable Diffusion v1 https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud stabilityai/sd-vae-ft-mse
Stable Diffusion XL https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud madebyollin/sdxl-vae-fp16-fix
Flux https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud black-forest-labs/FLUX.1-schnell
HunyuanVideo https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud hunyuanvideo-community/HunyuanVideo

This is an example of using remote decoding with the Hunyuan Video pipeline:

Code
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel

model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(
    model_id, transformer=transformer, vae=None, torch_dtype=torch.float16
).to("cuda")

latent = pipe(
    prompt="A cat walks on the grass, realistic",
    height=320,
    width=512,
    num_frames=61,
    num_inference_steps=30,
    output_type="latent",
).frames

video = remote_decode(
    endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
    tensor=latent,
    output_type="mp4",
)

if isinstance(video, bytes):
    with open("video.mp4", "wb") as f:
        f.write(video) 

Check out the docs to know more.

Introducing Cached Inference for DiTs

Cached Inference for Diffusion Transformer models is a performance optimization that significantly accelerates the denoising process by caching intermediate values. This technique reduces redundant computations across timesteps, resulting in faster generation with a slight dip in output quality.

Check out the docs to learn more about the available caching methods.

Pyramind Attention Broadcast

Code
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")

config = PyramidAttentionBroadcastConfig(
    spatial_attention_block_skip_range=2,
    spatial_attention_timestep_skip_range=(100, 800),
    current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)

FasterCache

Code
import torch
from diffusers import CogVideoXPipeline, FasterCacheConfig

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")

config = FasterCacheConfig(
        spatial_attention_block_skip_range=2,
        spatial_attention_timestep_skip_range=(-1, 901),
        unconditional_batch_skip_range=2,
        attention_weight_callback=lambda _: 0.5,
        is_guidance_distilled=True,
)
pipe.transformer.enable_cache(config)

Quantization

Quanto Backend

Diffusers now has support for the Quanto quantization backend, which provides float8 , int8 , int4 and int2 quantization dtypes.

import torch
from diffusers import FluxTransformer2DModel, QuantoConfig

model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8")
transformer = FluxTransformer2DModel.from_pretrained(
      model_id,
      subfolder="transformer",
      quantization_config=quantization_config,
      torch_dtype=torch.bfloat16,
)

Quanto int8 models are also compatible with torch.compile :

Code
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig

model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8")
transformer = FluxTransformer2DModel.from_pretrained(
      model_id,
      subfolder="transformer",
      quantization_config=quantization_config,
      torch_dtype=torch.bfloat16,
)
transformer.compile()

Improved loading for uintx TorchAO checkpoints with torch>=2.6

TorchAO checkpoints currently have to be serialized using pickle. For some quantization dtypes using the uintx format, such as uint4wo this involves saving subclassed TorchAO Tensor objects in the model file. This made loading the models directly with Diffusers a bit tricky since we do not allow deserializing artbitary Python objects from pickle files.

Torch 2.6 allows adding expected Tensors to torch safe globals, which lets us directly load TorchAO checkpoints with these objects.

- state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
- with init_empty_weights():
-     transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
- transformer.load_state_dict(state_dict, strict=True, assign=True)
+ transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_uint4wo/")  

LoRAs

We have shipped a couple of improvements on the LoRA front in this release.

🚨 Improved coverage for loading non-diffusers LoRA checkpoints for Flux

Take note of the breaking change introduced in this PR 🚨 We suggest you upgrade your peft installation to the latest version - pip install -U peft especially when dealing with Flux LoRAs.

torch.compile() support when hotswapping LoRAs without triggering recompilation

A common use case when serving multiple adapters is to load one adapter first, generate images, load another adapter, generate more images, load another adapter, etc. This workflow normally requires calling load_lora_weights(), set_adapters(), and possibly delete_adapters() to save memory. Moreover, if the model is compiled using torch.compile, performing these steps requires recompilation, which takes time.

To better support this common workflow, you can “hotswap” a LoRA adapter, to avoid accumulating memory and in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter.

Check out the docs to learn more about this feature.

The other major change is the support for

  • Loading LoRAs into quantized model checkpoints

dtype Maps for Pipelines

Since various pipelines require their components to run in different compute dtypes, we now support passing a dtype map when initializing a pipeline:

from diffusers import HunyuanVideoPipeline
import torch

pipe = HunyuanVideoPipeline.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    torch_dtype={"transformer": torch.bfloat16, "default": torch.float16},
)
print(pipe.transformer.dtype, pipe.vae.dtype)  # (torch.bfloat16, torch.float16)

AutoModel

This release includes an AutoModel object similar to the one found in transformers that automatically fetches the appropriate model class for the provided repo.

from diffusers import AutoModel

unet = AutoModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")

All commits

Significant community contributions

The following contributors have made significant changes to the library over the last release:

  • @guiyrt
    • IP-Adapter for StableDiffusion3Img2ImgPipeline (#10589)
    • [Docs] Update SD3 ip_adapter model_id to diffusers checkpoint (#10597)
    • MultiControlNetUnionModel on SDXL (#10747)
    • SD3 IP-Adapter runtime checkpoint conversion (#10718)
    • Comprehensive type checking for from_pretrained kwargs (#10758)
    • Multi IP-Adapter for Flux pipelines (#10867)
  • @chengzeyi
    • [Docs] Add documentation about using ParaAttention to optimize FLUX and HunyuanVideo (#10544)
    • Fix Graph Breaks When Compiling CogView4 (#10959)
    • Fix Wan I2V Quality (#11087)
  • @entrpn
    • implementing flux on TPUs with ptxla (#10515)
    • reverts accidental change that removes attn_mask in attn. Improves fl… (#11065)
    • update readme instructions. (#11096)
  • @SHYuanBest
  • @faaany
    • [tests] make tests device-agnostic (part 3) (#10437)
    • make tensors contiguous before passing to safetensors (#10761)
    • [tests] make tests device-agnostic (part 4) (#10508)
    • [tests] enable bnb tests on xpu (#11001)
    • [tests] make cuda only tests device-agnostic (#11058)
    • [tests] no hard-coded cuda (#11186)
    • [tests] HunyuanDiTControlNetPipeline inference precision issue on XPU (#11197)
  • @yiyixuxu
    • fix offload gpu tests etc (#10366)
    • follow-up refactor on lumina2 (#10776)
    • [Alibaba Wan Team] continue on #10921 Wan2.1 (#10922)
    • update check_input for cogview4 (#10966)
    • remove F.rms_norm for now (#11126)
    • add sana-sprint (#11074)
  • @DN6
    • [CI] Update HF_TOKEN in all workflows (#10613)
    • [CI] Fix Truffle Hog failure (#10769)
    • [Single File] Add Single File support for Lumina Image 2.0 Transformer (#10781)
    • [CI] Fix incorrectly named test module for Hunyuan DiT (#10854)
    • [CI] Update always test Pipelines list in Pipeline fetcher (#10856)
    • [Docs] Fix toctree sorting (#10894)
    • [CI] Improvements to conditional GPU PR tests (#10859)
    • [CI] Fix Fast GPU tests on PR (#10912)
    • [CI] Fix for failing IP Adapter test in Fast GPU PR tests (#10915)
    • [CI] Update Stylebot Permissions (#10931)
    • [Single File] Add user agent to SF download requests. (#10979)
    • [Single File] Add single file support for Wan T2V/I2V (#10991)
    • Fix for fetching variants only (#10646)
    • [Quantization] Add Quanto backend (#10756)
    • [Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6 (#11018)
    • [Refactor] Clean up import utils boilerplate (#11026)
    • Provide option to reduce CPU RAM usage in Group Offload (#11106)
    • [Quantization] dtype fix for GGUF + fix BnB tests (#11159)
    • [Docs] Update Wan Docs with memory optimizations (#11089)
    • [WIP] Add Wan Video2Video (#11053)
    • Add CacheMixin to Wan and LTX Transformers (#11187)
    • Fix Single File loading for LTX VAE (#11200)
    • Update Ruff to latest Version (#10919)
  • @Anonym0u3
    • Add pipeline_stable_diffusion_xl_attentive_eraser (#10579)
  • @lavinal712
    • create a script to train autoencoderkl (#10605)
    • [BUG] Fix Autoencoderkl train script (#11113)
  • @Marlon154
    • Add community pipeline for semantic guidance for FLUX (#10610)
  • @ParagEkbote
    • Fix Documentation about Image-to-Image Pipeline (#10704)
    • Notebooks for Community Scripts-6 (#10713)
    • Extend Support for callback_on_step_end for AuraFlow and LuminaText2Img Pipelines (#10746)
    • Notebooks for Community Scripts-7 (#10846)
    • Add Example of IPAdapterScaleCutoffCallback to Docs (#10934)
    • Notebooks for Community Scripts-8 (#11128)
  • @suzukimain
    • [Community] Enhanced Model Search (#10417)
  • @staoxiao
  • @elismasilva
    • feat: new community mixture_tiling_sdxl pipeline for SDXL (#10759)
    • fix: [Community pipeline] Fix flattened elements on image (#10774)
    • feat: add Mixture-of-Diffusers ControlNet Tile upscaler Pipeline for SDXL (#10951)
    • fix: mixture tiling sdxl pipeline - adjust gerating time_ids & embeddings (#11012)
    • fix: for checking mandatory and optional pipeline components (#11189)
    • feat: [Community Pipeline] - FaithDiff Stable Diffusion XL Pipeline (#11188)
  • @zhuole1025
    • Add support for lumina2 (#10642)
  • @zRzRzRzRzRzRzR
    • CogView4 (supports different length c and uc) (#10649)
    • Update pipeline_cogview4.py (#10944)
    • [Docs] CogView4 comment fix (#10957)
    • CogView4 Control Block (#10809)
    • Modify the implementation of retrieve_timesteps in CogView4-Control. (#11125)
  • @toshas
    • Marigold Update: v1-1 models, Intrinsic Image Decomposition pipeline, documentation (#10884)
  • @bubbliiiing
    • Add EasyAnimateV5.1 text-to-video, image-to-video, control-to-video generation model (#10626)
  • @LittleNyima
    • Add CogVideoX DDIM Inversion to Community Pipelines (#10956)
  • @kinam0252
    • Add STG to community pipelines (#10960)
  • @tolgacangoz
    • [Research Project] Add AnyText: Multilingual Visual Text Generation And Editing (#8998)
    • Update README and example code for AnyText usage (#11028)
    • [LTX0.9.5] Refactor LTXConditionPipeline for text-only conditioning (#11174)
  • @Ednaordinary
    • Add Wan with STG as a community pipeline (#11184)