1
1
import argparse
2
2
import json
3
3
import os
4
- from typing import List , Tuple
4
+ from typing import List , Optional , Tuple
5
5
6
- import torch
7
6
from joblib import Parallel , delayed
8
7
from loguru import logger
9
- from PIL import Image
10
8
from tqdm import tqdm
11
9
12
- from config import create_cfg , merge_possible_with_base , show_config
13
- from modeling import build_model
14
10
from modeling .text_translation import TextTranslationDiffusion
15
11
16
12
@@ -23,34 +19,36 @@ def copy_parameters(from_parameters, to_parameters):
23
19
24
20
def parse_args ():
25
21
parser = argparse .ArgumentParser ()
26
- parser .add_argument ("--config" , default = None , type = str )
27
22
parser .add_argument ("--save-folder" , default = "batch_images" , type = str )
28
23
parser .add_argument ("--source-root" , required = True , type = str )
29
24
parser .add_argument ("--source-list" , required = True , type = str )
30
- parser .add_argument ("--source-label" , required = True , type = int )
31
25
parser .add_argument ("--num-process" , default = 1 , type = int )
32
26
parser .add_argument ("--num-of-step" , default = 180 , type = int )
33
- parser .add_argument ("--opts" , nargs = argparse .REMAINDER , default = None , type = str )
27
+ parser .add_argument ("--img-size" , default = 512 , type = int )
28
+ parser .add_argument ("--model-path" , default = None , type = str )
29
+ parser .add_argument ("--scheduler" , default = "ddpm" , type = str )
30
+ parser .add_argument ("--sample-steps" , default = 1000 , type = int )
34
31
return parser .parse_args ()
35
32
36
33
37
34
def generate_image (
38
- cfg ,
35
+ img_size : int ,
39
36
save_folder : str ,
40
37
source_list : List [Tuple [str , str ]],
41
- source_label : int ,
42
38
offset : int ,
43
39
device : str ,
44
40
num_of_step : int ,
41
+ scheduler : str ,
42
+ sample_steps : int ,
43
+ model_path : Optional [str ] = None ,
45
44
):
46
- model = build_model (cfg ).to (device )
47
- if cfg .MODEL .PRETRAINED :
48
- logger .info (f"Loading pretrained model from { cfg .MODEL .PRETRAINED } " )
49
- weight = torch .load (cfg .MODEL .PRETRAINED , map_location = device )
50
- copy_parameters (weight ["ema_state_dict" ]["shadow_params" ], model .parameters ())
51
- del weight
52
- torch .cuda .empty_cache ()
53
- diffuser = TextTranslationDiffusion (cfg , device = device )
45
+ diffuser = TextTranslationDiffusion (
46
+ img_size = img_size ,
47
+ scheduler = scheduler ,
48
+ device = device ,
49
+ model_path = model_path ,
50
+ sample_steps = sample_steps ,
51
+ )
54
52
os .makedirs (args .save_folder , exist_ok = True )
55
53
56
54
progress_bar = tqdm (total = len (source_list ), position = int (device .split (":" )[- 1 ]))
@@ -65,8 +63,6 @@ def generate_image(
65
63
source_mask = source_mask .replace ("jpg" , "png" )
66
64
try :
67
65
editing_result = diffuser .modify_with_text (
68
- model = model ,
69
- source_label = source_label ,
70
66
image = source_image ,
71
67
mask = source_mask ,
72
68
prompt = [editing_prompt ],
@@ -77,9 +73,7 @@ def generate_image(
77
73
logger .error (str (e ))
78
74
count_error += 1
79
75
continue
80
- save_image = Image .fromarray (
81
- (editing_result [0 ].permute (1 , 2 , 0 ).cpu ().numpy () * 255 ).astype ("uint8" )
82
- )
76
+ save_image = editing_result [0 ]
83
77
save_image .save (save_image_name )
84
78
progress_bar .update (1 )
85
79
progress_bar .close ()
@@ -90,12 +84,6 @@ def generate_image(
90
84
91
85
if __name__ == "__main__" :
92
86
args = parse_args ()
93
- cfg = create_cfg ()
94
- if args .config :
95
- merge_possible_with_base (cfg , args .config )
96
- if args .opts :
97
- cfg .merge_from_list (args .opts )
98
- show_config (cfg )
99
87
100
88
with open (args .source_list , "r" ) as f :
101
89
data = json .load (f )
@@ -118,7 +106,6 @@ def generate_image(
118
106
else len (source_list )
119
107
)
120
108
],
121
- source_label = args .source_label ,
122
109
offset = gpu_idx * task_per_process ,
123
110
device = f"cuda:{ gpu_idx } " ,
124
111
num_of_step = args .num_of_step ,
0 commit comments