Skip to content

Commit 8fb6d46

Browse files
authored
Merge pull request #65 from christsa/main
MANO support and z-up feature for Meshes
2 parents 380a2c2 + 9113538 commit 8fb6d46

File tree

3 files changed

+99
-17
lines changed

3 files changed

+99
-17
lines changed

aitviewer/models/smpl.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,65 @@ def fk(
236236

237237
return output.vertices, output.joints
238238

239+
def fk_mano(
240+
self,
241+
hand_pose,
242+
betas,
243+
global_orient=None,
244+
trans=None,
245+
normalize_root=False,
246+
mano=True,
247+
):
248+
"""
249+
Convert mano pose data (joint angles and shape parameters) to positional data (joint and mesh vertex positions).
250+
:param hand_pose: A tensor of shape (N, N_JOINTS*3), i.e. joint angles in angle-axis format or PCA format (N, N_PCA_COMPONENTS). This contains all
251+
body joints which are not the root.
252+
:param betas: A tensor of shape (N, N_BETAS) containing the betas/shape parameters.
253+
:param global_orient: Orientation of the root or None. If specified expected shape is (N, 3).
254+
:param trans: translation that is applied to vertices and joints or None, this is the 'transl' parameter
255+
of the MANO Model. If specified expected shape is (N, 3).
256+
:param normalize_root: If set, it will normalize the root such that its orientation is the identity in the
257+
first frame and its position starts at the origin.
258+
:return: The resulting vertices and joints.
259+
"""
260+
261+
batch_size = hand_pose.shape[0]
262+
device = hand_pose.device
263+
264+
if global_orient is None:
265+
global_orient = torch.zeros([batch_size, 3]).to(dtype=hand_pose.dtype, device=device)
266+
if trans is None:
267+
trans = torch.zeros([batch_size, 3]).to(dtype=hand_pose.dtype, device=device)
268+
269+
# Batch shapes if they don't match batch dimension.
270+
if len(betas.shape) == 1 or betas.shape[0] == 1:
271+
betas = betas.repeat(hand_pose.shape[0], 1)
272+
betas = betas[:, : self.num_betas]
273+
274+
if normalize_root:
275+
# Make everything relative to the first root orientation.
276+
root_ori = aa2rot(global_orient)
277+
first_root_ori = torch.inverse(root_ori[0:1])
278+
root_ori = torch.matmul(first_root_ori, root_ori)
279+
global_orient = rot2aa(root_ori)
280+
trans = torch.matmul(first_root_ori.unsqueeze(0), trans.unsqueeze(-1)).squeeze()
281+
trans = trans - trans[0:1]
282+
283+
output = self.bm(
284+
hand_pose=hand_pose,
285+
betas=betas,
286+
global_orient=global_orient,
287+
transl=trans,
288+
)
289+
290+
return output.vertices, output.joints
291+
239292
def forward(self, *args, **kwargs):
240293
"""
241294
Forward pass using forward kinematics
242295
"""
243-
return self.fk(*args, **kwargs)
296+
297+
if "mano" in kwargs.keys():
298+
return self.fk_mano(*args, **kwargs)
299+
else:
300+
return self.fk(*args, **kwargs)

aitviewer/renderables/meshes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
draw_edges=False,
5252
draw_outline=False,
5353
instance_transforms=None,
54+
z_up=False,
5455
icon="\u008d",
5556
**kwargs,
5657
):
@@ -147,6 +148,9 @@ def _maybe_unsqueeze(x):
147148
self.clip_control = np.array((0, 0, 0), np.int32)
148149
self.clip_value = np.array((0, 0, 0), np.float32)
149150

151+
if z_up:
152+
self.rotation = np.matmul(np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]]), self.rotation)
153+
150154
@classmethod
151155
def instanced(cls, *args, positions=None, rotations=None, scales=None, **kwargs):
152156
"""

aitviewer/renderables/smpl.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -157,20 +157,28 @@ def __init__(
157157

158158
# First convert the relative joint angles to global joint angles in rotation matrix form.
159159
if self.smpl_layer.model_type != "flame":
160-
global_oris = local_to_global(
161-
torch.cat([self.poses_root, self.poses_body], dim=-1),
162-
self.skeleton[:, 0],
163-
output_format="rotmat",
164-
)
160+
if self.smpl_layer.model_type != "mano":
161+
global_oris = local_to_global(
162+
torch.cat([self.poses_root, self.poses_body, self.poses_left_hand, self.poses_right_hand], dim=-1),
163+
self.skeleton[:, 0],
164+
output_format="rotmat",
165+
)
166+
else:
167+
global_oris = local_to_global(
168+
torch.cat([self.poses_root, self.poses_body], dim=-1),
169+
self.skeleton[:, 0],
170+
output_format="rotmat",
171+
)
165172
global_oris = c2c(global_oris.reshape((self.n_frames, -1, 3, 3)))
166173
else:
167174
global_oris = np.tile(np.eye(3), self.joints.shape[:-1])[np.newaxis]
168175

169176
if self._z_up and not C.z_up:
170177
self.rotation = np.matmul(np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]]), self.rotation)
171178

172-
self.rbs = RigidBodies(self.joints, global_oris, length=0.1, gui_affine=False, name="Joint Angles")
173-
self._add_node(self.rbs, enabled=self._show_joint_angles)
179+
if self.smpl_layer.model_type != "mano":
180+
self.rbs = RigidBodies(self.joints, global_oris, length=0.1, gui_affine=False, name="Joint Angles")
181+
self._add_node(self.rbs, enabled=self._show_joint_angles)
174182

175183
self.mesh_seq = Meshes(
176184
self.vertices,
@@ -397,20 +405,33 @@ def fk(self, current_frame_only=False):
397405
trans = self.trans
398406
betas = self.betas
399407

400-
verts, joints = self.smpl_layer(
401-
poses_root=poses_root,
402-
poses_body=poses_body,
403-
poses_left_hand=poses_left_hand,
404-
poses_right_hand=poses_right_hand,
405-
betas=betas,
406-
trans=trans,
407-
)
408+
if self.smpl_layer.model_type == "mano":
409+
verts, joints = self.smpl_layer(
410+
hand_pose=poses_body,
411+
betas=betas,
412+
global_orient=poses_root,
413+
trans=trans,
414+
mano=True,
415+
)
416+
else:
417+
verts, joints = self.smpl_layer(
418+
poses_root=poses_root,
419+
poses_body=poses_body,
420+
poses_left_hand=poses_left_hand,
421+
poses_right_hand=poses_right_hand,
422+
betas=betas,
423+
trans=trans,
424+
)
408425

409426
# Apply post_fk_func if specified.
410427
if self.post_fk_func:
411428
verts, joints = self.post_fk_func(self, verts, joints, current_frame_only)
412429

413-
skeleton = self.smpl_layer.skeletons()["body"].T
430+
skeleton = (
431+
self.smpl_layer.skeletons()["body"].T
432+
if not self.smpl_layer.model_type == "mano"
433+
else self.smpl_layer.skeletons()["all"].T
434+
)
414435
faces = self.smpl_layer.bm.faces.astype(np.int64)
415436
joints = joints[:, : skeleton.shape[0]]
416437

0 commit comments

Comments
 (0)