diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 99a2f871c837..218394af2843 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -26,6 +26,7 @@ if is_torch_available(): _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] + _import_structure["auto_model"] = ["AutoModel"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] @@ -41,7 +42,6 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] - _import_structure["auto_model"] = ["AutoModel"] _import_structure["cache_utils"] = ["CacheMixin"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index cb35f67fa112..c86334068302 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -15,6 +15,8 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image import torch from transformers import ( CLIPTextModelWithProjection, @@ -39,7 +41,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput @@ -227,6 +229,8 @@ def __init__( feature_extractor: Optional[SiglipImageProcessor] = None, ): super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = SD3MultiControlNetModel(controlnet) self.register_modules( vae=vae, @@ -572,14 +576,52 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.check_inputs + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + def check_inputs( self, + height, + width, + image, prompt, prompt_2, prompt_3, - height, - width, negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None, @@ -587,6 +629,11 @@ def check_inputs( negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -669,6 +716,76 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, SD3MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, SD3ControlNetModel): + self.check_image(image, prompt, prompt_embeds) + elif isinstance(controlnet, SD3MultiControlNetModel): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + + # Check `controlnet_conditioning_scale` + if isinstance(controlnet, SD3MultiControlNetModel): + if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(controlnet, SD3MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control_guidance_start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control_guidance_start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control_guidance_end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents def prepare_latents( self, @@ -1040,11 +1157,12 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( + height, + width, + control_image, prompt, prompt_2, prompt_3, - height, - width, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, @@ -1052,6 +1170,11 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) @@ -1119,9 +1242,26 @@ def __call__( width = latent_width * self.vae_scale_factor elif isinstance(self.controlnet, SD3MultiControlNetModel): - raise NotImplementedError("MultiControlNetModel is not supported for SD3ControlNetInpaintingPipeline.") + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image_with_mask( + image=control_image_, + mask=control_mask, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + control_images.append(control_image_) + + control_image = control_images else: - assert False + assert ValueError("Controlnet not found. Please check the controlnet model.") if controlnet_pooled_projections is None: controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)