Skip to content

Commit 5568ac1

Browse files
committed
Add support for warp-ai/wuerstchen
1 parent 978e7b9 commit 5568ac1

File tree

2 files changed

+43
-21
lines changed

2 files changed

+43
-21
lines changed

generator_process/actions/huggingface_hub.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def list_dir(cache_dir):
6363
def detect_model_type(snapshot_folder):
6464
unet_config = os.path.join(snapshot_folder, 'unet', 'config.json')
6565
config = os.path.join(snapshot_folder, 'config.json')
66+
model_index = os.path.join(snapshot_folder, 'model_index.json')
6667
if os.path.exists(unet_config):
6768
with open(unet_config, 'r') as f:
6869
return ModelType(json.load(f)['in_channels'])
@@ -73,6 +74,13 @@ def detect_model_type(snapshot_folder):
7374
return ModelType.CONTROL_NET
7475
else:
7576
return ModelType.UNKNOWN
77+
elif os.path.exists(model_index):
78+
with open(model_index, 'r') as f:
79+
model_index_dict = json.load(f)
80+
if '_class_name' in model_index_dict and model_index_dict['_class_name'] == 'WuerstchenDecoderPipeline':
81+
return ModelType.PROMPT_TO_IMAGE
82+
else:
83+
return ModelType.UNKNOWN
7684
else:
7785
return ModelType.UNKNOWN
7886

generator_process/actions/prompt_to_image.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def prompt_to_image(
4747
import diffusers
4848
import torch
4949
from PIL import Image, ImageOps
50+
from diffusers.pipelines.wuerstchen import WuerstchenCombinedPipeline
5051

5152
device = self.choose_device(optimizations)
5253

@@ -56,8 +57,8 @@ def prompt_to_image(
5657
else:
5758
pipe = self.load_model(diffusers.AutoPipelineForText2Image, model, optimizations, scheduler)
5859
refiner = None
59-
height = height or pipe.unet.config.sample_size * pipe.vae_scale_factor
60-
width = width or pipe.unet.config.sample_size * pipe.vae_scale_factor
60+
height = height or ((pipe.unet.config.sample_size * pipe.vae_scale_factor) if hasattr(pipe, 'unet') and hasattr(pipe, 'vae_scale_factor') else 512)
61+
width = width or ((pipe.unet.config.sample_size * pipe.vae_scale_factor) if hasattr(pipe, 'unet') and hasattr(pipe, 'vae_scale_factor') else 512)
6162

6263
# Optimizations
6364
pipe = optimizations.apply(pipe, device)
@@ -73,8 +74,16 @@ def prompt_to_image(
7374
generator = generator[0]
7475

7576
# Seamless
76-
_configure_model_padding(pipe.unet, seamless_axes)
77-
_configure_model_padding(pipe.vae, seamless_axes)
77+
if hasattr(pipe, 'unet'):
78+
_configure_model_padding(pipe.unet, seamless_axes)
79+
if hasattr(pipe, 'vae'):
80+
_configure_model_padding(pipe.vae, seamless_axes)
81+
if hasattr(pipe, 'prior_prior'):
82+
_configure_model_padding(pipe.prior_prior, seamless_axes)
83+
if hasattr(pipe, 'decoder'):
84+
_configure_model_padding(pipe.decoder, seamless_axes)
85+
if hasattr(pipe, 'vqgan'):
86+
_configure_model_padding(pipe.vqgan, seamless_axes)
7887

7988
# Inference
8089
with torch.inference_mode() if device not in ('mps', "dml") else nullcontext():
@@ -85,23 +94,28 @@ def callback(step, timestep, latents):
8594
raise InterruptedError()
8695
future.add_response(ImageGenerationResult.step_preview(self, step_preview_mode, width, height, latents, generator, step))
8796
try:
88-
result = pipe(
89-
prompt=prompt,
90-
height=height,
91-
width=width,
92-
num_inference_steps=steps,
93-
guidance_scale=cfg_scale,
94-
negative_prompt=negative_prompt if use_negative_prompt else None,
95-
num_images_per_prompt=1,
96-
eta=0.0,
97-
generator=generator,
98-
latents=None,
99-
output_type=output_type,
100-
return_dict=True,
101-
callback=callback,
102-
callback_steps=1,
103-
#cfg_end=optimizations.cfg_end
104-
)
97+
pipe_kwargs = {
98+
'prompt': prompt,
99+
'height': height,
100+
'width': width,
101+
'num_inference_steps': steps,
102+
'guidance_scale': cfg_scale,
103+
'negative_prompt': negative_prompt if use_negative_prompt else None,
104+
'num_images_per_prompt': 1,
105+
'eta': 0.0,
106+
'generator': generator,
107+
'latents': None,
108+
'output_type': output_type,
109+
'return_dict': True,
110+
'callback': callback,
111+
'callback_steps': 1,
112+
}
113+
if isinstance(pipe, WuerstchenCombinedPipeline):
114+
pipe_kwargs['prior_guidance_scale'] = pipe_kwargs.pop('guidance_scale')
115+
del pipe_kwargs['eta']
116+
del pipe_kwargs['callback']
117+
del pipe_kwargs['callback_steps']
118+
result = pipe(**pipe_kwargs)
105119
if is_sdxl and sdxl_refiner_model is not None and refiner is None:
106120
# allow load_model() to garbage collect pipe
107121
pipe = None

0 commit comments

Comments
 (0)