From 6fb6f2bbcdff25e3abd83bb0371e29cae849c376 Mon Sep 17 00:00:00 2001 From: AstraliteHeart <81396681+AstraliteHeart@users.noreply.github.com> Date: Wed, 16 Apr 2025 04:04:51 -0700 Subject: [PATCH 1/5] Add basic implementation for AuraFlowImg2ImgPipeline --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/aura_flow/__init__.py | 2 + .../aura_flow/pipeline_aura_flow_img2img.py | 703 ++++++++++++++++++ src/diffusers/pipelines/auto_pipeline.py | 3 +- .../test_pipeline_aura_flow_img2img.py | 117 +++ 6 files changed, 827 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py create mode 100644 tests/pipelines/aura_flow/test_pipeline_aura_flow_img2img.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f51a4ef2b3f6..fe2ff12dc4c8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -346,6 +346,7 @@ "AudioLDM2UNet2DConditionModel", "AudioLDMPipeline", "AuraFlowPipeline", + "AuraFlowImg2ImgPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", "CLIPImageProjection", diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 011f23ed371c..a42b0b34bb83 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -309,7 +309,7 @@ "StableDiffusionLDM3DPipeline", ] ) - _import_structure["aura_flow"] = ["AuraFlowPipeline"] + _import_structure["aura_flow"] = ["AuraFlowPipeline", "AuraFlowImg2ImgPipeline"] _import_structure["stable_diffusion_3"] = [ "StableDiffusion3Pipeline", "StableDiffusion3Img2ImgPipeline", @@ -515,7 +515,7 @@ AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, ) - from .aura_flow import AuraFlowPipeline + from .aura_flow import AuraFlowPipeline, AuraFlowImg2ImgPipeline from .blip_diffusion import BlipDiffusionPipeline from .cogvideo import ( CogVideoXFunControlPipeline, diff --git a/src/diffusers/pipelines/aura_flow/__init__.py b/src/diffusers/pipelines/aura_flow/__init__.py index e1917baa61e2..ad4974bfeae3 100644 --- a/src/diffusers/pipelines/aura_flow/__init__.py +++ b/src/diffusers/pipelines/aura_flow/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"] + _import_structure["pipeline_aura_flow_img2img"] = ["AuraFlowImg2ImgPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +34,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_aura_flow import AuraFlowPipeline + from .pipeline_aura_flow_img2img import AuraFlowImg2ImgPipeline else: import sys diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py new file mode 100644 index 000000000000..1e91af7bc708 --- /dev/null +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py @@ -0,0 +1,703 @@ +# Copyright 2025 AuraFlow Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, List, Optional, Tuple, Union + +import PIL +import torch +from transformers import T5Tokenizer, UMT5EncoderModel + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AuraFlowTransformer2DModel, AutoencoderKL +from diffusers.models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AuraFlowImg2ImgPipeline + >>> import requests + >>> from PIL import Image + >>> from io import BytesIO + + >>> # download an initial image + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> pipe = AuraFlowImg2ImgPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + >>> prompt = "A fantasy landscape, trending on artstation" + >>> image = pipe(prompt=prompt, image=init_image, strength=0.75, num_inference_steps=50).images[0] + >>> image.save("aura_flow_img2img.png") + ``` +""" + + +class AuraFlowImg2ImgPipeline(DiffusionPipeline): + r""" + Args: + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. AuraFlow uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + transformer ([`AuraFlowTransformer2DModel`]): + Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKL, + transformer: AuraFlowTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + prompt, + height, + width, + strength, + image, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") + + patch_size = 2 # AuraFlow uses patch size of 2 + required_divisor = self.vae_scale_factor * patch_size + if height % required_divisor != 0 or width % required_divisor != 0: + raise ValueError( + f"\`height\` and \`width\` have to be divisible by the VAE scale factor ({self.vae_scale_factor}) times the transformer patch size ({patch_size}), which is {required_divisor}. " + f"Your dimensions are ({height}, {width})." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + """ + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = max_sequence_length + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + text_input_ids = text_inputs["input_ids"] + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + text_inputs = {k: v.to(device) for k, v in text_inputs.items()} + prompt_embeds = self.text_encoder(**text_inputs)[0] + prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) + prompt_embeds = prompt_embeds * prompt_attention_mask + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.reshape(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + uncond_input = {k: v.to(device) for k, v in uncond_input.items()} + negative_prompt_embeds = self.text_encoder(**uncond_input)[0] + negative_prompt_attention_mask = ( + uncond_input["attention_mask"].unsqueeze(-1).expand(negative_prompt_embeds.shape) + ) + negative_prompt_embeds = negative_prompt_embeds * negative_prompt_attention_mask + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.reshape(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + def get_timesteps(self, num_inference_steps, strength, device): + # 1. Call set_timesteps with num_inference_steps + self.scheduler.set_timesteps(num_inference_steps, device=device) # Ensure scheduler uses the correct number of steps + + # 2. Calculate strength-based number of steps and offset + init_timestep_count = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep_count, 0) + + # 3. Get the timesteps *after* set_timesteps has been called (now has length num_inference_steps) + timesteps = self.scheduler.timesteps[t_start:] + + # 4. Return the correct slice and the number of actual steps + num_actual_inference_steps = len(timesteps) + return timesteps, num_actual_inference_steps + + def prepare_img2img_latents( + self, + image, + timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + latents = image + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if image.shape[0] == 1: + image = image.repeat(batch_size, 1, 1, 1) + + # encode the init image into latents and scale the latents + latents = self.vae.encode(image).latent_dist.sample(generator=generator) + latents = latents * self.vae.config.scaling_factor + + # get the original timestep using init_timestep + init_timestep = timestep + + # add noise to latents using the timesteps + noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) + + # Ensure timestep tensor is on the same device + t = init_timestep.to(latents.device) + + # Normalize timestep to [0, 1] range (using scheduler's config) + t = t / self.scheduler.config.num_train_timesteps + + # Reshape t to match the dimensions needed for broadcasting + required_dims = len(latents.shape) + current_dims = len(t.shape) + for _ in range(required_dims - current_dims): + t = t.unsqueeze(-1) + + # Interpolation: x_t = t * x_1 + (1 - t) * x_0 + latents = t * noise + (1 - t) * latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None, + strength: float = 0.8, + negative_prompt: Union[str, List[str]] = None, + num_inference_steps: int = 50, + sigmas: List[float] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose style you want to transfer. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images. + """ + # 0. Default height and width to transformer config + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + strength, + image, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._num_inference_steps = num_inference_steps + + # 2. Preprocess image + image = self.image_processor.preprocess(image) + + # 3. Determine batch size. + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 4. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 5. Prepare timesteps + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1] + + # 6. Prepare latent variables + latents = self.prepare_img2img_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0]) + timestep = timestep.to(latents.device, dtype=latents.dtype) + + # Make sure latent_model_input has the same dtype as the transformer + transformer_dtype = self.transformer.dtype + if latent_model_input.dtype != transformer_dtype: + latent_model_input = latent_model_input.to(dtype=transformer_dtype) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # Apply proper scaling factor and shift factor if available + if hasattr(self.vae.config, "scaling_factor") and hasattr(self.vae.config, "shift_factor") and getattr(self.vae.config, "shift_factor", None) is not None: + # Handle both scaling and shifting + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + else: + # Just scale using standard approach + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) \ No newline at end of file diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 6a5f6098b6fb..56d25053123d 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin from ..models.controlnets import ControlNetUnionModel from ..utils import is_sentencepiece_available -from .aura_flow import AuraFlowPipeline +from .aura_flow import AuraFlowPipeline, AuraFlowImg2ImgPipeline from .cogview3 import CogView3PlusPipeline from .cogview4 import CogView4ControlPipeline, CogView4Pipeline from .controlnet import ( @@ -165,6 +165,7 @@ ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), + ("auraflow", AuraFlowImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), ("flux", FluxImg2ImgPipeline), ("flux-controlnet", FluxControlNetImg2ImgPipeline), diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow_img2img.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow_img2img.py new file mode 100644 index 000000000000..734885f0eaca --- /dev/null +++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow_img2img.py @@ -0,0 +1,117 @@ +import unittest + +import numpy as np +import PIL.Image +import torch +from diffusers.utils.testing_utils import require_torch_gpu, torch_device +from transformers import AutoTokenizer, UMT5EncoderModel, AuraFlowPipelineFastTests + +from diffusers import ( + AuraFlowImg2ImgPipeline, # Added for Img2Img + AuraFlowPipeline, + AuraFlowTransformer2DModel, + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, +) + +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, +) + +class AuraFlowImg2ImgPipelineFastTests(AuraFlowPipelineFastTests): + pipeline_class = AuraFlowImg2ImgPipeline + params = frozenset( + [ + "prompt", + "image", + "strength", + "guidance_scale", + "negative_prompt", + "prompt_embeds", + "negative_prompt_embeds", + ] + ) + batch_params = frozenset(["prompt", "negative_prompt", "image"]) + test_layerwise_casting = False # T5 uses multiple devices + test_group_offloading = False # T5 uses multiple devices + + # Redefine get_dummy_inputs for Img2Img + def get_dummy_inputs(self, device, seed=0): + # Ensure image dimensions are divisible by VAE scale factor * transformer patch size + # vae_scale_factor = 8, patch_size = 2 => divisible by 16 + image = PIL.Image.new("RGB", (64, 64)) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "strength": 0.75, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + # height/width are inferred from image in img2img + } + return inputs + + # Override T2I test that requires height/width + def test_fused_qkv_projections(self): + # Inherited test expects height/width, skip for img2img dummy inputs + # Call the parent T2I test method directly if needed for coverage, + # but adapt inputs or skip if incompatible. + # For now, simply reimplement with img2img inputs + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist(pipe.transformer) + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3) + assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3) + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2) + + + def test_aura_flow_img2img_output_shape(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + + # Use dimensions divisible by vae_scale_factor * patch_size (8*2=16) + height_width_pairs = [(64, 64), (128, 48)] # 48 is divisible by 16 + + for height, width in height_width_pairs: + inputs = self.get_dummy_inputs(torch_device) + # Override dummy image size + inputs["image"] = PIL.Image.new("RGB", (width, height)) + # Pass height/width explicitly to test pipeline handles them (though inferred by default) + inputs["height"] = height + inputs["width"] = width + + output = pipe(**inputs) + image = output.images[0] + + # Expected shape is (height, width, 3) for np output + self.assertEqual(image.shape, (height, width, 3)) \ No newline at end of file From 6ff1af8c20683175fadef09a77cb2f79127d53e7 Mon Sep 17 00:00:00 2001 From: AstraliteHeart <astralite.heart@gmail.com> Date: Thu, 17 Apr 2025 19:30:50 +0000 Subject: [PATCH 2/5] Update i2i tests, fix style --- src/diffusers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- .../aura_flow/pipeline_aura_flow_img2img.py | 116 +++++++++++++----- src/diffusers/pipelines/auto_pipeline.py | 4 +- .../dummy_torch_and_transformers_objects.py | 15 +++ .../aura_flow/test_pipeline_aura_flow.py | 3 + .../test_pipeline_aura_flow_img2img.py | 109 +++++++++++++--- 7 files changed, 196 insertions(+), 55 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fe2ff12dc4c8..4bebc3404220 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -345,8 +345,8 @@ "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", "AudioLDMPipeline", - "AuraFlowPipeline", "AuraFlowImg2ImgPipeline", + "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", "CLIPImageProjection", diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a42b0b34bb83..352bf804d827 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -515,7 +515,7 @@ AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, ) - from .aura_flow import AuraFlowPipeline, AuraFlowImg2ImgPipeline + from .aura_flow import AuraFlowImg2ImgPipeline, AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .cogvideo import ( CogVideoXFunControlPipeline, diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py index 1e91af7bc708..2a96833f4617 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py @@ -21,10 +21,10 @@ from diffusers.image_processor import VaeImageProcessor from diffusers.models import AuraFlowTransformer2DModel, AutoencoderKL from diffusers.models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput if is_torch_xla_available(): @@ -119,12 +119,12 @@ def check_inputs( ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") - + patch_size = 2 # AuraFlow uses patch size of 2 required_divisor = self.vae_scale_factor * patch_size if height % required_divisor != 0 or width % required_divisor != 0: raise ValueError( - f"\`height\` and \`width\` have to be divisible by the VAE scale factor ({self.vae_scale_factor}) times the transformer patch size ({patch_size}), which is {required_divisor}. " + rf"\`height\` and \`width\` have to be divisible by the VAE scale factor ({self.vae_scale_factor}) times the transformer patch size ({patch_size}), which is {required_divisor}. " f"Your dimensions are ({height}, {width})." ) @@ -339,7 +339,7 @@ def prepare_latents( def get_timesteps(self, num_inference_steps, strength, device): # 1. Call set_timesteps with num_inference_steps - self.scheduler.set_timesteps(num_inference_steps, device=device) # Ensure scheduler uses the correct number of steps + self.scheduler.set_timesteps(num_inference_steps, device=device) # 2. Calculate strength-based number of steps and offset init_timestep_count = min(int(num_inference_steps * strength), num_inference_steps) @@ -353,14 +353,7 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_actual_inference_steps def prepare_img2img_latents( - self, - image, - timestep, - batch_size, - num_images_per_prompt, - dtype, - device, - generator=None + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( @@ -380,34 +373,87 @@ def prepare_img2img_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if image.shape[0] == 1: - image = image.repeat(batch_size, 1, 1, 1) + # Handle different batch size scenarios + if image.shape[0] < batch_size: + if batch_size % image.shape[0] == 0: + # Duplicate the image to match the batch size + additional_image_per_prompt = batch_size // image.shape[0] + image = torch.cat([image] * additional_image_per_prompt, dim=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to {batch_size} text prompts." + f" Batch size must be divisible by the image batch size." + ) # encode the init image into latents and scale the latents - latents = self.vae.encode(image).latent_dist.sample(generator=generator) + # 1. Get VAE distribution parameters (on device) + latent_dist = self.vae.encode(image).latent_dist + mean, std = latent_dist.mean, latent_dist.std # Already on device + + # 2. Sample noise for each batch element individually if using multiple generators + if isinstance(generator, list): + sample = torch.cat( + [ + torch.randn( + (1, *mean.shape[1:]), + generator=generator[i], + device=generator[i].device if hasattr(generator[i], "device") else "cpu", + dtype=mean.dtype, + ).to(mean.device) + for i in range(batch_size) + ] + ) + else: + # Single generator - use its device if it has one + generator_device = getattr(generator, "device", "cpu") if generator is not None else "cpu" + noise = torch.randn(mean.shape, generator=generator, device=generator_device, dtype=mean.dtype) + sample = noise.to(mean.device) + + # Compute latents + latents = mean + std * sample + + # Scale latents latents = latents * self.vae.config.scaling_factor # get the original timestep using init_timestep init_timestep = timestep # add noise to latents using the timesteps - noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) - + # Handle noise generation with multiple generators if provided + if isinstance(generator, list): + noise = torch.cat( + [ + torch.randn( + (1, *latents.shape[1:]), + generator=generator[i], + device=generator[i].device if hasattr(generator[i], "device") else "cpu", + dtype=latents.dtype, + ).to(latents.device) + for i in range(batch_size) + ] + ) + else: + # Single generator - use its device if it has one + generator_device = getattr(generator, "device", "cpu") if generator is not None else "cpu" + noise = torch.randn( + latents.shape, generator=generator, device=generator_device, dtype=latents.dtype + ).to(latents.device) + # Ensure timestep tensor is on the same device t = init_timestep.to(latents.device) - + # Normalize timestep to [0, 1] range (using scheduler's config) t = t / self.scheduler.config.num_train_timesteps - + # Reshape t to match the dimensions needed for broadcasting required_dims = len(latents.shape) current_dims = len(t.shape) for _ in range(required_dims - current_dims): t = t.unsqueeze(-1) - + # Interpolation: x_t = t * x_1 + (1 - t) * x_0 latents = t * noise + (1 - t) * latents - + return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae @@ -606,13 +652,14 @@ def __call__( negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, ) + if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 5. Prepare timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1] - + # 6. Prepare latent variables latents = self.prepare_img2img_latents( image, @@ -632,10 +679,13 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0]) - timestep = timestep.to(latents.device, dtype=latents.dtype) + # AureFlow use timestep value between 0 and 1, with t=1 as noise and t=0 as the image + # create a timestep tensor with the correct batch size + # ensure it matches the batch size of the model input + t_float = t / 1000 + timestep_tensor = torch.full( + (latent_model_input.shape[0],), t_float, device=latents.device, dtype=latents.dtype + ) # Make sure latent_model_input has the same dtype as the transformer transformer_dtype = self.transformer.dtype @@ -646,7 +696,7 @@ def __call__( noise_pred = self.transformer( latent_model_input, encoder_hidden_states=prompt_embeds, - timestep=timestep, + timestep=timestep_tensor, return_dict=False, )[0] @@ -682,15 +732,19 @@ def __call__( if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - + # Apply proper scaling factor and shift factor if available - if hasattr(self.vae.config, "scaling_factor") and hasattr(self.vae.config, "shift_factor") and getattr(self.vae.config, "shift_factor", None) is not None: + if ( + hasattr(self.vae.config, "scaling_factor") + and hasattr(self.vae.config, "shift_factor") + and getattr(self.vae.config, "shift_factor", None) is not None + ): # Handle both scaling and shifting latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor else: # Just scale using standard approach latents = latents / self.vae.config.scaling_factor - + image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) @@ -700,4 +754,4 @@ def __call__( if not return_dict: return (image,) - return ImagePipelineOutput(images=image) \ No newline at end of file + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 56d25053123d..ccc3413d8e55 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin from ..models.controlnets import ControlNetUnionModel from ..utils import is_sentencepiece_available -from .aura_flow import AuraFlowPipeline, AuraFlowImg2ImgPipeline +from .aura_flow import AuraFlowImg2ImgPipeline, AuraFlowPipeline from .cogview3 import CogView3PlusPipeline from .cogview4 import CogView4ControlPipeline, CogView4Pipeline from .controlnet import ( @@ -165,7 +165,7 @@ ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), - ("auraflow", AuraFlowImg2ImgPipeline), + ("auraflow", AuraFlowImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), ("flux", FluxImg2ImgPipeline), ("flux-controlnet", FluxControlNetImg2ImgPipeline), diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index b3c6efb8cdcf..05801cd3a935 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -257,6 +257,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AuraFlowImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AuraFlowPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py index 1eb9d1035c33..aeaefb527327 100644 --- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py +++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py @@ -135,3 +135,6 @@ def test_fused_qkv_projections(self): @unittest.skip("xformers attention processor does not exist for AuraFlow") def test_xformers_attention_forwardGenerator_pass(self): pass + + def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=0.0004): + self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff) diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow_img2img.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow_img2img.py index 734885f0eaca..3983dfb07a19 100644 --- a/tests/pipelines/aura_flow/test_pipeline_aura_flow_img2img.py +++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow_img2img.py @@ -3,16 +3,15 @@ import numpy as np import PIL.Image import torch -from diffusers.utils.testing_utils import require_torch_gpu, torch_device -from transformers import AutoTokenizer, UMT5EncoderModel, AuraFlowPipelineFastTests +from transformers import AutoTokenizer, UMT5EncoderModel from diffusers import ( - AuraFlowImg2ImgPipeline, # Added for Img2Img - AuraFlowPipeline, + AuraFlowImg2ImgPipeline, AuraFlowTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler, ) +from diffusers.utils.testing_utils import torch_device from ..test_pipelines_common import ( PipelineTesterMixin, @@ -20,7 +19,8 @@ check_qkv_fusion_processors_exist, ) -class AuraFlowImg2ImgPipelineFastTests(AuraFlowPipelineFastTests): + +class AuraFlowImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = AuraFlowImg2ImgPipeline params = frozenset( [ @@ -33,11 +33,50 @@ class AuraFlowImg2ImgPipelineFastTests(AuraFlowPipelineFastTests): "negative_prompt_embeds", ] ) - batch_params = frozenset(["prompt", "negative_prompt", "image"]) - test_layerwise_casting = False # T5 uses multiple devices - test_group_offloading = False # T5 uses multiple devices + batch_params = frozenset(["prompt", "image", "negative_prompt"]) + test_layerwise_casting = False # T5 uses multiple devices + test_group_offloading = False # T5 uses multiple devices + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = AuraFlowTransformer2DModel( + sample_size=32, + patch_size=2, + in_channels=4, + num_mmdit_layers=1, + num_single_dit_layers=1, + attention_head_dim=8, + num_attention_heads=4, + caption_projection_dim=32, + joint_attention_dim=32, + out_channels=4, + pos_embed_max_size=256, + ) + + text_encoder = UMT5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-umt5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=32, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } - # Redefine get_dummy_inputs for Img2Img def get_dummy_inputs(self, device, seed=0): # Ensure image dimensions are divisible by VAE scale factor * transformer patch size # vae_scale_factor = 8, patch_size = 2 => divisible by 16 @@ -59,12 +98,12 @@ def get_dummy_inputs(self, device, seed=0): } return inputs - # Override T2I test that requires height/width + def test_attention_slicing_forward_pass(self): + # Attention slicing needs to implemented differently for this because how single DiT and MMDiT + # blocks interfere with each other. + return + def test_fused_qkv_projections(self): - # Inherited test expects height/width, skip for img2img dummy inputs - # Call the parent T2I test method directly if needed for coverage, - # but adapt inputs or skip if incompatible. - # For now, simply reimplement with img2img inputs device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -77,9 +116,7 @@ def test_fused_qkv_projections(self): pipe.transformer.fuse_qkv_projections() assert check_qkv_fusion_processors_exist(pipe.transformer) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ) + assert check_qkv_fusion_matches_attn_procs_length(pipe.transformer, pipe.transformer.original_attn_processors) inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images @@ -94,13 +131,18 @@ def test_fused_qkv_projections(self): assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3) assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2) + @unittest.skip("xformers attention processor does not exist for AuraFlow") + def test_xformers_attention_forwardGenerator_pass(self): + pass def test_aura_flow_img2img_output_shape(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components).to(torch_device) - # Use dimensions divisible by vae_scale_factor * patch_size (8*2=16) - height_width_pairs = [(64, 64), (128, 48)] # 48 is divisible by 16 + # The positional embedding has a max size of 256 + # Each position is a (height/vae_scale_factor/patch_size) × (width/vae_scale_factor/patch_size) grid + # To stay within limits: (height/8/2) * (width/8/2) < 256 + height_width_pairs = [(32, 32), (64, 32)] # creates 4 and 16 positions respectively for height, width in height_width_pairs: inputs = self.get_dummy_inputs(torch_device) @@ -114,4 +156,31 @@ def test_aura_flow_img2img_output_shape(self): image = output.images[0] # Expected shape is (height, width, 3) for np output - self.assertEqual(image.shape, (height, width, 3)) \ No newline at end of file + self.assertEqual(image.shape, (height, width, 3)) + + def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=0.001): + self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff) + + def test_num_images_per_prompt(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + batch_sizes = [1] + num_images_per_prompts = [1, 2] + + for batch_size in batch_sizes: + for num_images_per_prompt in num_images_per_prompts: + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = 2 + + inputs["image"] = PIL.Image.new("RGB", (32, 32)) + + for key in inputs.keys(): + if key in self.batch_params: + inputs[key] = batch_size * [inputs[key]] + + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images + + assert len(images) == batch_size * num_images_per_prompt From 6ac5cbb76fc3ab3d478a704c6ab747b5cde7d29b Mon Sep 17 00:00:00 2001 From: AstraliteHeart <astralite.heart@gmail.com> Date: Thu, 17 Apr 2025 22:52:20 +0000 Subject: [PATCH 3/5] Use scale_noise directly and fix VAE decoding --- .../aura_flow/pipeline_aura_flow_img2img.py | 68 ++++++++++--------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py index 2a96833f4617..423ce01a270d 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py @@ -52,7 +52,7 @@ >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") >>> init_image = init_image.resize((768, 512)) - >>> pipe = AuraFlowImg2ImgPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16) + >>> pipe = AuraFlowImg2ImgPipeline.from_pretrained("fal/AuraFlow-v0.3", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "A fantasy landscape, trending on artstation" >>> image = pipe(prompt=prompt, image=init_image, strength=0.75, num_inference_steps=50).images[0] @@ -338,19 +338,20 @@ def prepare_latents( return latents def get_timesteps(self, num_inference_steps, strength, device): - # 1. Call set_timesteps with num_inference_steps + # Set timesteps using the full range initially self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps.to(device=device) - # 2. Calculate strength-based number of steps and offset - init_timestep_count = min(int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep_count, 0) + if len(timesteps) != num_inference_steps: + num_inference_steps = len(timesteps) # Adjust if scheduler changed num_steps - # 3. Get the timesteps *after* set_timesteps has been called (now has length num_inference_steps) + # Get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) timesteps = self.scheduler.timesteps[t_start:] - # 4. Return the correct slice and the number of actual steps - num_actual_inference_steps = len(timesteps) - return timesteps, num_actual_inference_steps + return timesteps, num_inference_steps - t_start def prepare_img2img_latents( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None @@ -385,11 +386,20 @@ def prepare_img2img_latents( f" Batch size must be divisible by the image batch size." ) + # Temporarily move VAE to float32 for encoding + vae_dtype = self.vae.dtype + if vae_dtype != torch.float32: + self.vae.to(dtype=torch.float32) + # encode the init image into latents and scale the latents # 1. Get VAE distribution parameters (on device) - latent_dist = self.vae.encode(image).latent_dist + latent_dist = self.vae.encode(image.to(dtype=torch.float32)).latent_dist mean, std = latent_dist.mean, latent_dist.std # Already on device + # Restore VAE dtype + if vae_dtype != torch.float32: + self.vae.to(dtype=vae_dtype) + # 2. Sample noise for each batch element individually if using multiple generators if isinstance(generator, list): sample = torch.cat( @@ -416,7 +426,7 @@ def prepare_img2img_latents( latents = latents * self.vae.config.scaling_factor # get the original timestep using init_timestep - init_timestep = timestep + init_timestep = timestep # Use the passed timestep directly # add noise to latents using the timesteps # Handle noise generation with multiple generators if provided @@ -439,20 +449,7 @@ def prepare_img2img_latents( latents.shape, generator=generator, device=generator_device, dtype=latents.dtype ).to(latents.device) - # Ensure timestep tensor is on the same device - t = init_timestep.to(latents.device) - - # Normalize timestep to [0, 1] range (using scheduler's config) - t = t / self.scheduler.config.num_train_timesteps - - # Reshape t to match the dimensions needed for broadcasting - required_dims = len(latents.shape) - current_dims = len(t.shape) - for _ in range(required_dims - current_dims): - t = t.unsqueeze(-1) - - # Interpolation: x_t = t * x_1 + (1 - t) * x_0 - latents = t * noise + (1 - t) * latents + latents = self.scheduler.scale_noise(latents, init_timestep, noise) return latents @@ -657,8 +654,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 5. Prepare timesteps - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - latent_timestep = timesteps[:1] + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength, device + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # Get the first timestep(s) for initial noise # 6. Prepare latent variables latents = self.prepare_img2img_latents( @@ -727,11 +726,11 @@ def __call__( if output_type == "latent": image = latents else: - # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast - if needs_upcasting: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + # Always upcast VAE to float32 for decoding + vae_dtype = self.vae.dtype + if vae_dtype != torch.float32: + self.vae.to(dtype=torch.float32) + latents = latents.to(dtype=torch.float32) # Apply proper scaling factor and shift factor if available if ( @@ -746,6 +745,11 @@ def __call__( latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] + + # Restore VAE dtype + if vae_dtype != torch.float32: + self.vae.to(dtype=vae_dtype) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models From 1b7fb36294d87e7c35e3e66c334c93b83f032f5d Mon Sep 17 00:00:00 2001 From: AstraliteHeart <astralite.heart@gmail.com> Date: Fri, 18 Apr 2025 10:29:30 +0000 Subject: [PATCH 4/5] Review updates --- .../aura_flow/pipeline_aura_flow_img2img.py | 124 +++++++++++------- 1 file changed, 73 insertions(+), 51 deletions(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py index 423ce01a270d..dae182c921f3 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py @@ -103,6 +103,31 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + @staticmethod + def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, + ): + """Calculate shift parameter based on image dimensions. + + Args: + image_seq_len: Length of the image sequence (height/vae_factor/2 * width/vae_factor/2) + base_seq_len: Base sequence length for interpolation + max_seq_len: Maximum sequence length for interpolation + base_shift: Base shift value + max_shift: Maximum shift value + + Returns: + Calculated shift parameter (mu) + """ + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + def check_inputs( self, prompt, @@ -305,41 +330,8 @@ def encode_prompt( return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - if latents is not None: - return latents.to(device=device, dtype=dtype) - - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - return latents - def get_timesteps(self, num_inference_steps, strength, device): # Set timesteps using the full range initially - self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps.to(device=device) if len(timesteps) != num_inference_steps: @@ -349,11 +341,15 @@ def get_timesteps(self, num_inference_steps, strength, device): init_timestep = min(num_inference_steps * strength, num_inference_steps) t_start = int(max(num_inference_steps - init_timestep, 0)) - timesteps = self.scheduler.timesteps[t_start:] + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] + + # Set begin index if scheduler supports it + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start - def prepare_img2img_latents( + def prepare_latents( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): @@ -361,6 +357,13 @@ def prepare_img2img_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) + # Check for latents_mean and latents_std in the VAE config + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -404,26 +407,30 @@ def prepare_img2img_latents( if isinstance(generator, list): sample = torch.cat( [ - torch.randn( + randn_tensor( (1, *mean.shape[1:]), generator=generator[i], - device=generator[i].device if hasattr(generator[i], "device") else "cpu", + device=mean.device, dtype=mean.dtype, - ).to(mean.device) + ) for i in range(batch_size) ] ) else: # Single generator - use its device if it has one - generator_device = getattr(generator, "device", "cpu") if generator is not None else "cpu" - noise = torch.randn(mean.shape, generator=generator, device=generator_device, dtype=mean.dtype) - sample = noise.to(mean.device) + sample = randn_tensor(mean.shape, generator=generator, device=mean.device, dtype=mean.dtype) # Compute latents latents = mean + std * sample - # Scale latents - latents = latents * self.vae.config.scaling_factor + # Apply standardization if VAE has mean and std defined in config + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + latents = (latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + # Scale latents + latents = latents * self.vae.config.scaling_factor # get the original timestep using init_timestep init_timestep = timestep # Use the passed timestep directly @@ -433,21 +440,20 @@ def prepare_img2img_latents( if isinstance(generator, list): noise = torch.cat( [ - torch.randn( + randn_tensor( (1, *latents.shape[1:]), generator=generator[i], - device=generator[i].device if hasattr(generator[i], "device") else "cpu", + device=latents.device, dtype=latents.dtype, - ).to(latents.device) + ) for i in range(batch_size) ] ) else: # Single generator - use its device if it has one - generator_device = getattr(generator, "device", "cpu") if generator is not None else "cpu" - noise = torch.randn( - latents.shape, generator=generator, device=generator_device, dtype=latents.dtype - ).to(latents.device) + noise = randn_tensor( + latents.shape, generator=generator, device=latents.device, dtype=latents.dtype + ) latents = self.scheduler.scale_noise(latents, init_timestep, noise) @@ -654,13 +660,29 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 5. Prepare timesteps + # Calculate shift parameter based on image dimensions + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + + # Calculate mu (shift parameter) based on image dimensions + mu = self.calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + + # Set timesteps with shift parameter + self.scheduler.set_timesteps(num_inference_steps, device=device, mu=mu) + + # Now adjust for strength timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, device ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # Get the first timestep(s) for initial noise # 6. Prepare latent variables - latents = self.prepare_img2img_latents( + latents = self.prepare_latents( image, latent_timestep, batch_size, From 937502046951d656dd736f62fcdc13d1108f601b Mon Sep 17 00:00:00 2001 From: AstraliteHeart <81396681+AstraliteHeart@users.noreply.github.com> Date: Thu, 1 May 2025 20:56:26 -0700 Subject: [PATCH 5/5] Update src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com> --- .../aura_flow/pipeline_aura_flow_img2img.py | 127 ++++++------------ 1 file changed, 39 insertions(+), 88 deletions(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py index dae182c921f3..5cf9a820681d 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py @@ -350,115 +350,66 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start def prepare_latents( - self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None + self, + image, + timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator=None, ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + f"`image` must be `torch.Tensor`, `PIL.Image.Image` or list, got {type(image)}" ) - # Check for latents_mean and latents_std in the VAE config - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - image = image.to(device=device, dtype=dtype) - batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: - latents = image + latents_0 = image else: - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - # Handle different batch size scenarios - if image.shape[0] < batch_size: - if batch_size % image.shape[0] == 0: - # Duplicate the image to match the batch size - additional_image_per_prompt = batch_size // image.shape[0] - image = torch.cat([image] * additional_image_per_prompt, dim=0) - else: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to {batch_size} text prompts." - f" Batch size must be divisible by the image batch size." - ) - - # Temporarily move VAE to float32 for encoding - vae_dtype = self.vae.dtype - if vae_dtype != torch.float32: + # VAE ⇢ latents (ALWAYS on fp32 for numerical stability) + orig_dtype = self.vae.dtype + if orig_dtype != torch.float32: self.vae.to(dtype=torch.float32) - # encode the init image into latents and scale the latents - # 1. Get VAE distribution parameters (on device) latent_dist = self.vae.encode(image.to(dtype=torch.float32)).latent_dist - mean, std = latent_dist.mean, latent_dist.std # Already on device + latents_0 = latent_dist.mean # ❶ deterministic! - # Restore VAE dtype - if vae_dtype != torch.float32: - self.vae.to(dtype=vae_dtype) + if orig_dtype != torch.float32: + self.vae.to(dtype=orig_dtype) - # 2. Sample noise for each batch element individually if using multiple generators - if isinstance(generator, list): - sample = torch.cat( - [ - randn_tensor( - (1, *mean.shape[1:]), - generator=generator[i], - device=mean.device, - dtype=mean.dtype, - ) - for i in range(batch_size) - ] - ) - else: - # Single generator - use its device if it has one - sample = randn_tensor(mean.shape, generator=generator, device=mean.device, dtype=mean.dtype) + # scale + latents_0 = latents_0 * self.vae.config.scaling_factor - # Compute latents - latents = mean + std * sample - - # Apply standardization if VAE has mean and std defined in config - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - latents = (latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - # Scale latents - latents = latents * self.vae.config.scaling_factor - - # get the original timestep using init_timestep - init_timestep = timestep # Use the passed timestep directly - - # add noise to latents using the timesteps - # Handle noise generation with multiple generators if provided - if isinstance(generator, list): - noise = torch.cat( - [ - randn_tensor( - (1, *latents.shape[1:]), - generator=generator[i], - device=latents.device, - dtype=latents.dtype, - ) - for i in range(batch_size) - ] - ) - else: - # Single generator - use its device if it has one - noise = randn_tensor( - latents.shape, generator=generator, device=latents.device, dtype=latents.dtype + # replicate to match `batch_size` + if latents_0.shape[0] != batch_size: + if batch_size % latents_0.shape[0] != 0: + raise ValueError( + f"Cannot duplicate image batch of size {latents_0.shape[0]} " + f"to effective batch size {batch_size}." ) + repeats = batch_size // latents_0.shape[0] + latents_0 = latents_0.repeat(repeats, 1, 1, 1) + + noise = randn_tensor( + latents_0.shape, + generator=generator, + device=latents_0.device, + dtype=latents_0.dtype, + ) - latents = self.scheduler.scale_noise(latents, init_timestep, noise) + # make sure `timestep` is 1-D and matches batch + if isinstance(timestep, (int, float)): + timestep = torch.tensor([timestep], device=latents_0.device, dtype=latents_0.dtype) + timestep = timestep.expand(latents_0.shape[0]) + latents = self.scheduler.scale_noise(latents_0, timestep, noise) return latents + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype