-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
111 lines (89 loc) · 4.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import tensorflow as tf
import os
##################################################################################
# Normalize and Denormalize images
##################################################################################
def preprocess(image):
with tf.name_scope("preprocess"):
# [0, 1] => [-1, 1]
return image * 2 - 1
def deprocess(image):
with tf.name_scope("deprocess"):
# [-1, 1] => [0, 1]
return (image + 1) / 2
##################################################################################
# Image Colorization functions
##################################################################################
def preprocess_lab(lab):
with tf.name_scope("preprocess_lab"):
L_chan, a_chan, b_chan = tf.unstack(lab, axis=2)
# L_chan: black and white with input range [0, 100]
# a_chan/b_chan: color channels with input range ~[-110, 110], not exact
# [0, 100] => [-1, 1], ~[-110, 110] => [-1, 1]
return [L_chan / 50 - 1, a_chan / 110, b_chan / 110]
def deprocess_lab(L_chan, a_chan, b_chan):
with tf.name_scope("deprocess_lab"):
# this is axis=3 instead of axis=2 because we process individual images but deprocess batches
return tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3)
def augment(image, brightness):
# (a, b) color channels, combine with L channel and convert to rgb
a_chan, b_chan = tf.unstack(image, axis=3)
L_chan = tf.squeeze(brightness, axis=3)
lab = deprocess_lab(L_chan, a_chan, b_chan)
rgb = lab_to_rgb(lab)
return rgb
def check_image(image):
assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels")
with tf.control_dependencies([assertion]):
image = tf.identity(image)
if image.get_shape().ndims not in (3, 4):
raise ValueError("image must be either 3 or 4 dimensions")
# make the last dimension 3 so that you can unstack the colors
shape = list(image.get_shape())
shape[-1] = 3
image.set_shape(shape)
return image
##################################################################################
# Save training or testing images
##################################################################################
def save_images(fetches, args, step=None):
image_dir = os.path.join(args.output_dir, "images")
if not os.path.exists(image_dir):
os.makedirs(image_dir)
filesets = []
for i, in_path in enumerate(fetches["paths"]):
name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8")))
fileset = {"name": name, "step": step}
for kind in ["src_font", "trg_font", "trg_skeleton", "f2f_outputs", "f2s_outputs", "s2f_outputs"]:
filename = name + "-" + kind + ".png"
if step is not None:
filename = "%08d-%s" % (step, filename)
fileset[kind] = filename
out_path = os.path.join(image_dir, filename)
contents = fetches[kind][i]
with open(out_path, "wb") as f:
f.write(contents)
filesets.append(fileset)
return filesets
##################################################################################
# Write HTML index file for browser view
##################################################################################
def append_index(filesets, args, step=False):
index_path = os.path.join(args.output_dir, "index.html")
if os.path.exists(index_path):
index = open(index_path, "a")
else:
index = open(index_path, "w")
index.write("<html><body><table><tr>")
if step:
index.write("<th>step</th>")
index.write("<th>name</th><th>src_font</th><th>trg_font</th><th>trg_skeleton</th><th>f2f_outputs</th><th>f2s_outputs</th><th>s2f_outputs</th></tr>")
for fileset in filesets:
index.write("<tr>")
if step:
index.write("<td>%d</td>" % fileset["step"])
index.write("<td>%s</td>" % fileset["name"])
for kind in ["src_font", "trg_font", "trg_skeleton", "f2f_outputs", "f2s_outputs", "s2f_outputs"]:
index.write("<td><img src='images/%s'></td>" % fileset[kind])
index.write("</tr>")
return index_path