9
9
import time
10
10
import glob
11
11
import itertools as it
12
- from typing import List , Tuple
12
+ from typing import List , Tuple , Union
13
13
14
14
15
15
class FaceClassifier ():
@@ -42,6 +42,8 @@ def __init__(self, ratio = 0.85, K = 200, data_pkl = None, target_pkl = None):
42
42
# how many eigenfaces to keep
43
43
self .K = K
44
44
# how much training data to use as part of total data
45
+ if not 0 < ratio <= 1 :
46
+ raise ValueError ('Provide a training/total data ratio from 0 to 1 inclusive.' )
45
47
self .ratio = ratio
46
48
# MxK matrix - each row stores the coords of each image in the eigenface space
47
49
self .W = None
@@ -84,7 +86,7 @@ def _subtract_mean(self):
84
86
self .data = np .matmul (C , self .data )
85
87
86
88
87
- def _read_from_webcam (self , new_label , stream = 0 ):
89
+ def _read_from_webcam (self , new_label , stream : Union [ str , int ] = 0 ):
88
90
"""Takes face snapshots from webcam. Pass the new label of the subject
89
91
being photographed."""
90
92
print ("Position your face in the green box.\n "
@@ -158,7 +160,7 @@ def add_img_data(self, dir_img: str = "", from_webcam: bool = False) -> int:
158
160
159
161
def train (self ):
160
162
""" Find the coordinates of each training image in the eigenface space """
161
- self ._divide_dataset (ratio = self . ratio )
163
+ self ._divide_dataset ()
162
164
# the matrix X to use for training
163
165
X = np .array ([v [0 ] for v in self .train_data .values ()])
164
166
# compute eig of MxN^2 matrix first instead of the N^2xN^2, N^2 >> M
@@ -175,16 +177,14 @@ def train(self):
175
177
self .W = self .W [:, :self .K ]
176
178
177
179
178
- def _divide_dataset (self , ratio = 0.85 ):
180
+ def _divide_dataset (self ):
179
181
"""Divides dataset in training and test (prediction) data"""
180
- if not 0 < ratio < 1 :
181
- raise RuntimeError ("Provide a ratio between 0 and 1." )
182
- training_or_test = [self ._random_binary (ratio ) for _ in self .data ]
182
+ training_or_test = [self ._random_binary (self .ratio ) for _ in self .data ]
183
183
self ._subtract_mean ()
184
184
185
185
train_inds = [i for i ,t in enumerate (training_or_test ) if t == self ._TRAIN_SAMPLE ]
186
186
test_inds = [i for i ,t in enumerate (training_or_test ) if t == self ._PRED_SAMPLE ]
187
- # {index: (data_vector, data_label)}, index starts from 0
187
+ # {index: (data_vector, data_label)}, index starts from 0
188
188
self .train_data = OD ( # ordered dict
189
189
dict (zip (train_inds , # keys
190
190
zip (self .data [train_inds ,:], self .labels [train_inds ]))) # vals
@@ -209,7 +209,7 @@ def get_test_sample(self) -> tuple:
209
209
return self .test_data [test_ind ] # data, label
210
210
211
211
212
- def classify (self , x_new :np .array ) -> tuple :
212
+ def classify (self , x_new : np .ndarray ) -> tuple :
213
213
"""classify. Classify an input data vector.
214
214
215
215
Parameters
@@ -236,8 +236,19 @@ def classify(self, x_new:np.array) -> tuple:
236
236
self .train_data [train_inds [np .argmin (dists )]][1 ]) # label
237
237
238
238
239
- def vec2img (self , x :list ):
240
- """Converts an 1D data vector stored in the class to image."""
239
+ def vec2img (self , x : list ) -> np .ndarray :
240
+ """vec2img. Converts an 1D data vector stored in the class to image.
241
+
242
+ Parameters
243
+ ----------
244
+ x : list
245
+ 0 mean float vector of length 64^2
246
+
247
+ Returns
248
+ -------
249
+ np.ndarray
250
+ the input vector 2D uint8 64x64 image
251
+ """
241
252
x = np .array (x ) + self ._mean
242
253
x = np .reshape (255 * x , (64 ,64 ))
243
254
return np .asarray (x , np .uint8 )
@@ -269,7 +280,7 @@ def webcam2vec(self):
269
280
min_shape = min (grey .shape )
270
281
cv2 .rectangle ( frame , (0 ,0 ), (int (3 * min_shape / 4 ),
271
282
int (3 * min_shape / 4 )), (0 ,255 ,0 ), thickness = 4 )
272
- cv2 .imshow ('frame' ,frame )
283
+ cv2 .imshow ('frame' , frame )
273
284
k = cv2 .waitKey (10 ) & 0xff
274
285
if k == ord ('q' ):
275
286
break
@@ -329,12 +340,12 @@ def benchmark(self, imshow = False, wait_time = 0.5, which_labels = []):
329
340
y_pred = lbl_test )
330
341
331
342
332
- def export (self , dest_folder = '/tmp' ) -> Tuple [str , str ]:
333
- """export.
343
+ def export (self , dest_folder : str = '/tmp' ) -> Tuple [str , str ]:
344
+ """export. Exports the data and labels as serialised files.
334
345
335
346
Parameters
336
347
----------
337
- dest_folder :
348
+ dest_folder : str
338
349
dest_folder
339
350
340
351
Returns
0 commit comments