Skip to content

Add basic implementation of AuraFlowImg2ImgPipeline #11340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
"AudioLDMPipeline",
"AuraFlowImg2ImgPipeline",
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@
"StableDiffusionLDM3DPipeline",
]
)
_import_structure["aura_flow"] = ["AuraFlowPipeline"]
_import_structure["aura_flow"] = ["AuraFlowPipeline", "AuraFlowImg2ImgPipeline"]
_import_structure["stable_diffusion_3"] = [
"StableDiffusion3Pipeline",
"StableDiffusion3Img2ImgPipeline",
Expand Down Expand Up @@ -515,7 +515,7 @@
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
)
from .aura_flow import AuraFlowPipeline
from .aura_flow import AuraFlowImg2ImgPipeline, AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/aura_flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"]
_import_structure["pipeline_aura_flow_img2img"] = ["AuraFlowImg2ImgPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand All @@ -33,6 +34,7 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_aura_flow import AuraFlowPipeline
from .pipeline_aura_flow_img2img import AuraFlowImg2ImgPipeline

else:
import sys
Expand Down
761 changes: 761 additions & 0 deletions src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/diffusers/pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..configuration_utils import ConfigMixin
from ..models.controlnets import ControlNetUnionModel
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
from .aura_flow import AuraFlowImg2ImgPipeline, AuraFlowPipeline
from .cogview3 import CogView3PlusPipeline
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
from .controlnet import (
Expand Down Expand Up @@ -165,6 +165,7 @@
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("auraflow", AuraFlowImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_torch_and_transformers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class AuraFlowImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class AuraFlowPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

Expand Down
3 changes: 3 additions & 0 deletions tests/pipelines/aura_flow/test_pipeline_aura_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,6 @@ def test_fused_qkv_projections(self):
@unittest.skip("xformers attention processor does not exist for AuraFlow")
def test_xformers_attention_forwardGenerator_pass(self):
pass

def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=0.0004):
self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff)
186 changes: 186 additions & 0 deletions tests/pipelines/aura_flow/test_pipeline_aura_flow_img2img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import unittest

import numpy as np
import PIL.Image
import torch
from transformers import AutoTokenizer, UMT5EncoderModel

from diffusers import (
AuraFlowImg2ImgPipeline,
AuraFlowTransformer2DModel,
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.utils.testing_utils import torch_device

from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)


class AuraFlowImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = AuraFlowImg2ImgPipeline
params = frozenset(
[
"prompt",
"image",
"strength",
"guidance_scale",
"negative_prompt",
"prompt_embeds",
"negative_prompt_embeds",
]
)
batch_params = frozenset(["prompt", "image", "negative_prompt"])
test_layerwise_casting = False # T5 uses multiple devices
test_group_offloading = False # T5 uses multiple devices

def get_dummy_components(self):
torch.manual_seed(0)
transformer = AuraFlowTransformer2DModel(
sample_size=32,
patch_size=2,
in_channels=4,
num_mmdit_layers=1,
num_single_dit_layers=1,
attention_head_dim=8,
num_attention_heads=4,
caption_projection_dim=32,
joint_attention_dim=32,
out_channels=4,
pos_embed_max_size=256,
)

text_encoder = UMT5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-umt5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=32,
)

scheduler = FlowMatchEulerDiscreteScheduler()

return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}

def get_dummy_inputs(self, device, seed=0):
# Ensure image dimensions are divisible by VAE scale factor * transformer patch size
# vae_scale_factor = 8, patch_size = 2 => divisible by 16
image = PIL.Image.new("RGB", (64, 64))
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)

inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"strength": 0.75,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
# height/width are inferred from image in img2img
}
return inputs

def test_attention_slicing_forward_pass(self):
# Attention slicing needs to implemented differently for this because how single DiT and MMDiT
# blocks interfere with each other.
return

def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]

pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer)
assert check_qkv_fusion_matches_attn_procs_length(pipe.transformer, pipe.transformer.original_attn_processors)

inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]

pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]

assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3)
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3)
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2)

@unittest.skip("xformers attention processor does not exist for AuraFlow")
def test_xformers_attention_forwardGenerator_pass(self):
pass

def test_aura_flow_img2img_output_shape(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)

# The positional embedding has a max size of 256
# Each position is a (height/vae_scale_factor/patch_size) × (width/vae_scale_factor/patch_size) grid
# To stay within limits: (height/8/2) * (width/8/2) < 256
height_width_pairs = [(32, 32), (64, 32)] # creates 4 and 16 positions respectively

for height, width in height_width_pairs:
inputs = self.get_dummy_inputs(torch_device)
# Override dummy image size
inputs["image"] = PIL.Image.new("RGB", (width, height))
# Pass height/width explicitly to test pipeline handles them (though inferred by default)
inputs["height"] = height
inputs["width"] = width

output = pipe(**inputs)
image = output.images[0]

# Expected shape is (height, width, 3) for np output
self.assertEqual(image.shape, (height, width, 3))

def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=0.001):
self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff)

def test_num_images_per_prompt(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

batch_sizes = [1]
num_images_per_prompts = [1, 2]

for batch_size in batch_sizes:
for num_images_per_prompt in num_images_per_prompts:
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 2

inputs["image"] = PIL.Image.new("RGB", (32, 32))

for key in inputs.keys():
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]

images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images

assert len(images) == batch_size * num_images_per_prompt