Skip to content

Commit 7390638

Browse files
committed
update
1 parent 21a03f9 commit 7390638

File tree

1 file changed

+1
-21
lines changed
  • src/diffusers/modular_pipelines/stable_diffusion_xl

1 file changed

+1
-21
lines changed

src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from ...configuration_utils import FrozenDict
2222
from ...image_processor import VaeImageProcessor
2323
from ...models import AutoencoderKL
24-
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
2524
from ...utils import logging
2625
from ..modular_pipeline import (
2726
ModularPipelineBlocks,
@@ -74,25 +73,6 @@ def intermediate_outputs(self) -> List[str]:
7473
)
7574
]
7675

77-
@staticmethod
78-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components
79-
def upcast_vae(components):
80-
dtype = components.vae.dtype
81-
components.vae.to(dtype=torch.float32)
82-
use_torch_2_0_or_xformers = isinstance(
83-
components.vae.decoder.mid_block.attentions[0].processor,
84-
(
85-
AttnProcessor2_0,
86-
XFormersAttnProcessor,
87-
),
88-
)
89-
# if xformers or torch_2_0 is used attention block does not need
90-
# to be in float32 which can save lots of memory
91-
if use_torch_2_0_or_xformers:
92-
components.vae.post_quant_conv.to(dtype)
93-
components.vae.decoder.conv_in.to(dtype)
94-
components.vae.decoder.mid_block.to(dtype)
95-
9676
@torch.no_grad()
9777
def __call__(self, components, state: PipelineState) -> PipelineState:
9878
block_state = self.get_block_state(state)
@@ -103,7 +83,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
10383
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
10484

10585
if block_state.needs_upcasting:
106-
self.upcast_vae(components)
86+
self.components.vae.to(torch.float32)
10787
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
10888
elif latents.dtype != components.vae.dtype:
10989
if torch.backends.mps.is_available():

0 commit comments

Comments
 (0)