-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
test.lua
executable file
·164 lines (138 loc) · 6.52 KB
/
test.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
-- usage: DATA_ROOT=/path/to/data/ name=expt1 which_direction=BtoA th test.lua
--
-- code derived from https://github.com/soumith/dcgan.torch
--
require 'image'
require 'nn'
require 'nngraph'
util = paths.dofile('util/util.lua')
torch.setdefaulttensortype('torch.FloatTensor')
opt = {
DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc)
batchSize = 1, -- # images in batch
loadSize = 256, -- scale images to this size
fineSize = 256, -- then crop to this size
flip=0, -- horizontal mirroring data augmentation
display = 1, -- display samples while training. 0 = false
display_id = 200, -- display window id.
gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
how_many = 'all', -- how many test images to run (set to all to run on every image found in the data/phase folder)
which_direction = 'AtoB', -- AtoB or BtoA
phase = 'val', -- train, val, test ,etc
preprocess = 'regular', -- for special purpose preprocessing, e.g., for colorization, change this (selects preprocessing functions in util.lua)
aspect_ratio = 1.0, -- aspect ratio of result images
name = '', -- name of experiment, selects which model to run, should generally should be passed on command line
input_nc = 3, -- # of input image channels
output_nc = 3, -- # of output image channels
serial_batches = 1, -- if 1, takes images in order to make batches, otherwise takes them randomly
serial_batch_iter = 1, -- iter into serial image list
cudnn = 1, -- set to 0 to not use cudnn (untested)
checkpoints_dir = './checkpoints', -- loads models from here
results_dir='./results/', -- saves results here
which_epoch = 'latest', -- which epoch to test? set to 'latest' to use latest cached model
}
-- one-line argument parser. parses enviroment variables to override the defaults
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
opt.nThreads = 1 -- test only works with 1 thread...
print(opt)
if opt.display == 0 then opt.display = false end
opt.manualSeed = torch.random(1, 10000) -- set seed
print("Random Seed: " .. opt.manualSeed)
torch.manualSeed(opt.manualSeed)
torch.setdefaulttensortype('torch.FloatTensor')
opt.netG_name = opt.name .. '/' .. opt.which_epoch .. '_net_G'
local data_loader = paths.dofile('data/data.lua')
print('#threads...' .. opt.nThreads)
local data = data_loader.new(opt.nThreads, opt)
print("Dataset Size: ", data:size())
-- translation direction
local idx_A = nil
local idx_B = nil
local input_nc = opt.input_nc
local output_nc = opt.output_nc
if opt.which_direction=='AtoB' then
idx_A = {1, input_nc}
idx_B = {input_nc+1, input_nc+output_nc}
elseif opt.which_direction=='BtoA' then
idx_A = {input_nc+1, input_nc+output_nc}
idx_B = {1, input_nc}
else
error(string.format('bad direction %s',opt.which_direction))
end
----------------------------------------------------------------------------
local input = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
local target = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
print('checkpoints_dir', opt.checkpoints_dir)
local netG = util.load(paths.concat(opt.checkpoints_dir, opt.netG_name .. '.t7'), opt)
--netG:evaluate()
print(netG)
function TableConcat(t1,t2)
for i=1,#t2 do
t1[#t1+1] = t2[i]
end
return t1
end
if opt.how_many=='all' then
opt.how_many=data:size()
end
opt.how_many=math.min(opt.how_many, data:size())
local filepaths = {} -- paths to images tested on
for n=1,math.floor(opt.how_many/opt.batchSize) do
print('processing batch ' .. n)
local data_curr, filepaths_curr = data:getBatch()
filepaths_curr = util.basename_batch(filepaths_curr)
print('filepaths_curr: ', filepaths_curr)
input = data_curr[{ {}, idx_A, {}, {} }]
target = data_curr[{ {}, idx_B, {}, {} }]
if opt.gpu > 0 then
input = input:cuda()
end
if opt.preprocess == 'colorization' then
local output_AB = netG:forward(input):float()
local input_L = input:float()
output = util.deprocessLAB_batch(input_L, output_AB)
local target_AB = target:float()
target = util.deprocessLAB_batch(input_L, target_AB)
input = util.deprocessL_batch(input_L)
else
output = util.deprocess_batch(netG:forward(input))
input = util.deprocess_batch(input):float()
output = output:float()
target = util.deprocess_batch(target):float()
end
paths.mkdir(paths.concat(opt.results_dir, opt.netG_name .. '_' .. opt.phase))
local image_dir = paths.concat(opt.results_dir, opt.netG_name .. '_' .. opt.phase, 'images')
paths.mkdir(image_dir)
paths.mkdir(paths.concat(image_dir,'input'))
paths.mkdir(paths.concat(image_dir,'output'))
paths.mkdir(paths.concat(image_dir,'target'))
for i=1, opt.batchSize do
image.save(paths.concat(image_dir,'input',filepaths_curr[i]), image.scale(input[i],input[i]:size(2),input[i]:size(3)/opt.aspect_ratio))
image.save(paths.concat(image_dir,'output',filepaths_curr[i]), image.scale(output[i],output[i]:size(2),output[i]:size(3)/opt.aspect_ratio))
image.save(paths.concat(image_dir,'target',filepaths_curr[i]), image.scale(target[i],target[i]:size(2),target[i]:size(3)/opt.aspect_ratio))
end
print('Saved images to: ', image_dir)
if opt.display then
if opt.preprocess == 'regular' then
disp = require 'display'
disp.image(util.scaleBatch(input,100,100),{win=opt.display_id, title='input'})
disp.image(util.scaleBatch(output,100,100),{win=opt.display_id+1, title='output'})
disp.image(util.scaleBatch(target,100,100),{win=opt.display_id+2, title='target'})
print('Displayed images')
end
end
filepaths = TableConcat(filepaths, filepaths_curr)
end
-- make webpage
io.output(paths.concat(opt.results_dir,opt.netG_name .. '_' .. opt.phase, 'index.html'))
io.write('<table style="text-align:center;">')
io.write('<tr><td>Image #</td><td>Input</td><td>Output</td><td>Ground Truth</td></tr>')
for i=1, #filepaths do
io.write('<tr>')
io.write('<td>' .. filepaths[i] .. '</td>')
io.write('<td><img src="./images/input/' .. filepaths[i] .. '"/></td>')
io.write('<td><img src="./images/output/' .. filepaths[i] .. '"/></td>')
io.write('<td><img src="./images/target/' .. filepaths[i] .. '"/></td>')
io.write('</tr>')
end
io.write('</table>')