Skip to content

Commit 538c0e3

Browse files
committed
Handle downloading of R8 weights before inferring model
1 parent db51dc3 commit 538c0e3

File tree

3 files changed

+32
-24
lines changed

3 files changed

+32
-24
lines changed

cog-safe-push-configs/standalone.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ test_model: fofr/staging
55
# Tests
66
predict:
77
compare_outputs: false
8-
predict_timeout: 600
8+
predict_timeout: 1800
99
test_cases:
1010
- inputs:
1111
prompt: "a MNCRFTMOV cat"

cog-safe-push-configs/trained.yaml

+6-1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,10 @@ predict:
88
predict_timeout: 600
99
test_cases:
1010
- inputs:
11-
prompt: "a LAEZEL portrait painting"
11+
prompt: "a 14b LAEZEL is laughing"
12+
# url contains lora.safetensors, which defaults to 14b
1213
replicate_weights: "https://replicate.delivery/xezq/begAR055rwyeb0sjY60xM01L9i2fJ8T11ofpHGfL6PCEYf2FF/trained_model.tar"
14+
- inputs:
15+
prompt: "a 1.3b HOMER is laughing"
16+
# url contains 1.3b-lora.safetensors, which defaults to 1.3b
17+
replicate_weights: "https://replicate.delivery/xezq/7kXMcHfE8GwHZ6CU4TR2iTod20dPX8SerMCMrRn16efou1hRB/trained_model.tar"

predict.py

+25-22
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
OUTPUT_DIR = "/tmp/outputs"
1414
INPUT_DIR = "/tmp/inputs"
1515
COMFYUI_TEMP_OUTPUT_DIR = "ComfyUI/temp"
16+
COMFYUI_LORAS_DIR = "ComfyUI/models/loras"
1617
ALL_DIRECTORIES = [OUTPUT_DIR, INPUT_DIR, COMFYUI_TEMP_OUTPUT_DIR]
1718

1819
mimetypes.add_type("image/webp", ".webp")
@@ -94,8 +95,6 @@ def setup(self):
9495
self.comfyUI.handle_weights(
9596
workflow,
9697
weights_to_download=[
97-
"wan2.1_t2v_1.3B_bf16.safetensors",
98-
"wan2.1_t2v_14B_bf16.safetensors",
9998
"wan_2.1_vae.safetensors",
10099
"umt5_xxl_fp16.safetensors",
101100
],
@@ -186,23 +185,7 @@ def update_workflow(self, workflow, **kwargs):
186185
if kwargs["lora_filename"]:
187186
lora_loader["lora_name"] = kwargs["lora_filename"]
188187
elif kwargs["lora_url"]:
189-
url = kwargs["lora_url"]
190-
if m := re.match(
191-
r"^(?:https?://replicate.com/)?([^/]+)/([^/]+)/?$", url
192-
):
193-
owner, model_name = m.groups()
194-
lora_filename = download_replicate_weights(
195-
f"https://replicate.com/{owner}/{model_name}/_weights",
196-
"ComfyUI/models/loras",
197-
)
198-
lora_loader["lora_name"] = lora_filename
199-
elif url.startswith("https://replicate.delivery"):
200-
lora_filename = download_replicate_weights(
201-
url, "ComfyUI/models/loras"
202-
)
203-
lora_loader["lora_name"] = lora_filename
204-
else:
205-
lora_loader["lora_name"] = url
188+
lora_loader["lora_name"] = kwargs["lora_url"]
206189

207190
lora_loader["strength_model"] = kwargs["lora_strength_model"]
208191
lora_loader["strength_clip"] = kwargs["lora_strength_clip"]
@@ -233,11 +216,31 @@ def generate(
233216
seed = seed_helper.generate(seed)
234217

235218
lora_filename = None
219+
inferred_model_type = None
236220
if replicate_weights:
237-
lora_filename, model_type = download_replicate_weights(
238-
replicate_weights, "ComfyUI/models/loras"
221+
lora_filename, inferred_model_type = download_replicate_weights(
222+
replicate_weights, COMFYUI_LORAS_DIR
239223
)
240-
model = model_type
224+
model = inferred_model_type
225+
elif lora_url:
226+
if m := re.match(
227+
r"^(?:https?://replicate.com/)?([^/]+)/([^/]+)/?$", lora_url
228+
):
229+
owner, model_name = m.groups()
230+
lora_filename, inferred_model_type = download_replicate_weights(
231+
f"https://replicate.com/{owner}/{model_name}/_weights",
232+
COMFYUI_LORAS_DIR,
233+
)
234+
elif lora_url.startswith("https://replicate.delivery"):
235+
lora_filename, inferred_model_type = download_replicate_weights(
236+
lora_url, COMFYUI_LORAS_DIR
237+
)
238+
239+
if inferred_model_type and inferred_model_type != model:
240+
print(
241+
f"Warning: Model type mismatch between requested model ({model}) and inferred model type ({inferred_model_type}). Using {inferred_model_type}."
242+
)
243+
model = inferred_model_type
241244

242245
if resolution == "720p" and model == "1.3b":
243246
print("Warning: 720p is not supported for 1.3b, using 480p instead")

0 commit comments

Comments
 (0)