2121from ...configuration_utils import FrozenDict
2222from ...image_processor import VaeImageProcessor
2323from ...models import AutoencoderKL
24- from ...models .attention_processor import AttnProcessor2_0 , XFormersAttnProcessor
2524from ...utils import logging
2625from ..modular_pipeline import (
2726 ModularPipelineBlocks ,
@@ -74,25 +73,6 @@ def intermediate_outputs(self) -> List[str]:
7473 )
7574 ]
7675
77- @staticmethod
78- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components
79- def upcast_vae (components ):
80- dtype = components .vae .dtype
81- components .vae .to (dtype = torch .float32 )
82- use_torch_2_0_or_xformers = isinstance (
83- components .vae .decoder .mid_block .attentions [0 ].processor ,
84- (
85- AttnProcessor2_0 ,
86- XFormersAttnProcessor ,
87- ),
88- )
89- # if xformers or torch_2_0 is used attention block does not need
90- # to be in float32 which can save lots of memory
91- if use_torch_2_0_or_xformers :
92- components .vae .post_quant_conv .to (dtype )
93- components .vae .decoder .conv_in .to (dtype )
94- components .vae .decoder .mid_block .to (dtype )
95-
9676 @torch .no_grad ()
9777 def __call__ (self , components , state : PipelineState ) -> PipelineState :
9878 block_state = self .get_block_state (state )
@@ -103,7 +83,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
10383 block_state .needs_upcasting = components .vae .dtype == torch .float16 and components .vae .config .force_upcast
10484
10585 if block_state .needs_upcasting :
106- self .upcast_vae ( components )
86+ self .components . vae . to ( torch . float32 )
10787 latents = latents .to (next (iter (components .vae .post_quant_conv .parameters ())).dtype )
10888 elif latents .dtype != components .vae .dtype :
10989 if torch .backends .mps .is_available ():
0 commit comments