Skip to content

Commit 2195c54

Browse files
committed
au detection
1 parent 6ebb69a commit 2195c54

20 files changed

+2830
-0
lines changed

AU_Detection/data.py

+331
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
import os
2+
import random
3+
import pandas as pd
4+
from PIL import Image
5+
6+
import torch
7+
import torch.utils.data as data
8+
from torchvision import transforms
9+
10+
from data_utils import au2heatmap
11+
import numpy as np
12+
13+
class image_train(object):
14+
def __init__(self, img_size=256, crop_size=224):
15+
self.img_size = img_size
16+
self.crop_size = crop_size
17+
18+
def __call__(self, img):
19+
transform = transforms.Compose([
20+
transforms.Resize(self.img_size),
21+
transforms.ToTensor(),
22+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
23+
std=[0.229, 0.224, 0.225])
24+
])
25+
img = transform(img)
26+
27+
return img
28+
29+
30+
class image_test(object):
31+
def __init__(self, img_size=256, crop_size=224):
32+
self.img_size = img_size
33+
self.crop_size = crop_size
34+
35+
def __call__(self, img):
36+
transform = transforms.Compose([
37+
transforms.Resize(self.img_size),
38+
transforms.CenterCrop(self.crop_size),
39+
transforms.ToTensor(),
40+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
41+
std=[0.229, 0.224, 0.225])
42+
])
43+
img = transform(img)
44+
45+
return img
46+
47+
48+
class MyDataset(data.Dataset):
49+
def __init__(self, csv_file, train, config):
50+
self.config = config
51+
self.csv_file = csv_file
52+
53+
self.data = config.data
54+
self.data_root = config.data_root
55+
self.img_size = config.image_size
56+
self.crop_size = config.crop_size
57+
self.train = train
58+
if self.train:
59+
self.transform = image_train(img_size=self.img_size, crop_size=self.crop_size)
60+
else:
61+
self.transform = image_test(img_size=self.img_size, crop_size=self.crop_size)
62+
63+
self.file_list = pd.read_csv(csv_file)
64+
self.images = self.file_list['image_path']
65+
if self.data == 'BP4D':
66+
self.labels = [
67+
self.file_list['au1'],
68+
self.file_list['au2'],
69+
self.file_list['au4'],
70+
self.file_list['au6'],
71+
self.file_list['au7'],
72+
self.file_list['au10'],
73+
self.file_list['au12'],
74+
self.file_list['au14'],
75+
self.file_list['au15'],
76+
self.file_list['au17'],
77+
self.file_list['au23'],
78+
self.file_list['au24'],
79+
]
80+
elif self.data == 'DISFA':
81+
self.labels = [
82+
self.file_list['au1'],
83+
self.file_list['au2'],
84+
self.file_list['au4'],
85+
self.file_list['au5'],
86+
self.file_list['au6'],
87+
self.file_list['au9'],
88+
self.file_list['au12'],
89+
self.file_list['au15'],
90+
self.file_list['au17'],
91+
self.file_list['au20'],
92+
self.file_list['au25'],
93+
self.file_list['au26']
94+
]
95+
self.num_labels = len(self.labels)
96+
97+
def data_augmentation(self, image, flip, crop_size, offset_x, offset_y):
98+
image = image[:,offset_x:offset_x+crop_size,offset_y:offset_y+crop_size]
99+
if flip:
100+
image = torch.flip(image, [2])
101+
102+
return image
103+
104+
def pil_loader(self, path):
105+
with open(path, 'rb') as f:
106+
with Image.open(f) as img:
107+
return img.convert('RGB')
108+
109+
def __getitem__(self, index):
110+
image_path = self.images[index]
111+
image_name = os.path.join(self.data_root, image_path)
112+
image = self.pil_loader(image_name)
113+
114+
label = []
115+
for i in range(self.num_labels):
116+
label.append(float(self.labels[i][index]))
117+
label = torch.FloatTensor(label)
118+
119+
if self.train:
120+
heatmap = au2heatmap(image_name, label, self.img_size, self.config)
121+
heatmap = torch.from_numpy(heatmap)
122+
offset_y = random.randint(0, self.img_size - self.crop_size)
123+
offset_x = random.randint(0, self.img_size - self.crop_size)
124+
flip = random.randint(0, 1)
125+
image = self.transform(image)
126+
image = self.data_augmentation(image, flip, self.crop_size, offset_x, offset_y)
127+
heatmap = self.data_augmentation(heatmap, flip, self.crop_size // 4, offset_x // 4, offset_y // 4)
128+
129+
return image, label, heatmap
130+
else:
131+
image = self.transform(image)
132+
133+
return image, label
134+
135+
def collate_fn(self, data):
136+
if self.train:
137+
images, labels, heatmaps = zip(*data)
138+
139+
images = torch.stack(images)
140+
labels = torch.stack(labels).float()
141+
heatmaps = torch.stack(heatmaps).float()
142+
143+
return images, labels, heatmaps
144+
else:
145+
images, labels = zip(*data)
146+
147+
images = torch.stack(images)
148+
labels = torch.stack(labels).float()
149+
150+
return images, labels
151+
152+
def __len__(self):
153+
return len(self.images)
154+
155+
156+
157+
class MyDataset_GH_Feat(data.Dataset):
158+
def __init__(self, csv_file, config):
159+
self.config = config
160+
self.csv_file = csv_file
161+
162+
self.data = config.data
163+
self.data_root = config.data_root
164+
165+
self.file_list = pd.read_csv(csv_file)
166+
self.images = self.file_list['image_path']
167+
168+
if self.data == 'BP4D':
169+
self.labels = [
170+
self.file_list['au1'],
171+
self.file_list['au2'],
172+
self.file_list['au4'],
173+
self.file_list['au6'],
174+
self.file_list['au7'],
175+
self.file_list['au10'],
176+
self.file_list['au12'],
177+
self.file_list['au14'],
178+
self.file_list['au15'],
179+
self.file_list['au17'],
180+
self.file_list['au23'],
181+
self.file_list['au24']
182+
]
183+
elif self.data == 'DISFA':
184+
self.labels = [
185+
self.file_list['au1'],
186+
self.file_list['au2'],
187+
self.file_list['au4'],
188+
self.file_list['au5'],
189+
self.file_list['au6'],
190+
self.file_list['au9'],
191+
self.file_list['au12'],
192+
self.file_list['au15'],
193+
self.file_list['au17'],
194+
self.file_list['au20'],
195+
self.file_list['au25'],
196+
self.file_list['au26']
197+
]
198+
199+
self.num_labels = len(self.labels)
200+
201+
202+
def __getitem__(self, index):
203+
image_path = self.images[index]
204+
feature_path = os.path.join('/home/ICT2000/dchang/TAC_project/data', image_path[:-4]+'.npy')
205+
feature_path = feature_path.replace('images', 'gh_feat')
206+
feature = np.load(feature_path)
207+
feature = torch.from_numpy(feature).view(-1)
208+
209+
label = []
210+
for i in range(self.num_labels):
211+
label.append(int(self.labels[i][index]))
212+
label = torch.FloatTensor(label)
213+
214+
return feature, label
215+
216+
217+
def collate_fn(self, data):
218+
features, labels = zip(*data)
219+
220+
features = torch.stack(features)
221+
labels = torch.stack(labels)
222+
223+
return features, labels
224+
225+
226+
def __len__(self):
227+
return len(self.images)
228+
229+
230+
class MyDataset_with_lm(data.Dataset):
231+
def __init__(self, csv_file, train, config):
232+
self.config = config
233+
self.csv_file = csv_file
234+
235+
self.data = config.data
236+
self.data_root = config.data_root
237+
self.img_size = config.image_size
238+
self.crop_size = config.crop_size
239+
self.train = train
240+
if self.train:
241+
self.transform = image_train(img_size=self.img_size, crop_size=self.crop_size)
242+
else:
243+
self.transform = image_test(img_size=self.img_size, crop_size=self.crop_size)
244+
245+
self.file_list = pd.read_csv(csv_file)
246+
self.images = self.file_list['image_path']
247+
if self.data == 'BP4D':
248+
self.labels = [
249+
self.file_list['au6'],
250+
self.file_list['au10'],
251+
self.file_list['au12'],
252+
self.file_list['au14'],
253+
self.file_list['au17']
254+
]
255+
elif self.data == 'DISFA':
256+
self.labels = [
257+
self.file_list['au1'],
258+
self.file_list['au2'],
259+
self.file_list['au4'],
260+
self.file_list['au5'],
261+
self.file_list['au6'],
262+
self.file_list['au9'],
263+
self.file_list['au12'],
264+
self.file_list['au15'],
265+
self.file_list['au17'],
266+
self.file_list['au20'],
267+
self.file_list['au25'],
268+
self.file_list['au26']
269+
]
270+
self.num_labels = len(self.labels)
271+
272+
def data_augmentation(self, image, flip, crop_size, offset_x, offset_y):
273+
image = image[:,offset_x:offset_x+crop_size,offset_y:offset_y+crop_size]
274+
if flip:
275+
image = torch.flip(image, [2])
276+
277+
return image
278+
279+
def pil_loader(self, path):
280+
with open(path, 'rb') as f:
281+
with Image.open(f) as img:
282+
return img.convert('RGB')
283+
284+
def __getitem__(self, index):
285+
image_path = self.images[index]
286+
image_name = os.path.join(self.data_root, image_path)
287+
image = self.pil_loader(image_name)
288+
289+
lm_path = image_path.replace('images', 'landmarks')[:-4]+'.npy'
290+
lm_name = os.path.join(self.data_root, lm_path)
291+
landmark = np.load(lm_name)
292+
landmark = torch.FloatTensor(landmark)
293+
label = []
294+
for i in range(self.num_labels):
295+
label.append(float(self.labels[i][index]))
296+
label = torch.FloatTensor(label)
297+
298+
if self.train:
299+
heatmap = au2heatmap(image_name, label, self.img_size, self.config)
300+
heatmap = torch.from_numpy(heatmap)
301+
offset_y = random.randint(0, self.img_size - self.crop_size)
302+
offset_x = random.randint(0, self.img_size - self.crop_size)
303+
flip = random.randint(0, 1)
304+
image = self.transform(image)
305+
image = self.data_augmentation(image, flip, self.crop_size, offset_x, offset_y)
306+
heatmap = self.data_augmentation(heatmap, flip, self.crop_size // 4, offset_x // 4, offset_y // 4)
307+
return image, label, heatmap, landmark
308+
else:
309+
image = self.transform(image)
310+
311+
return image, label, landmark
312+
313+
def collate_fn(self, data):
314+
if self.train:
315+
images, labels, heatmaps, landmarks = zip(*data)
316+
317+
images = torch.stack(images)
318+
labels = torch.stack(labels).float()
319+
heatmaps = torch.stack(heatmaps).float()
320+
landmarks = torch.stack(landmarks).float()
321+
return images, labels, heatmaps, landmarks
322+
else:
323+
images, labels, landmarks = zip(*data)
324+
325+
images = torch.stack(images)
326+
labels = torch.stack(labels).float()
327+
landmarks = torch.stack(landmarks).float()
328+
return images, labels, landmarks
329+
330+
def __len__(self):
331+
return len(self.images)

0 commit comments

Comments
 (0)