Skip to content

Commit f2911ff

Browse files
committedOct 20, 2024
🚀 Update the editing
1 parent 7708a74 commit f2911ff

File tree

2 files changed

+131
-143
lines changed

2 files changed

+131
-143
lines changed
 

‎generate_text_editing.py

+17-30
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
import argparse
22
import json
33
import os
4-
from typing import List, Tuple
4+
from typing import List, Optional, Tuple
55

6-
import torch
76
from joblib import Parallel, delayed
87
from loguru import logger
9-
from PIL import Image
108
from tqdm import tqdm
119

12-
from config import create_cfg, merge_possible_with_base, show_config
13-
from modeling import build_model
1410
from modeling.text_translation import TextTranslationDiffusion
1511

1612

@@ -23,34 +19,36 @@ def copy_parameters(from_parameters, to_parameters):
2319

2420
def parse_args():
2521
parser = argparse.ArgumentParser()
26-
parser.add_argument("--config", default=None, type=str)
2722
parser.add_argument("--save-folder", default="batch_images", type=str)
2823
parser.add_argument("--source-root", required=True, type=str)
2924
parser.add_argument("--source-list", required=True, type=str)
30-
parser.add_argument("--source-label", required=True, type=int)
3125
parser.add_argument("--num-process", default=1, type=int)
3226
parser.add_argument("--num-of-step", default=180, type=int)
33-
parser.add_argument("--opts", nargs=argparse.REMAINDER, default=None, type=str)
27+
parser.add_argument("--img-size", default=512, type=int)
28+
parser.add_argument("--model-path", default=None, type=str)
29+
parser.add_argument("--scheduler", default="ddpm", type=str)
30+
parser.add_argument("--sample-steps", default=1000, type=int)
3431
return parser.parse_args()
3532

3633

3734
def generate_image(
38-
cfg,
35+
img_size: int,
3936
save_folder: str,
4037
source_list: List[Tuple[str, str]],
41-
source_label: int,
4238
offset: int,
4339
device: str,
4440
num_of_step: int,
41+
scheduler: str,
42+
sample_steps: int,
43+
model_path: Optional[str] = None,
4544
):
46-
model = build_model(cfg).to(device)
47-
if cfg.MODEL.PRETRAINED:
48-
logger.info(f"Loading pretrained model from {cfg.MODEL.PRETRAINED}")
49-
weight = torch.load(cfg.MODEL.PRETRAINED, map_location=device)
50-
copy_parameters(weight["ema_state_dict"]["shadow_params"], model.parameters())
51-
del weight
52-
torch.cuda.empty_cache()
53-
diffuser = TextTranslationDiffusion(cfg, device=device)
45+
diffuser = TextTranslationDiffusion(
46+
img_size=img_size,
47+
scheduler=scheduler,
48+
device=device,
49+
model_path=model_path,
50+
sample_steps=sample_steps,
51+
)
5452
os.makedirs(args.save_folder, exist_ok=True)
5553

5654
progress_bar = tqdm(total=len(source_list), position=int(device.split(":")[-1]))
@@ -65,8 +63,6 @@ def generate_image(
6563
source_mask = source_mask.replace("jpg", "png")
6664
try:
6765
editing_result = diffuser.modify_with_text(
68-
model=model,
69-
source_label=source_label,
7066
image=source_image,
7167
mask=source_mask,
7268
prompt=[editing_prompt],
@@ -77,9 +73,7 @@ def generate_image(
7773
logger.error(str(e))
7874
count_error += 1
7975
continue
80-
save_image = Image.fromarray(
81-
(editing_result[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")
82-
)
76+
save_image = editing_result[0]
8377
save_image.save(save_image_name)
8478
progress_bar.update(1)
8579
progress_bar.close()
@@ -90,12 +84,6 @@ def generate_image(
9084

9185
if __name__ == "__main__":
9286
args = parse_args()
93-
cfg = create_cfg()
94-
if args.config:
95-
merge_possible_with_base(cfg, args.config)
96-
if args.opts:
97-
cfg.merge_from_list(args.opts)
98-
show_config(cfg)
9987

10088
with open(args.source_list, "r") as f:
10189
data = json.load(f)
@@ -118,7 +106,6 @@ def generate_image(
118106
else len(source_list)
119107
)
120108
],
121-
source_label=args.source_label,
122109
offset=gpu_idx * task_per_process,
123110
device=f"cuda:{gpu_idx}",
124111
num_of_step=args.num_of_step,

0 commit comments

Comments
 (0)
Please sign in to comment.