@@ -47,6 +47,7 @@ def prompt_to_image(
4747 import diffusers
4848 import torch
4949 from PIL import Image , ImageOps
50+ from diffusers .pipelines .wuerstchen import WuerstchenCombinedPipeline
5051
5152 device = self .choose_device (optimizations )
5253
@@ -56,8 +57,8 @@ def prompt_to_image(
5657 else :
5758 pipe = self .load_model (diffusers .AutoPipelineForText2Image , model , optimizations , scheduler )
5859 refiner = None
59- height = height or pipe .unet .config .sample_size * pipe .vae_scale_factor
60- width = width or pipe .unet .config .sample_size * pipe .vae_scale_factor
60+ height = height or (( pipe .unet .config .sample_size * pipe .vae_scale_factor ) if hasattr ( pipe , 'unet' ) and hasattr ( pipe , 'vae_scale_factor' ) else 512 )
61+ width = width or (( pipe .unet .config .sample_size * pipe .vae_scale_factor ) if hasattr ( pipe , 'unet' ) and hasattr ( pipe , 'vae_scale_factor' ) else 512 )
6162
6263 # Optimizations
6364 pipe = optimizations .apply (pipe , device )
@@ -73,8 +74,16 @@ def prompt_to_image(
7374 generator = generator [0 ]
7475
7576 # Seamless
76- _configure_model_padding (pipe .unet , seamless_axes )
77- _configure_model_padding (pipe .vae , seamless_axes )
77+ if hasattr (pipe , 'unet' ):
78+ _configure_model_padding (pipe .unet , seamless_axes )
79+ if hasattr (pipe , 'vae' ):
80+ _configure_model_padding (pipe .vae , seamless_axes )
81+ if hasattr (pipe , 'prior_prior' ):
82+ _configure_model_padding (pipe .prior_prior , seamless_axes )
83+ if hasattr (pipe , 'decoder' ):
84+ _configure_model_padding (pipe .decoder , seamless_axes )
85+ if hasattr (pipe , 'vqgan' ):
86+ _configure_model_padding (pipe .vqgan , seamless_axes )
7887
7988 # Inference
8089 with torch .inference_mode () if device not in ('mps' , "dml" ) else nullcontext ():
@@ -85,23 +94,28 @@ def callback(step, timestep, latents):
8594 raise InterruptedError ()
8695 future .add_response (ImageGenerationResult .step_preview (self , step_preview_mode , width , height , latents , generator , step ))
8796 try :
88- result = pipe (
89- prompt = prompt ,
90- height = height ,
91- width = width ,
92- num_inference_steps = steps ,
93- guidance_scale = cfg_scale ,
94- negative_prompt = negative_prompt if use_negative_prompt else None ,
95- num_images_per_prompt = 1 ,
96- eta = 0.0 ,
97- generator = generator ,
98- latents = None ,
99- output_type = output_type ,
100- return_dict = True ,
101- callback = callback ,
102- callback_steps = 1 ,
103- #cfg_end=optimizations.cfg_end
104- )
97+ pipe_kwargs = {
98+ 'prompt' : prompt ,
99+ 'height' : height ,
100+ 'width' : width ,
101+ 'num_inference_steps' : steps ,
102+ 'guidance_scale' : cfg_scale ,
103+ 'negative_prompt' : negative_prompt if use_negative_prompt else None ,
104+ 'num_images_per_prompt' : 1 ,
105+ 'eta' : 0.0 ,
106+ 'generator' : generator ,
107+ 'latents' : None ,
108+ 'output_type' : output_type ,
109+ 'return_dict' : True ,
110+ 'callback' : callback ,
111+ 'callback_steps' : 1 ,
112+ }
113+ if isinstance (pipe , WuerstchenCombinedPipeline ):
114+ pipe_kwargs ['prior_guidance_scale' ] = pipe_kwargs .pop ('guidance_scale' )
115+ del pipe_kwargs ['eta' ]
116+ del pipe_kwargs ['callback' ]
117+ del pipe_kwargs ['callback_steps' ]
118+ result = pipe (** pipe_kwargs )
105119 if is_sdxl and sdxl_refiner_model is not None and refiner is None :
106120 # allow load_model() to garbage collect pipe
107121 pipe = None
0 commit comments