-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathinference.py
69 lines (60 loc) · 2.9 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
from src.pipeline_pe_clone import FluxPipeline
import torch
from PIL import Image
def parse_args():
parser = argparse.ArgumentParser(description='FLUX image generation with LoRA')
parser.add_argument('--model_path', type=str,
default="black-forest-labs/FLUX.1-dev",
help='Path to pretrained model')
parser.add_argument('--image_path', type=str,
default="assets/1.png",
help='Input image path')
parser.add_argument('--output_path', type=str,
default="output.png",
help='Output image path')
parser.add_argument('--height', type=int, default=768)
parser.add_argument('--width', type=int, default=512)
parser.add_argument('--prompt', type=str,
default="add a halo and wings for the cat by sksmagiceffects",
help="""Different LoRA effects and their example prompts:
- sksmagiceffects: "add a halo and wings for the cat by sksmagiceffects"
- sksmonstercalledlulu: "add a red sksmonstercalledlulu hugging the cat"
- skspaintingeffects: "add a yellow flower on the cat's head and psychedelic colors and dynamic flows by skspaintingeffects"
- sksedgeeffect: "add yellow flames to the cat by sksedgeeffect"
""")
parser.add_argument('--guidance_scale', type=float, default=3.5)
parser.add_argument('--num_steps', type=int, default=20,
help='Number of inference steps')
parser.add_argument('--lora_name', type=str,
choices=['pretrained', 'sksmagiceffects', 'sksmonstercalledlulu',
'skspaintingeffects', 'sksedgeeffect'],
default="sksmagiceffects",
help='Name of LoRA weights to use. Use "pretrained" for base model only')
return parser.parse_args()
def main():
args = parse_args()
pipeline = FluxPipeline.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
).to('cuda')
# Load and fuse base LoRA weights
pipeline.load_lora_weights("nicolaus-huang/PhotoDoodle", weight_name="pretrain.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
# Load selected LoRA effect only if not using pretrained
if args.lora_name != 'pretrained':
pipeline.load_lora_weights("nicolaus-huang/PhotoDoodle", weight_name=f"{args.lora_name}.safetensors")
condition_image = Image.open(args.image_path).resize((args.height, args.width)).convert("RGB")
result = pipeline(
prompt=args.prompt,
condition_image=condition_image,
height=args.height,
width=args.width,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_steps,
max_sequence_length=512
).images[0]
result.save(args.output_path)
if __name__ == "__main__":
main()