Skip to content

Commit acbb783

Browse files
committed
Supporting different number of landmarks.
1 parent 07de821 commit acbb783

File tree

4 files changed

+57
-47
lines changed

4 files changed

+57
-47
lines changed

face_alignment_test.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def main() -> None:
4848
help='Weights to be loaded for face alignment, can be either 2DFAN2, 2DFAN4, ' +
4949
'or 2DFAN2_ALT (default=2DFAN2_ALT)')
5050
parser.add_argument('--alignment-alternative-pth', '-ap', default=None,
51-
help='Alternative pth file to be loaded for face alaignment')
51+
help='Alternative pth file to be loaded for face alignment')
52+
parser.add_argument('--alignment-alternative-landmarks', '-al', default=None,
53+
help='Alternative number of landmarks to detect')
5254
parser.add_argument('--alignment-device', '-ad', default='cuda:0',
5355
help='Device to be used for face alignment (default=cuda:0)')
5456
parser.add_argument('--hide-alignment-results', '-ha', help='Do not visualise face alignment results',
@@ -89,6 +91,8 @@ def main() -> None:
8991
fa_model = FANPredictor.get_model(args.alignment_weights)
9092
if args.alignment_alternative_pth is not None:
9193
fa_model.weights = args.alignment_alternative_pth
94+
if args.alignment_alternative_landmarks is not None:
95+
fa_model.config.num_landmarks = int(args.alignment_alternative_landmarks)
9296
landmark_detector = FANPredictor(device=args.alignment_device, model=fa_model)
9397
print(f"Landmark detector created using FAN ({fa_model.weights}).")
9498
else:

ibug/face_alignment/fan/fan.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,16 @@ def __init__(self, config):
133133
self.add_module('bn_end' + str(hg_module),
134134
nn.InstanceNorm2d(self.config.hg_num_features) if self.config.use_instance_norm
135135
else nn.BatchNorm2d(self.config.hg_num_features))
136-
self.add_module('l' + str(hg_module), nn.Conv2d(self.config.hg_num_features, 68,
136+
self.add_module('l' + str(hg_module), nn.Conv2d(self.config.hg_num_features,
137+
self.config.num_landmarks,
137138
kernel_size=1, stride=1, padding=0))
138139

139140
if hg_module < self.config.num_modules - 1:
140141
self.add_module('bl' + str(hg_module), nn.Conv2d(self.config.hg_num_features,
141142
self.config.hg_num_features,
142143
kernel_size=1, stride=1, padding=0))
143-
self.add_module('al' + str(hg_module), nn.Conv2d(68, self.config.hg_num_features,
144+
self.add_module('al' + str(hg_module), nn.Conv2d(self.config.num_landmarks,
145+
self.config.hg_num_features,
144146
kernel_size=1, stride=1, padding=0))
145147

146148
def forward(self, x):

ibug/face_alignment/fan/fan_predictor.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,22 @@ def get_model(name: str = '2dfan2') -> SimpleNamespace:
3434
config=SimpleNamespace(crop_ratio=0.55, input_size=256, num_modules=2,
3535
hg_num_features=256, hg_depth=4, use_avg_pool=False,
3636
use_instance_norm=False, stem_conv_kernel_size=7,
37-
stem_conv_stride=2, stem_pool_kernel_size=2))
37+
stem_conv_stride=2, stem_pool_kernel_size=2,
38+
num_landmarks=68))
3839
elif name == '2dfan4':
3940
return SimpleNamespace(weights=os.path.join(os.path.dirname(__file__), 'weights', '2dfan4.pth'),
4041
config=SimpleNamespace(crop_ratio=0.55, input_size=256, num_modules=4,
4142
hg_num_features=256, hg_depth=4, use_avg_pool=True,
4243
use_instance_norm=False, stem_conv_kernel_size=7,
43-
stem_conv_stride=2, stem_pool_kernel_size=2))
44+
stem_conv_stride=2, stem_pool_kernel_size=2,
45+
num_landmarks=68))
4446
elif name == '2dfan2_alt':
4547
return SimpleNamespace(weights=os.path.join(os.path.dirname(__file__), 'weights', '2dfan2_alt.pth'),
4648
config=SimpleNamespace(crop_ratio=0.55, input_size=256, num_modules=2,
4749
hg_num_features=256, hg_depth=4, use_avg_pool=False,
4850
use_instance_norm=False, stem_conv_kernel_size=7,
49-
stem_conv_stride=2, stem_pool_kernel_size=2))
51+
stem_conv_stride=2, stem_pool_kernel_size=2,
52+
num_landmarks=68))
5053
else:
5154
raise ValueError('name must be set to either 2dfan2, 2dfan4, or 2dfan2_alt')
5255

ibug/face_alignment/utils.py

+42-41
Original file line numberDiff line numberDiff line change
@@ -3,48 +3,49 @@
33
from typing import Optional, Sequence, Tuple
44

55

6-
__all__ = ['plot_landmarks']
6+
__all__ = ['get_landmark_connectivity', 'plot_landmarks']
7+
8+
9+
def get_landmark_connectivity(num_landmarks):
10+
if num_landmarks == 68:
11+
return ((0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12),
12+
(12, 13), (13, 14), (14, 15), (15, 16), (17, 18), (18, 19), (19, 20), (20, 21), (22, 23), (23, 24),
13+
(24, 25), (25, 26), (27, 28), (28, 29), (29, 30), (31, 32), (32, 33), (33, 34), (34, 35), (36, 37),
14+
(37, 38), (38, 39), (40, 41), (41, 36), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 42),
15+
(48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58),
16+
(58, 59), (59, 48), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 60),
17+
(39, 40))
18+
elif num_landmarks == 100:
19+
return ((0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12),
20+
(12, 13), (13, 14), (14, 15), (15, 16), (17, 18), (18, 19), (19, 20), (20, 21), (22, 23), (23, 24),
21+
(24, 25), (25, 26), (68, 69), (69, 70), (70, 71), (72, 73), (73, 74), (74, 75), (36, 76), (76, 37),
22+
(37, 77), (77, 38), (38, 78), (78, 39), (39, 40), (40, 79), (79, 41), (41, 36), (42, 80), (80, 43),
23+
(43, 81), (81, 44), (44, 82), (82, 45), (45, 46), (46, 83), (83, 47), (47, 42), (27, 28), (28, 29),
24+
(29, 30), (30, 33), (31, 32), (32, 33), (33, 34), (34, 35), (84, 85), (86, 87), (48, 49), (49, 88),
25+
(88, 50), (50, 51), (51, 52), (52, 89), (89, 53), (53, 54), (54, 55), (55, 90), (90, 56), (56, 57),
26+
(57, 58), (58, 91), (91, 59), (59, 48), (60, 92), (92, 93), (93, 61), (61, 62), (62, 63), (63, 94),
27+
(94, 95), (95, 64), (64, 96), (96, 97), (97, 65), (65, 66), (66, 67), (67, 98), (98, 99), (99, 60),
28+
(17, 68), (21, 71), (22, 72), (26, 75))
29+
else:
30+
return None
731

832

933
def plot_landmarks(image: np.ndarray, landmarks: np.ndarray, landmark_scores: Optional[Sequence[float]] = None,
1034
threshold: float = 0.2, line_colour: Tuple[int, int, int] = (0, 255, 0),
11-
pts_colour: Tuple[int, int, int] = (0, 0, 255),
12-
line_thickness: int = 1, pts_radius: int = 1) -> None:
13-
if landmarks.shape[0] > 0:
14-
if landmark_scores is None:
15-
landmark_scores = np.full(shape=(landmarks.shape[0],), fill_value=threshold + 1)
16-
if landmarks.shape[0] == 68:
17-
for idx in range(len(landmarks) - 1):
18-
if idx not in [16, 21, 26, 30, 35, 41, 47, 59]:
19-
if landmark_scores[idx] >= threshold and landmark_scores[idx + 1] >= threshold:
20-
cv2.line(image, tuple(landmarks[idx].astype(int).tolist()),
21-
tuple(landmarks[idx + 1].astype(int).tolist()),
22-
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
23-
if idx == 30:
24-
if landmark_scores[30] >= threshold and landmark_scores[33] >= threshold:
25-
cv2.line(image, tuple(landmarks[30].astype(int).tolist()),
26-
tuple(landmarks[33].astype(int).tolist()),
27-
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
28-
elif idx == 36:
29-
if landmark_scores[36] >= threshold and landmark_scores[41] >= threshold:
30-
cv2.line(image, tuple(landmarks[36].astype(int).tolist()),
31-
tuple(landmarks[41].astype(int).tolist()),
32-
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
33-
elif idx == 42:
34-
if landmark_scores[42] >= threshold and landmark_scores[47] >= threshold:
35-
cv2.line(image, tuple(landmarks[42].astype(int).tolist()),
36-
tuple(landmarks[47].astype(int).tolist()),
37-
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
38-
elif idx == 48:
39-
if landmark_scores[48] >= threshold and landmark_scores[59] >= threshold:
40-
cv2.line(image, tuple(landmarks[48].astype(int).tolist()),
41-
tuple(landmarks[59].astype(int).tolist()),
42-
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
43-
elif idx == 60:
44-
if landmark_scores[60] >= threshold and landmark_scores[67] >= threshold:
45-
cv2.line(image, tuple(landmarks[60].astype(int).tolist()),
46-
tuple(landmarks[67].astype(int).tolist()),
47-
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
48-
for landmark, score in zip(landmarks, landmark_scores):
49-
if score >= threshold:
50-
cv2.circle(image, tuple(landmark.astype(int).tolist()), pts_radius, pts_colour, -1)
35+
pts_colour: Tuple[int, int, int] = (0, 0, 255), line_thickness: int = 1, pts_radius: int = 1,
36+
landmark_connectivity: Optional[Sequence[Sequence[int]]] = None) -> None:
37+
num_landmarks = len(landmarks)
38+
if landmark_scores is None:
39+
landmark_scores = np.full((num_landmarks,), threshold + 1.0, dtype=float)
40+
if landmark_connectivity is None:
41+
landmark_connectivity = get_landmark_connectivity(len(landmarks))
42+
if landmark_connectivity is not None:
43+
for (idx1, idx2) in landmark_connectivity:
44+
if (idx1 < num_landmarks and idx2 < num_landmarks and
45+
landmark_scores[idx1] >= threshold and landmark_scores[idx2] >= threshold):
46+
cv2.line(image, tuple(landmarks[idx1].astype(int).tolist()),
47+
tuple(landmarks[idx2].astype(int).tolist()),
48+
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
49+
for landmark, score in zip(landmarks, landmark_scores):
50+
if score >= threshold:
51+
cv2.circle(image, tuple(landmark.astype(int).tolist()), pts_radius, pts_colour, -1)

0 commit comments

Comments
 (0)