Skip to content

Commit db51dc3

Browse files
committed
Determine model automatically from r8 lora filename
1 parent ceb86d9 commit db51dc3

File tree

6 files changed

+37
-42
lines changed

6 files changed

+37
-42
lines changed

.github/workflows/ci.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
workflow_dispatch:
55
inputs:
66
models:
7-
description: 'Comma-separated list of models (standalone, trained1.3b, trained14b) to push or "all"'
7+
description: 'Comma-separated list of models (standalone, trained) to push or "all"'
88
type: string
99
default: 'all'
1010

@@ -44,7 +44,7 @@ jobs:
4444
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
4545
run: |
4646
if [ "${{ inputs.models }}" = "all" ]; then
47-
models="standalone, trained1.3b, trained14b"
47+
models="standalone, trained"
4848
else
4949
models="${{ inputs.models }}"
5050
fi
File renamed without changes.

cog-safe-push-configs/trained1.3b.yaml

-12
This file was deleted.

predict.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def generate(
217217
negative_prompt: str | None = None,
218218
aspect_ratio: str = "16:9",
219219
frames: int = 81,
220-
model: str = "14b",
220+
model: str | None = None,
221221
lora_url: str | None = None,
222222
lora_strength_model: float = 1.0,
223223
lora_strength_clip: float = 1.0,
@@ -232,15 +232,16 @@ def generate(
232232
self.comfyUI.cleanup(ALL_DIRECTORIES)
233233
seed = seed_helper.generate(seed)
234234

235-
if resolution == "720p" and model == "1.3b":
236-
print("Warning: 720p is not supported for 1.3b, using 480p instead")
237-
resolution = "480p"
238-
239235
lora_filename = None
240236
if replicate_weights:
241-
lora_filename = download_replicate_weights(
237+
lora_filename, model_type = download_replicate_weights(
242238
replicate_weights, "ComfyUI/models/loras"
243239
)
240+
model = model_type
241+
242+
if resolution == "720p" and model == "1.3b":
243+
print("Warning: 720p is not supported for 1.3b, using 480p instead")
244+
resolution = "480p"
244245

245246
with open(api_json_file, "r") as file:
246247
workflow = json.loads(file.read())
@@ -308,9 +309,7 @@ def predict(
308309
)
309310

310311

311-
class Trained14BLoraPredictor(Predictor):
312-
model = "14b"
313-
312+
class TrainedLoraPredictor(Predictor):
314313
def predict(
315314
self,
316315
prompt: str = Inputs.prompt,
@@ -332,7 +331,7 @@ def predict(
332331
negative_prompt=negative_prompt,
333332
aspect_ratio=aspect_ratio,
334333
frames=frames,
335-
model=self.model,
334+
model=None,
336335
resolution=resolution,
337336
lora_url=None,
338337
lora_strength_model=lora_strength_model,
@@ -344,7 +343,3 @@ def predict(
344343
seed=seed,
345344
replicate_weights=replicate_weights,
346345
)
347-
348-
349-
class Trained1_3BLoraPredictor(Trained14BLoraPredictor):
350-
model = "1.3b"

replicate_weights.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,30 @@
55
import hashlib
66
from pathlib import Path
77

8-
def get_filename_from_url(url: str, extension: str) -> str:
9-
"""Generate a unique filename from a URL using MD5 hash"""
10-
filename = hashlib.md5(url.encode()).hexdigest()
11-
return f"{filename}.{extension}"
8+
9+
def url_hash(url: str) -> str:
10+
return hashlib.md5(url.encode()).hexdigest()
11+
12+
13+
def get_filename_from_url(url: str, model_type: str) -> str:
14+
return f"{model_type}_{url_hash(url)}.safetensors"
15+
1216

1317
def download_replicate_weights(url: str, lora_dir: str):
1418
"""Downloads weights from a Replicate tarball URL and extracts the safetensors file"""
15-
unique_filename = get_filename_from_url(url, "safetensors")
16-
target_path = Path(lora_dir) / unique_filename
19+
hash = url_hash(url)
20+
21+
# Check if either version already exists
22+
lora_dir_path = Path(lora_dir)
23+
existing_14b = lora_dir_path / f"14b_{hash}.safetensors"
24+
existing_1_3b = lora_dir_path / f"1.3b_{hash}.safetensors"
1725

18-
if target_path.exists():
19-
print(f"✅ {unique_filename} already cached")
20-
return unique_filename
26+
if existing_14b.exists():
27+
print(f"✅ {existing_14b.name} already cached")
28+
return existing_14b.name, "14b"
29+
if existing_1_3b.exists():
30+
print(f"✅ {existing_1_3b.name} already cached")
31+
return existing_1_3b.name, "1.3b"
2132

2233
with tempfile.TemporaryDirectory() as temp_dir:
2334
temp_dir = Path(temp_dir)
@@ -41,6 +52,9 @@ def download_replicate_weights(url: str, lora_dir: str):
4152
if len(safetensors_paths) > 1:
4253
raise ValueError("Multiple .safetensors files found in tarball")
4354

55+
model_type = "1.3b" if "1.3b" in safetensors_paths[0].name.lower() else "14b"
56+
unique_filename = get_filename_from_url(url, model_type)
57+
target_path = Path(lora_dir) / unique_filename
4458
shutil.move(safetensors_paths[0], target_path)
4559

46-
return unique_filename
60+
return unique_filename, model_type

scripts/select-model

+2-4
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@ MODEL=$1
3434
# Set the predictor name
3535
if [[ "$MODEL" == "standalone" ]]; then
3636
export PREDICTOR=StandaloneLoraPredictor
37-
elif [[ "$MODEL" == "trained14b" ]]; then
38-
export PREDICTOR=Trained14BLoraPredictor
39-
elif [[ "$MODEL" == "trained1.3b" ]]; then
40-
export PREDICTOR=Trained1_3BLoraPredictor
37+
elif [[ "$MODEL" == "trained" ]]; then
38+
export PREDICTOR=TrainedLoraPredictor
4139
else
4240
echo "Unknown model: $MODEL. Valid models are:"
4341
list_models

0 commit comments

Comments
 (0)