diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f51a4ef2b3f6..4bebc3404220 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -345,6 +345,7 @@ "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", "AudioLDMPipeline", + "AuraFlowImg2ImgPipeline", "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 011f23ed371c..352bf804d827 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -309,7 +309,7 @@ "StableDiffusionLDM3DPipeline", ] ) - _import_structure["aura_flow"] = ["AuraFlowPipeline"] + _import_structure["aura_flow"] = ["AuraFlowPipeline", "AuraFlowImg2ImgPipeline"] _import_structure["stable_diffusion_3"] = [ "StableDiffusion3Pipeline", "StableDiffusion3Img2ImgPipeline", @@ -515,7 +515,7 @@ AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, ) - from .aura_flow import AuraFlowPipeline + from .aura_flow import AuraFlowImg2ImgPipeline, AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .cogvideo import ( CogVideoXFunControlPipeline, diff --git a/src/diffusers/pipelines/aura_flow/__init__.py b/src/diffusers/pipelines/aura_flow/__init__.py index e1917baa61e2..ad4974bfeae3 100644 --- a/src/diffusers/pipelines/aura_flow/__init__.py +++ b/src/diffusers/pipelines/aura_flow/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"] + _import_structure["pipeline_aura_flow_img2img"] = ["AuraFlowImg2ImgPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +34,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_aura_flow import AuraFlowPipeline + from .pipeline_aura_flow_img2img import AuraFlowImg2ImgPipeline else: import sys diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py new file mode 100644 index 000000000000..5cf9a820681d --- /dev/null +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py @@ -0,0 +1,734 @@ +# Copyright 2025 AuraFlow Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, List, Optional, Tuple, Union + +import PIL +import torch +from transformers import T5Tokenizer, UMT5EncoderModel + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AuraFlowTransformer2DModel, AutoencoderKL +from diffusers.models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AuraFlowImg2ImgPipeline + >>> import requests + >>> from PIL import Image + >>> from io import BytesIO + + >>> # download an initial image + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> pipe = AuraFlowImg2ImgPipeline.from_pretrained("fal/AuraFlow-v0.3", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + >>> prompt = "A fantasy landscape, trending on artstation" + >>> image = pipe(prompt=prompt, image=init_image, strength=0.75, num_inference_steps=50).images[0] + >>> image.save("aura_flow_img2img.png") + ``` +""" + + +class AuraFlowImg2ImgPipeline(DiffusionPipeline): + r""" + Args: + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. AuraFlow uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + transformer ([`AuraFlowTransformer2DModel`]): + Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKL, + transformer: AuraFlowTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + @staticmethod + def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, + ): + """Calculate shift parameter based on image dimensions. + + Args: + image_seq_len: Length of the image sequence (height/vae_factor/2 * width/vae_factor/2) + base_seq_len: Base sequence length for interpolation + max_seq_len: Maximum sequence length for interpolation + base_shift: Base shift value + max_shift: Maximum shift value + + Returns: + Calculated shift parameter (mu) + """ + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + def check_inputs( + self, + prompt, + height, + width, + strength, + image, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") + + patch_size = 2 # AuraFlow uses patch size of 2 + required_divisor = self.vae_scale_factor * patch_size + if height % required_divisor != 0 or width % required_divisor != 0: + raise ValueError( + rf"\`height\` and \`width\` have to be divisible by the VAE scale factor ({self.vae_scale_factor}) times the transformer patch size ({patch_size}), which is {required_divisor}. " + f"Your dimensions are ({height}, {width})." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + """ + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = max_sequence_length + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + text_input_ids = text_inputs["input_ids"] + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + text_inputs = {k: v.to(device) for k, v in text_inputs.items()} + prompt_embeds = self.text_encoder(**text_inputs)[0] + prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) + prompt_embeds = prompt_embeds * prompt_attention_mask + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.reshape(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + uncond_input = {k: v.to(device) for k, v in uncond_input.items()} + negative_prompt_embeds = self.text_encoder(**uncond_input)[0] + negative_prompt_attention_mask = ( + uncond_input["attention_mask"].unsqueeze(-1).expand(negative_prompt_embeds.shape) + ) + negative_prompt_embeds = negative_prompt_embeds * negative_prompt_attention_mask + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.reshape(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def get_timesteps(self, num_inference_steps, strength, device): + # Set timesteps using the full range initially + timesteps = self.scheduler.timesteps.to(device=device) + + if len(timesteps) != num_inference_steps: + num_inference_steps = len(timesteps) # Adjust if scheduler changed num_steps + + # Get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] + + # Set begin index if scheduler supports it + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator=None, + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` must be `torch.Tensor`, `PIL.Image.Image` or list, got {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + latents_0 = image + else: + # VAE ⇢ latents (ALWAYS on fp32 for numerical stability) + orig_dtype = self.vae.dtype + if orig_dtype != torch.float32: + self.vae.to(dtype=torch.float32) + + latent_dist = self.vae.encode(image.to(dtype=torch.float32)).latent_dist + latents_0 = latent_dist.mean # ❶ deterministic! + + if orig_dtype != torch.float32: + self.vae.to(dtype=orig_dtype) + + # scale + latents_0 = latents_0 * self.vae.config.scaling_factor + + # replicate to match `batch_size` + if latents_0.shape[0] != batch_size: + if batch_size % latents_0.shape[0] != 0: + raise ValueError( + f"Cannot duplicate image batch of size {latents_0.shape[0]} " + f"to effective batch size {batch_size}." + ) + repeats = batch_size // latents_0.shape[0] + latents_0 = latents_0.repeat(repeats, 1, 1, 1) + + noise = randn_tensor( + latents_0.shape, + generator=generator, + device=latents_0.device, + dtype=latents_0.dtype, + ) + + # make sure `timestep` is 1-D and matches batch + if isinstance(timestep, (int, float)): + timestep = torch.tensor([timestep], device=latents_0.device, dtype=latents_0.dtype) + timestep = timestep.expand(latents_0.shape[0]) + + latents = self.scheduler.scale_noise(latents_0, timestep, noise) + return latents + + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None, + strength: float = 0.8, + negative_prompt: Union[str, List[str]] = None, + num_inference_steps: int = 50, + sigmas: List[float] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose style you want to transfer. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images. + """ + # 0. Default height and width to transformer config + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + strength, + image, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._num_inference_steps = num_inference_steps + + # 2. Preprocess image + image = self.image_processor.preprocess(image) + + # 3. Determine batch size. + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 4. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 5. Prepare timesteps + # Calculate shift parameter based on image dimensions + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + + # Calculate mu (shift parameter) based on image dimensions + mu = self.calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + + # Set timesteps with shift parameter + self.scheduler.set_timesteps(num_inference_steps, device=device, mu=mu) + + # Now adjust for strength + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength, device + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # Get the first timestep(s) for initial noise + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # AureFlow use timestep value between 0 and 1, with t=1 as noise and t=0 as the image + # create a timestep tensor with the correct batch size + # ensure it matches the batch size of the model input + t_float = t / 1000 + timestep_tensor = torch.full( + (latent_model_input.shape[0],), t_float, device=latents.device, dtype=latents.dtype + ) + + # Make sure latent_model_input has the same dtype as the transformer + transformer_dtype = self.transformer.dtype + if latent_model_input.dtype != transformer_dtype: + latent_model_input = latent_model_input.to(dtype=transformer_dtype) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep_tensor, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + # Always upcast VAE to float32 for decoding + vae_dtype = self.vae.dtype + if vae_dtype != torch.float32: + self.vae.to(dtype=torch.float32) + latents = latents.to(dtype=torch.float32) + + # Apply proper scaling factor and shift factor if available + if ( + hasattr(self.vae.config, "scaling_factor") + and hasattr(self.vae.config, "shift_factor") + and getattr(self.vae.config, "shift_factor", None) is not None + ): + # Handle both scaling and shifting + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + else: + # Just scale using standard approach + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # Restore VAE dtype + if vae_dtype != torch.float32: + self.vae.to(dtype=vae_dtype) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 6a5f6098b6fb..ccc3413d8e55 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin from ..models.controlnets import ControlNetUnionModel from ..utils import is_sentencepiece_available -from .aura_flow import AuraFlowPipeline +from .aura_flow import AuraFlowImg2ImgPipeline, AuraFlowPipeline from .cogview3 import CogView3PlusPipeline from .cogview4 import CogView4ControlPipeline, CogView4Pipeline from .controlnet import ( @@ -165,6 +165,7 @@ ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), + ("auraflow", AuraFlowImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), ("flux", FluxImg2ImgPipeline), ("flux-controlnet", FluxControlNetImg2ImgPipeline), diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index b3c6efb8cdcf..05801cd3a935 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -257,6 +257,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AuraFlowImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AuraFlowPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py index 1eb9d1035c33..aeaefb527327 100644 --- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py +++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py @@ -135,3 +135,6 @@ def test_fused_qkv_projections(self): @unittest.skip("xformers attention processor does not exist for AuraFlow") def test_xformers_attention_forwardGenerator_pass(self): pass + + def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=0.0004): + self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff) diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow_img2img.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow_img2img.py new file mode 100644 index 000000000000..3983dfb07a19 --- /dev/null +++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow_img2img.py @@ -0,0 +1,186 @@ +import unittest + +import numpy as np +import PIL.Image +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from diffusers import ( + AuraFlowImg2ImgPipeline, + AuraFlowTransformer2DModel, + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.testing_utils import torch_device + +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, +) + + +class AuraFlowImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = AuraFlowImg2ImgPipeline + params = frozenset( + [ + "prompt", + "image", + "strength", + "guidance_scale", + "negative_prompt", + "prompt_embeds", + "negative_prompt_embeds", + ] + ) + batch_params = frozenset(["prompt", "image", "negative_prompt"]) + test_layerwise_casting = False # T5 uses multiple devices + test_group_offloading = False # T5 uses multiple devices + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = AuraFlowTransformer2DModel( + sample_size=32, + patch_size=2, + in_channels=4, + num_mmdit_layers=1, + num_single_dit_layers=1, + attention_head_dim=8, + num_attention_heads=4, + caption_projection_dim=32, + joint_attention_dim=32, + out_channels=4, + pos_embed_max_size=256, + ) + + text_encoder = UMT5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-umt5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=32, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + # Ensure image dimensions are divisible by VAE scale factor * transformer patch size + # vae_scale_factor = 8, patch_size = 2 => divisible by 16 + image = PIL.Image.new("RGB", (64, 64)) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "strength": 0.75, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + # height/width are inferred from image in img2img + } + return inputs + + def test_attention_slicing_forward_pass(self): + # Attention slicing needs to implemented differently for this because how single DiT and MMDiT + # blocks interfere with each other. + return + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist(pipe.transformer) + assert check_qkv_fusion_matches_attn_procs_length(pipe.transformer, pipe.transformer.original_attn_processors) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2) + + @unittest.skip("xformers attention processor does not exist for AuraFlow") + def test_xformers_attention_forwardGenerator_pass(self): + pass + + def test_aura_flow_img2img_output_shape(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + + # The positional embedding has a max size of 256 + # Each position is a (height/vae_scale_factor/patch_size) × (width/vae_scale_factor/patch_size) grid + # To stay within limits: (height/8/2) * (width/8/2) < 256 + height_width_pairs = [(32, 32), (64, 32)] # creates 4 and 16 positions respectively + + for height, width in height_width_pairs: + inputs = self.get_dummy_inputs(torch_device) + # Override dummy image size + inputs["image"] = PIL.Image.new("RGB", (width, height)) + # Pass height/width explicitly to test pipeline handles them (though inferred by default) + inputs["height"] = height + inputs["width"] = width + + output = pipe(**inputs) + image = output.images[0] + + # Expected shape is (height, width, 3) for np output + self.assertEqual(image.shape, (height, width, 3)) + + def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=0.001): + self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff) + + def test_num_images_per_prompt(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + batch_sizes = [1] + num_images_per_prompts = [1, 2] + + for batch_size in batch_sizes: + for num_images_per_prompt in num_images_per_prompts: + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = 2 + + inputs["image"] = PIL.Image.new("RGB", (32, 32)) + + for key in inputs.keys(): + if key in self.batch_params: + inputs[key] = batch_size * [inputs[key]] + + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images + + assert len(images) == batch_size * num_images_per_prompt