1
+ import os
2
+ import random
3
+ import pandas as pd
4
+ from PIL import Image
5
+
6
+ import torch
7
+ import torch .utils .data as data
8
+ from torchvision import transforms
9
+
10
+ from data_utils import au2heatmap
11
+ import numpy as np
12
+
13
+ class image_train (object ):
14
+ def __init__ (self , img_size = 256 , crop_size = 224 ):
15
+ self .img_size = img_size
16
+ self .crop_size = crop_size
17
+
18
+ def __call__ (self , img ):
19
+ transform = transforms .Compose ([
20
+ transforms .Resize (self .img_size ),
21
+ transforms .ToTensor (),
22
+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
23
+ std = [0.229 , 0.224 , 0.225 ])
24
+ ])
25
+ img = transform (img )
26
+
27
+ return img
28
+
29
+
30
+ class image_test (object ):
31
+ def __init__ (self , img_size = 256 , crop_size = 224 ):
32
+ self .img_size = img_size
33
+ self .crop_size = crop_size
34
+
35
+ def __call__ (self , img ):
36
+ transform = transforms .Compose ([
37
+ transforms .Resize (self .img_size ),
38
+ transforms .CenterCrop (self .crop_size ),
39
+ transforms .ToTensor (),
40
+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
41
+ std = [0.229 , 0.224 , 0.225 ])
42
+ ])
43
+ img = transform (img )
44
+
45
+ return img
46
+
47
+
48
+ class MyDataset (data .Dataset ):
49
+ def __init__ (self , csv_file , train , config ):
50
+ self .config = config
51
+ self .csv_file = csv_file
52
+
53
+ self .data = config .data
54
+ self .data_root = config .data_root
55
+ self .img_size = config .image_size
56
+ self .crop_size = config .crop_size
57
+ self .train = train
58
+ if self .train :
59
+ self .transform = image_train (img_size = self .img_size , crop_size = self .crop_size )
60
+ else :
61
+ self .transform = image_test (img_size = self .img_size , crop_size = self .crop_size )
62
+
63
+ self .file_list = pd .read_csv (csv_file )
64
+ self .images = self .file_list ['image_path' ]
65
+ if self .data == 'BP4D' :
66
+ self .labels = [
67
+ self .file_list ['au1' ],
68
+ self .file_list ['au2' ],
69
+ self .file_list ['au4' ],
70
+ self .file_list ['au6' ],
71
+ self .file_list ['au7' ],
72
+ self .file_list ['au10' ],
73
+ self .file_list ['au12' ],
74
+ self .file_list ['au14' ],
75
+ self .file_list ['au15' ],
76
+ self .file_list ['au17' ],
77
+ self .file_list ['au23' ],
78
+ self .file_list ['au24' ],
79
+ ]
80
+ elif self .data == 'DISFA' :
81
+ self .labels = [
82
+ self .file_list ['au1' ],
83
+ self .file_list ['au2' ],
84
+ self .file_list ['au4' ],
85
+ self .file_list ['au5' ],
86
+ self .file_list ['au6' ],
87
+ self .file_list ['au9' ],
88
+ self .file_list ['au12' ],
89
+ self .file_list ['au15' ],
90
+ self .file_list ['au17' ],
91
+ self .file_list ['au20' ],
92
+ self .file_list ['au25' ],
93
+ self .file_list ['au26' ]
94
+ ]
95
+ self .num_labels = len (self .labels )
96
+
97
+ def data_augmentation (self , image , flip , crop_size , offset_x , offset_y ):
98
+ image = image [:,offset_x :offset_x + crop_size ,offset_y :offset_y + crop_size ]
99
+ if flip :
100
+ image = torch .flip (image , [2 ])
101
+
102
+ return image
103
+
104
+ def pil_loader (self , path ):
105
+ with open (path , 'rb' ) as f :
106
+ with Image .open (f ) as img :
107
+ return img .convert ('RGB' )
108
+
109
+ def __getitem__ (self , index ):
110
+ image_path = self .images [index ]
111
+ image_name = os .path .join (self .data_root , image_path )
112
+ image = self .pil_loader (image_name )
113
+
114
+ label = []
115
+ for i in range (self .num_labels ):
116
+ label .append (float (self .labels [i ][index ]))
117
+ label = torch .FloatTensor (label )
118
+
119
+ if self .train :
120
+ heatmap = au2heatmap (image_name , label , self .img_size , self .config )
121
+ heatmap = torch .from_numpy (heatmap )
122
+ offset_y = random .randint (0 , self .img_size - self .crop_size )
123
+ offset_x = random .randint (0 , self .img_size - self .crop_size )
124
+ flip = random .randint (0 , 1 )
125
+ image = self .transform (image )
126
+ image = self .data_augmentation (image , flip , self .crop_size , offset_x , offset_y )
127
+ heatmap = self .data_augmentation (heatmap , flip , self .crop_size // 4 , offset_x // 4 , offset_y // 4 )
128
+
129
+ return image , label , heatmap
130
+ else :
131
+ image = self .transform (image )
132
+
133
+ return image , label
134
+
135
+ def collate_fn (self , data ):
136
+ if self .train :
137
+ images , labels , heatmaps = zip (* data )
138
+
139
+ images = torch .stack (images )
140
+ labels = torch .stack (labels ).float ()
141
+ heatmaps = torch .stack (heatmaps ).float ()
142
+
143
+ return images , labels , heatmaps
144
+ else :
145
+ images , labels = zip (* data )
146
+
147
+ images = torch .stack (images )
148
+ labels = torch .stack (labels ).float ()
149
+
150
+ return images , labels
151
+
152
+ def __len__ (self ):
153
+ return len (self .images )
154
+
155
+
156
+
157
+ class MyDataset_GH_Feat (data .Dataset ):
158
+ def __init__ (self , csv_file , config ):
159
+ self .config = config
160
+ self .csv_file = csv_file
161
+
162
+ self .data = config .data
163
+ self .data_root = config .data_root
164
+
165
+ self .file_list = pd .read_csv (csv_file )
166
+ self .images = self .file_list ['image_path' ]
167
+
168
+ if self .data == 'BP4D' :
169
+ self .labels = [
170
+ self .file_list ['au1' ],
171
+ self .file_list ['au2' ],
172
+ self .file_list ['au4' ],
173
+ self .file_list ['au6' ],
174
+ self .file_list ['au7' ],
175
+ self .file_list ['au10' ],
176
+ self .file_list ['au12' ],
177
+ self .file_list ['au14' ],
178
+ self .file_list ['au15' ],
179
+ self .file_list ['au17' ],
180
+ self .file_list ['au23' ],
181
+ self .file_list ['au24' ]
182
+ ]
183
+ elif self .data == 'DISFA' :
184
+ self .labels = [
185
+ self .file_list ['au1' ],
186
+ self .file_list ['au2' ],
187
+ self .file_list ['au4' ],
188
+ self .file_list ['au5' ],
189
+ self .file_list ['au6' ],
190
+ self .file_list ['au9' ],
191
+ self .file_list ['au12' ],
192
+ self .file_list ['au15' ],
193
+ self .file_list ['au17' ],
194
+ self .file_list ['au20' ],
195
+ self .file_list ['au25' ],
196
+ self .file_list ['au26' ]
197
+ ]
198
+
199
+ self .num_labels = len (self .labels )
200
+
201
+
202
+ def __getitem__ (self , index ):
203
+ image_path = self .images [index ]
204
+ feature_path = os .path .join ('/home/ICT2000/dchang/TAC_project/data' , image_path [:- 4 ]+ '.npy' )
205
+ feature_path = feature_path .replace ('images' , 'gh_feat' )
206
+ feature = np .load (feature_path )
207
+ feature = torch .from_numpy (feature ).view (- 1 )
208
+
209
+ label = []
210
+ for i in range (self .num_labels ):
211
+ label .append (int (self .labels [i ][index ]))
212
+ label = torch .FloatTensor (label )
213
+
214
+ return feature , label
215
+
216
+
217
+ def collate_fn (self , data ):
218
+ features , labels = zip (* data )
219
+
220
+ features = torch .stack (features )
221
+ labels = torch .stack (labels )
222
+
223
+ return features , labels
224
+
225
+
226
+ def __len__ (self ):
227
+ return len (self .images )
228
+
229
+
230
+ class MyDataset_with_lm (data .Dataset ):
231
+ def __init__ (self , csv_file , train , config ):
232
+ self .config = config
233
+ self .csv_file = csv_file
234
+
235
+ self .data = config .data
236
+ self .data_root = config .data_root
237
+ self .img_size = config .image_size
238
+ self .crop_size = config .crop_size
239
+ self .train = train
240
+ if self .train :
241
+ self .transform = image_train (img_size = self .img_size , crop_size = self .crop_size )
242
+ else :
243
+ self .transform = image_test (img_size = self .img_size , crop_size = self .crop_size )
244
+
245
+ self .file_list = pd .read_csv (csv_file )
246
+ self .images = self .file_list ['image_path' ]
247
+ if self .data == 'BP4D' :
248
+ self .labels = [
249
+ self .file_list ['au6' ],
250
+ self .file_list ['au10' ],
251
+ self .file_list ['au12' ],
252
+ self .file_list ['au14' ],
253
+ self .file_list ['au17' ]
254
+ ]
255
+ elif self .data == 'DISFA' :
256
+ self .labels = [
257
+ self .file_list ['au1' ],
258
+ self .file_list ['au2' ],
259
+ self .file_list ['au4' ],
260
+ self .file_list ['au5' ],
261
+ self .file_list ['au6' ],
262
+ self .file_list ['au9' ],
263
+ self .file_list ['au12' ],
264
+ self .file_list ['au15' ],
265
+ self .file_list ['au17' ],
266
+ self .file_list ['au20' ],
267
+ self .file_list ['au25' ],
268
+ self .file_list ['au26' ]
269
+ ]
270
+ self .num_labels = len (self .labels )
271
+
272
+ def data_augmentation (self , image , flip , crop_size , offset_x , offset_y ):
273
+ image = image [:,offset_x :offset_x + crop_size ,offset_y :offset_y + crop_size ]
274
+ if flip :
275
+ image = torch .flip (image , [2 ])
276
+
277
+ return image
278
+
279
+ def pil_loader (self , path ):
280
+ with open (path , 'rb' ) as f :
281
+ with Image .open (f ) as img :
282
+ return img .convert ('RGB' )
283
+
284
+ def __getitem__ (self , index ):
285
+ image_path = self .images [index ]
286
+ image_name = os .path .join (self .data_root , image_path )
287
+ image = self .pil_loader (image_name )
288
+
289
+ lm_path = image_path .replace ('images' , 'landmarks' )[:- 4 ]+ '.npy'
290
+ lm_name = os .path .join (self .data_root , lm_path )
291
+ landmark = np .load (lm_name )
292
+ landmark = torch .FloatTensor (landmark )
293
+ label = []
294
+ for i in range (self .num_labels ):
295
+ label .append (float (self .labels [i ][index ]))
296
+ label = torch .FloatTensor (label )
297
+
298
+ if self .train :
299
+ heatmap = au2heatmap (image_name , label , self .img_size , self .config )
300
+ heatmap = torch .from_numpy (heatmap )
301
+ offset_y = random .randint (0 , self .img_size - self .crop_size )
302
+ offset_x = random .randint (0 , self .img_size - self .crop_size )
303
+ flip = random .randint (0 , 1 )
304
+ image = self .transform (image )
305
+ image = self .data_augmentation (image , flip , self .crop_size , offset_x , offset_y )
306
+ heatmap = self .data_augmentation (heatmap , flip , self .crop_size // 4 , offset_x // 4 , offset_y // 4 )
307
+ return image , label , heatmap , landmark
308
+ else :
309
+ image = self .transform (image )
310
+
311
+ return image , label , landmark
312
+
313
+ def collate_fn (self , data ):
314
+ if self .train :
315
+ images , labels , heatmaps , landmarks = zip (* data )
316
+
317
+ images = torch .stack (images )
318
+ labels = torch .stack (labels ).float ()
319
+ heatmaps = torch .stack (heatmaps ).float ()
320
+ landmarks = torch .stack (landmarks ).float ()
321
+ return images , labels , heatmaps , landmarks
322
+ else :
323
+ images , labels , landmarks = zip (* data )
324
+
325
+ images = torch .stack (images )
326
+ labels = torch .stack (labels ).float ()
327
+ landmarks = torch .stack (landmarks ).float ()
328
+ return images , labels , landmarks
329
+
330
+ def __len__ (self ):
331
+ return len (self .images )
0 commit comments