|
13 | 13 | OUTPUT_DIR = "/tmp/outputs"
|
14 | 14 | INPUT_DIR = "/tmp/inputs"
|
15 | 15 | COMFYUI_TEMP_OUTPUT_DIR = "ComfyUI/temp"
|
| 16 | +COMFYUI_LORAS_DIR = "ComfyUI/models/loras" |
16 | 17 | ALL_DIRECTORIES = [OUTPUT_DIR, INPUT_DIR, COMFYUI_TEMP_OUTPUT_DIR]
|
17 | 18 |
|
18 | 19 | mimetypes.add_type("image/webp", ".webp")
|
@@ -94,8 +95,6 @@ def setup(self):
|
94 | 95 | self.comfyUI.handle_weights(
|
95 | 96 | workflow,
|
96 | 97 | weights_to_download=[
|
97 |
| - "wan2.1_t2v_1.3B_bf16.safetensors", |
98 |
| - "wan2.1_t2v_14B_bf16.safetensors", |
99 | 98 | "wan_2.1_vae.safetensors",
|
100 | 99 | "umt5_xxl_fp16.safetensors",
|
101 | 100 | ],
|
@@ -186,23 +185,7 @@ def update_workflow(self, workflow, **kwargs):
|
186 | 185 | if kwargs["lora_filename"]:
|
187 | 186 | lora_loader["lora_name"] = kwargs["lora_filename"]
|
188 | 187 | elif kwargs["lora_url"]:
|
189 |
| - url = kwargs["lora_url"] |
190 |
| - if m := re.match( |
191 |
| - r"^(?:https?://replicate.com/)?([^/]+)/([^/]+)/?$", url |
192 |
| - ): |
193 |
| - owner, model_name = m.groups() |
194 |
| - lora_filename = download_replicate_weights( |
195 |
| - f"https://replicate.com/{owner}/{model_name}/_weights", |
196 |
| - "ComfyUI/models/loras", |
197 |
| - ) |
198 |
| - lora_loader["lora_name"] = lora_filename |
199 |
| - elif url.startswith("https://replicate.delivery"): |
200 |
| - lora_filename = download_replicate_weights( |
201 |
| - url, "ComfyUI/models/loras" |
202 |
| - ) |
203 |
| - lora_loader["lora_name"] = lora_filename |
204 |
| - else: |
205 |
| - lora_loader["lora_name"] = url |
| 188 | + lora_loader["lora_name"] = kwargs["lora_url"] |
206 | 189 |
|
207 | 190 | lora_loader["strength_model"] = kwargs["lora_strength_model"]
|
208 | 191 | lora_loader["strength_clip"] = kwargs["lora_strength_clip"]
|
@@ -233,11 +216,31 @@ def generate(
|
233 | 216 | seed = seed_helper.generate(seed)
|
234 | 217 |
|
235 | 218 | lora_filename = None
|
| 219 | + inferred_model_type = None |
236 | 220 | if replicate_weights:
|
237 |
| - lora_filename, model_type = download_replicate_weights( |
238 |
| - replicate_weights, "ComfyUI/models/loras" |
| 221 | + lora_filename, inferred_model_type = download_replicate_weights( |
| 222 | + replicate_weights, COMFYUI_LORAS_DIR |
239 | 223 | )
|
240 |
| - model = model_type |
| 224 | + model = inferred_model_type |
| 225 | + elif lora_url: |
| 226 | + if m := re.match( |
| 227 | + r"^(?:https?://replicate.com/)?([^/]+)/([^/]+)/?$", lora_url |
| 228 | + ): |
| 229 | + owner, model_name = m.groups() |
| 230 | + lora_filename, inferred_model_type = download_replicate_weights( |
| 231 | + f"https://replicate.com/{owner}/{model_name}/_weights", |
| 232 | + COMFYUI_LORAS_DIR, |
| 233 | + ) |
| 234 | + elif lora_url.startswith("https://replicate.delivery"): |
| 235 | + lora_filename, inferred_model_type = download_replicate_weights( |
| 236 | + lora_url, COMFYUI_LORAS_DIR |
| 237 | + ) |
| 238 | + |
| 239 | + if inferred_model_type and inferred_model_type != model: |
| 240 | + print( |
| 241 | + f"Warning: Model type mismatch between requested model ({model}) and inferred model type ({inferred_model_type}). Using {inferred_model_type}." |
| 242 | + ) |
| 243 | + model = inferred_model_type |
241 | 244 |
|
242 | 245 | if resolution == "720p" and model == "1.3b":
|
243 | 246 | print("Warning: 720p is not supported for 1.3b, using 480p instead")
|
|
0 commit comments