-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_dataset_loader_test.py
183 lines (147 loc) · 6.57 KB
/
custom_dataset_loader_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms, utils
from torch.utils import data
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as vutils
import os
import random
from os.path import join
from os import listdir
import glob
SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
# Default paths.
Lable_file = os.path.join(SCRIPT_PATH,
'dataset_builder/labels/350-common-hangul.txt')
font_path = os.path.join(SCRIPT_PATH, 'dataset_builder/val_fonts')
# total_characters = sum(1 for _ in open(Lable_file, encoding='utf-8'))
total_characters = 50
# total_styles = len(glob.glob1(font_path,"*.ttf"))
total_styles = 1
def data_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def replace_name(s, old, new, occurrence):
li = s.rsplit(old, occurrence)
return new.join(li)
def gen_random_no(start, stop):
return random.randint(start, stop)
def replace_style(string, new, I):
li = string[string.index(I):]
# print("new", new)
# print("li", li)
# print(new+li)
return new+li
def get_positive_img(img_name, total_chars):
file_name = img_name
name, file_extension = os.path.splitext(file_name)
img_name = name + file_extension
style_no = name.split('_')[0]
char_no = name.split('_')[1]
# Get random character
char = random.randint(1,total_chars)
style_img_name = replace_name(str(img_name), str(char_no), str(char), 1)
return style_img_name
def get_negative_img(a_img_name, total_styles, total_chars):
a_file_name = a_img_name
a_name, a_file_extension = os.path.splitext(a_file_name)
a_img_name = a_name + a_file_extension
a_style_no = a_name.split('_')[0]
a_char_no = a_name.split('_')[1]
# Get random style
n_img_style = gen_random_no(1, total_styles)
while True:
if int(n_img_style) == int(a_style_no):
n_img_style = gen_random_no(1, total_styles)
else:
n_img_style = n_img_style
break
style_name = replace_style(str(a_img_name), str(n_img_style), I="_")
# n_file_name = style_name
# n_name, n_file_extension = os.path.splitext(n_file_name)
# n_img_name = n_name + n_file_extension
# n_style_no = n_name.split('_')[0]
# n_char_no = n_name.split('_')[1]
# # Get random character
# char = random.randint(1, total_chars)
# n_img = replace_name(str(n_img_name), str(n_char_no), str(char), 1)
return style_name
def get_src_img(img_name, total_chars):
file_name = img_name
name, file_extension = os.path.splitext(file_name)
img_name = name + file_extension
src_style = 1
src_img_name = replace_style(str(img_name), str(src_style), I="_")
return src_img_name
class DatasetFromFolder(data.Dataset):
def __init__(self, image_dir, img_size):
super(DatasetFromFolder, self).__init__()
self.src_img_path = join(image_dir, "test_img_50/printed/font")
self.b_path = join(image_dir, "fineTune/font1")
self.image_filenames = sorted([x for x in listdir(self.b_path) if is_image_file(x)])
self.img_size = img_size
self.total_chars = total_characters
self.total_styles = total_styles
transform_list = [transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])]
self.transform = transforms.Compose(transform_list)
def __getitem__(self, index):
# Get Images
# a = Image.open(join(self.a_path, self.image_filenames[index])).convert('RGB')
anchor_img = Image.open(join(self.b_path, self.image_filenames[index]))
# print("anchor_img", self.image_filenames[index])
# Get source image corresponding to the anchor character
src_img_name = get_src_img(self.image_filenames[index], self.total_chars)
negative_img = self.transform(Image.open(join(self.src_img_path, src_img_name)))
positive_img_name = get_positive_img(self.image_filenames[index], self.total_chars)
positive_img = Image.open(join(self.b_path, positive_img_name))
# print("positive_img", positive_img_name)
# negative_img_name = get_negative_img(self.image_filenames[index], self.total_styles, self.total_chars)
# negative_img = Image.open(join(self.b_path, negative_img_name)).convert('RGB')
# # print("negative_img", negative_img_name)
# # print()
# # Get Labels
# style_label = int((self.image_filenames[index].split('_'))[0].split('.')[0])
# char_label = int((self.image_filenames[index].split('_'))[1].split('.')[0])
# anchor_s_label = style_label - 1 # This is done to avoid the lable indexing label as our labels start with 1 and not 0
# anchor_c_label = char_label - 1 # This is done to avoid the lable indexing label as our labels start with 1 and not 0
# a = self.transform(a)
anchor_img = self.transform(anchor_img)
positive_img = self.transform(positive_img)
# negative_img = self.transform(negative_img)
file_name = self.image_filenames[index]
name, _ = os.path.splitext(file_name)
return negative_img, positive_img, anchor_img, name
def __len__(self):
return len(self.image_filenames)
def plot_data(loader):
# Plot some training images
anchor_img, positive_img, negative_img, anchor_c_label, anchor_s_label = next(iter(loader))
# plt.figure(figsize=(8,8))
# plt.axis("off")
# plt.title("src_img training Images")
# plt.imshow(np.transpose(vutils.make_grid(src_img[:64], padding=2, normalize=True).cpu(),(1,2,0)))
# plt.show()
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("anchor_img training Images")
plt.imshow(np.transpose(vutils.make_grid(anchor_img[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("positive_img training Images")
plt.imshow(np.transpose(vutils.make_grid(positive_img[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("negative_img training Images")
plt.imshow(np.transpose(vutils.make_grid(negative_img[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()