Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/diffusers/modular_pipelines/qwenimage/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
block_state = self.get_block_state(state)

# for edit, image size can be different from the target size (height/width)

block_state.img_shapes = [
[
(
Expand Down Expand Up @@ -640,6 +639,43 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state


class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
model_name = "qwenimage-edit-plus"

@property
def inputs(self) -> List[InputParam]:
existing_inputs = super().inputs
current_inputs = [InputParam("vae_image_sizes", type_hint=List[Tuple[int, int]], required=True)]
return existing_inputs + current_inputs

def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

vae_scale_factor = components.vae_scale_factor
block_state.img_shapes = [
[
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
*[
(1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
for vae_width, vae_height in block_state.vae_image_sizes
],
]
] * block_state.batch_size

block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)

self.set_block_state(state, block_state)

return components, state


## ControlNet inputs for denoiser
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
model_name = "qwenimage"
Expand Down
45 changes: 27 additions & 18 deletions src/diffusers/modular_pipelines/qwenimage/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import PIL
import torch
Expand Down Expand Up @@ -330,7 +330,7 @@ def __init__(
output_name: str = "resized_image",
vae_image_output_name: str = "vae_image",
):
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
"""Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio.

This block resizes an input image or a list input images and exposes the resized result under configurable
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
Expand Down Expand Up @@ -803,9 +803,7 @@ def inputs(self) -> List[InputParam]:

@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="processed_image"),
]
return [OutputParam(name="processed_image")]

@staticmethod
def check_inputs(height, width, vae_scale_factor):
Expand Down Expand Up @@ -845,7 +843,10 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):

class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
model_name = "qwenimage-edit-plus"
vae_image_size = 1024 * 1024

def __init__(self):
self.vae_image_size = 1024 * 1024
super().__init__()

@property
def description(self) -> str:
Expand All @@ -855,13 +856,20 @@ def description(self) -> str:
def inputs(self) -> List[InputParam]:
return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")]

@property
def intermediate_outputs(self):
existing_outputs = super().intermediate_outputs
current_outputs = [OutputParam("vae_image_sizes", type_hint=List[Tuple[int, int]])]
return existing_outputs + current_outputs

@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)

if block_state.vae_image is None and block_state.image is None:
raise ValueError("`vae_image` and `image` cannot be None at the same time")

vae_image_sizes = None
if block_state.vae_image is None:
image = block_state.image
self.check_inputs(
Expand All @@ -873,12 +881,18 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
image=image, height=height, width=width
)
else:
width, height = block_state.vae_image[0].size
image = block_state.vae_image
processed_images = []
vae_image_sizes = []
for img in block_state.vae_image:
width, height = img.size
vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height)
vae_image_sizes.append((vae_width, vae_height))
processed_images.append(
components.image_processor.preprocess(image=img, height=vae_height, width=vae_width)
)
block_state.processed_image = torch.stack(processed_images, dim=0).squeeze(1)

block_state.processed_image = components.image_processor.preprocess(
image=image, height=height, width=width
)
block_state.vae_image_sizes = vae_image_sizes

self.set_block_state(state, block_state)
return components, state
Expand Down Expand Up @@ -920,17 +934,12 @@ def description(self) -> str:

@property
def expected_components(self) -> List[ComponentSpec]:
components = [
ComponentSpec("vae", AutoencoderKLQwenImage),
]
components = [ComponentSpec("vae", AutoencoderKLQwenImage)]
return components

@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(self._image_input_name, required=True),
InputParam("generator"),
]
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
return inputs

@property
Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/modular_pipelines/qwenimage/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
reshape_to_seq_dim: bool = False,
):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"

Expand All @@ -245,6 +246,9 @@ def __init__(
Names of additional conditional input tensors to expand batch size. These tensors will only have their
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
Defaults to []. Examples: ["processed_mask_image"]
reshape_to_seq_dim: (bool, optional):
If the packed output should be reshaped along the sequence dimension. Example: `[2, 4096, 64]` => `[1,
8192, 64]`. This is needed for QwenImage Edit Plus.

Examples:
# Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
Expand All @@ -263,6 +267,7 @@ def __init__(

self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
self.reshape_to_seq_dim = reshape_to_seq_dim
super().__init__()

@property
Expand Down Expand Up @@ -341,6 +346,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -

# 2. Patchify the image latent tensor
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
if self.reshape_to_seq_dim:
channels = image_latent_tensor.shape[-1]
image_latent_tensor = image_latent_tensor.reshape(1, -1, channels)

# 3. Expand batch size
image_latent_tensor = repeat_tensor_to_batch_size(
Expand Down
69 changes: 60 additions & 9 deletions src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .before_denoise import (
QwenImageControlNetBeforeDenoiserStep,
QwenImageCreateMaskLatentsStep,
QwenImageEditPlusRoPEInputsStep,
QwenImageEditRoPEInputsStep,
QwenImagePrepareLatentsStep,
QwenImagePrepareLatentsWithStrengthStep,
Expand Down Expand Up @@ -911,7 +912,7 @@ def description(self) -> str:


class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage"
model_name = "qwenimage-edit-plus"
block_classes = QwenImageEditPlusVaeEncoderBlocks.values()
block_names = QwenImageEditPlusVaeEncoderBlocks.keys()

Expand All @@ -920,25 +921,62 @@ def description(self) -> str:
return "Vae encoder step that encode the image inputs into their latent representations."


#### QwenImage Edit Plus input blocks
QwenImageEditPlusInputBlocks = InsertableDict(
[
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
(
"additional_inputs",
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"], reshape_to_seq_dim=True),
),
]
)


class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = QwenImageEditPlusInputBlocks.values()
block_names = QwenImageEditPlusInputBlocks.keys()


#### QwenImage Edit Plus presets
EDIT_PLUS_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageEditPlusVLEncoderStep()),
("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
("input", QwenImageEditInputStep()),
("input", QwenImageEditPlusInputStep()),
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsStep()),
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
("denoise", QwenImageEditDenoiseStep()),
("decode", QwenImageDecodeStep()),
]
)


QwenImageEditPlusBeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsStep()),
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
]
)


class QwenImageEditPlusBeforeDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = QwenImageEditPlusBeforeDenoiseBlocks.values()
block_names = QwenImageEditPlusBeforeDenoiseBlocks.keys()

@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."


# auto before_denoise step for edit tasks
class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = [QwenImageEditBeforeDenoiseStep]
block_classes = [QwenImageEditPlusBeforeDenoiseStep]
block_names = ["edit"]
block_trigger_inputs = ["image_latents"]

Expand All @@ -947,7 +985,7 @@ def description(self):
return (
"Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ "This is an auto pipeline block that works for edit (img2img) task.\n"
+ " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
+ " - `QwenImageEditPlusBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
+ " - if `image_latents` is not provided, step will be skipped."
)

Expand All @@ -956,9 +994,7 @@ def description(self):


class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [
QwenImageEditPlusVaeEncoderStep,
]
block_classes = [QwenImageEditPlusVaeEncoderStep]
block_names = ["edit"]
block_trigger_inputs = ["image"]

Expand All @@ -975,10 +1011,25 @@ def description(self):
## 3.3 QwenImage-Edit/auto blocks & presets


class QwenImageEditPlusAutoInputStep(AutoPipelineBlocks):
block_classes = [QwenImageEditPlusInputStep]
block_names = ["edit"]
block_trigger_inputs = ["image_latents"]

@property
def description(self):
return (
"Input step that prepares the inputs for the edit denoising step.\n"
+ " It is an auto pipeline block that works for edit task.\n"
+ " - `QwenImageEditPlusInputStep` (edit) is used when `image_latents` is provided.\n"
+ " - if `image_latents` is not provided, step will be skipped."
)


class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageEditAutoInputStep,
QwenImageEditPlusAutoInputStep,
QwenImageEditPlusAutoBeforeDenoiseStep,
QwenImageEditAutoDenoiseStep,
]
Expand Down
Loading