Skip to content
This repository was archived by the owner on Aug 23, 2024. It is now read-only.

Commit 404550f

Browse files
committed
Better error catching and more parameter type hints
1 parent 5bfe113 commit 404550f

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

Diff for: ezfaces/face_classifier.py

+26-15
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import time
1010
import glob
1111
import itertools as it
12-
from typing import List, Tuple
12+
from typing import List, Tuple, Union
1313

1414

1515
class FaceClassifier():
@@ -42,6 +42,8 @@ def __init__(self, ratio = 0.85, K = 200, data_pkl = None, target_pkl = None):
4242
# how many eigenfaces to keep
4343
self.K = K
4444
# how much training data to use as part of total data
45+
if not 0 < ratio <= 1:
46+
raise ValueError('Provide a training/total data ratio from 0 to 1 inclusive.')
4547
self.ratio = ratio
4648
# MxK matrix - each row stores the coords of each image in the eigenface space
4749
self.W = None
@@ -84,7 +86,7 @@ def _subtract_mean(self):
8486
self.data = np.matmul(C, self.data)
8587

8688

87-
def _read_from_webcam(self, new_label, stream = 0):
89+
def _read_from_webcam(self, new_label, stream: Union[str, int] = 0):
8890
"""Takes face snapshots from webcam. Pass the new label of the subject
8991
being photographed."""
9092
print("Position your face in the green box.\n"
@@ -158,7 +160,7 @@ def add_img_data(self, dir_img: str = "", from_webcam: bool = False) -> int:
158160

159161
def train(self):
160162
""" Find the coordinates of each training image in the eigenface space """
161-
self._divide_dataset(ratio = self.ratio)
163+
self._divide_dataset()
162164
# the matrix X to use for training
163165
X = np.array([v[0] for v in self.train_data.values()])
164166
# compute eig of MxN^2 matrix first instead of the N^2xN^2, N^2 >> M
@@ -175,16 +177,14 @@ def train(self):
175177
self.W = self.W[:, :self.K]
176178

177179

178-
def _divide_dataset(self, ratio = 0.85):
180+
def _divide_dataset(self):
179181
"""Divides dataset in training and test (prediction) data"""
180-
if not 0 < ratio < 1:
181-
raise RuntimeError("Provide a ratio between 0 and 1.")
182-
training_or_test = [self._random_binary(ratio) for _ in self.data]
182+
training_or_test = [self._random_binary(self.ratio) for _ in self.data]
183183
self._subtract_mean()
184184

185185
train_inds = [i for i,t in enumerate(training_or_test) if t == self._TRAIN_SAMPLE]
186186
test_inds = [i for i,t in enumerate(training_or_test) if t == self._PRED_SAMPLE]
187-
# {index: (data_vector, data_label)}, index starts from 0
187+
# {index: (data_vector, data_label)}, index starts from 0
188188
self.train_data = OD( # ordered dict
189189
dict(zip(train_inds, # keys
190190
zip(self.data[train_inds,:], self.labels[train_inds]))) # vals
@@ -209,7 +209,7 @@ def get_test_sample(self) -> tuple:
209209
return self.test_data[test_ind] # data, label
210210

211211

212-
def classify(self, x_new:np.array) -> tuple:
212+
def classify(self, x_new: np.ndarray) -> tuple:
213213
"""classify. Classify an input data vector.
214214
215215
Parameters
@@ -236,8 +236,19 @@ def classify(self, x_new:np.array) -> tuple:
236236
self.train_data[train_inds[np.argmin(dists)]][1]) # label
237237

238238

239-
def vec2img(self, x:list):
240-
"""Converts an 1D data vector stored in the class to image."""
239+
def vec2img(self, x: list) -> np.ndarray:
240+
"""vec2img. Converts an 1D data vector stored in the class to image.
241+
242+
Parameters
243+
----------
244+
x : list
245+
0 mean float vector of length 64^2
246+
247+
Returns
248+
-------
249+
np.ndarray
250+
the input vector 2D uint8 64x64 image
251+
"""
241252
x = np.array(x) + self._mean
242253
x = np.reshape(255*x, (64,64))
243254
return np.asarray(x, np.uint8)
@@ -269,7 +280,7 @@ def webcam2vec(self):
269280
min_shape = min(grey.shape)
270281
cv2.rectangle( frame, (0,0), (int(3*min_shape/4),
271282
int(3*min_shape/4)), (0,255,0), thickness = 4)
272-
cv2.imshow('frame',frame)
283+
cv2.imshow('frame', frame)
273284
k = cv2.waitKey(10) & 0xff
274285
if k == ord('q'):
275286
break
@@ -329,12 +340,12 @@ def benchmark(self, imshow = False, wait_time = 0.5, which_labels = []):
329340
y_pred = lbl_test)
330341

331342

332-
def export(self, dest_folder = '/tmp') -> Tuple[str, str]:
333-
"""export.
343+
def export(self, dest_folder: str = '/tmp') -> Tuple[str, str]:
344+
"""export. Exports the data and labels as serialised files.
334345
335346
Parameters
336347
----------
337-
dest_folder :
348+
dest_folder : str
338349
dest_folder
339350
340351
Returns

0 commit comments

Comments
 (0)