forked from elvisyjlin/AttGAN-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
120 lines (103 loc) · 4.4 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright (C) 2018 Elvis Yu-Jing Lin <[email protected]>
#
# This work is licensed under the MIT License. To view a copy of this license,
# visit https://opensource.org/licenses/MIT.
"""Entry point for testing AttGAN network."""
import argparse
import json
import os
from os.path import join
import numpy as np
import torch
import torch.utils.data as data
import torchvision.utils as vutils
from attgan import AttGAN
from data import check_attribute_conflict
from helpers import Progressbar
from utils import find_model
def parse(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
parser.add_argument('--num_test', dest='num_test', type=int)
parser.add_argument('--load_epoch', dest='load_epoch', type=str, default='latest')
parser.add_argument('--custom_img', action='store_true')
parser.add_argument('--custom_data', type=str, default='./data/custom')
parser.add_argument('--custom_attr', type=str, default='./data/list_attr_custom.txt')
parser.add_argument('--gpu', action='store_true')
parser.add_argument('--multi_gpu', action='store_true')
return parser.parse_args(args)
args_ = parse()
print(args_)
with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
args.test_int = args_.test_int
args.num_test = args_.num_test
args.gpu = args_.gpu
args.load_epoch = args_.load_epoch
args.multi_gpu = args_.multi_gpu
args.custom_img = args_.custom_img
args.custom_data = args_.custom_data
args.custom_attr = args_.custom_attr
args.n_attrs = len(args.attrs)
args.betas = (args.beta1, args.beta2)
print(args)
if args.custom_img:
output_path = join('output', args.experiment_name, 'custom_testing')
from data import Custom
test_dataset = Custom(args.custom_data, args.custom_attr, args.img_size, args.attrs)
else:
output_path = join('output', args.experiment_name, 'sample_testing')
if args.data == 'CelebA':
from data import CelebA
test_dataset = CelebA(args.data_path, args.attr_path, args.img_size, 'test', args.attrs)
if args.data == 'CelebA-HQ':
from data import CelebA_HQ
test_dataset = CelebA_HQ(args.data_path, args.attr_path, args.image_list_path, args.img_size, 'test', args.attrs)
os.makedirs(output_path, exist_ok=True)
test_dataloader = data.DataLoader(
test_dataset, batch_size=1, num_workers=args.num_workers,
shuffle=False, drop_last=False
)
if args.num_test is None:
print('Testing images:', len(test_dataset))
else:
print('Testing images:', min(len(test_dataset), args.num_test))
attgan = AttGAN(args)
attgan.load(find_model(join('output', args.experiment_name, 'checkpoint'), args.load_epoch))
progressbar = Progressbar()
attgan.eval()
for idx, (img_a, att_a) in enumerate(test_dataloader):
if args.num_test is not None and idx == args.num_test:
break
img_a = img_a.cuda() if args.gpu else img_a
att_a = att_a.cuda() if args.gpu else att_a
att_a = att_a.type(torch.float)
att_b_list = [att_a]
for i in range(args.n_attrs):
tmp = att_a.clone()
tmp[:, i] = 1 - tmp[:, i]
tmp = check_attribute_conflict(tmp, args.attrs[i], args.attrs)
att_b_list.append(tmp)
#Bald Bangs Black_Hair Blond_Hair Brown_Hair Bushy_Eyebrows Eyeglasses Male Mouth_Slightly_Open Mustache No_Beard Pale_Skin Young
with torch.no_grad():
samples = [img_a]
for i, att_b in enumerate(att_b_list):
att_b_ = (att_b * 2 - 1) * args.thres_int
if i > 0:
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
if i == 13:
print("att_b_", i, att_b_)
att_b_= torch.from_numpy(np.array([[-0.5, 0.5, 0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, 0.5, -0.5, 1.0]])).type(torch.float)
print("att_b_", i, att_b_)
samples.append(attgan.G(img_a, att_b_))
samples = torch.cat(samples, dim=3)
if args.custom_img:
out_file = test_dataset.images[idx]
else:
out_file = '{:06d}.jpg'.format(idx + 182638)
vutils.save_image(
samples, join(output_path, out_file),
nrow=1, normalize=True, range=(-1., 1.)
)
print('{:s} done!'.format(out_file))