Skip to content

Commit abeb30f

Browse files
authored
[BugFix]: Fetch IK bug / mobile manipulation IK bug fix (#996)
* Fix: fetch ik fix * Fix: fetch cpu ik fix * Fix: allow set custom root name for pd ee * demo_manual_control_continuous play script * Fix: remove comments * Fix: fetch pd_ee_target_delta_pose by using torso link as root for ik
1 parent 407af33 commit abeb30f

File tree

7 files changed

+815
-34
lines changed

7 files changed

+815
-34
lines changed

mani_skill/agents/controllers/pd_ee_pose.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from dataclasses import dataclass
2-
from typing import Literal, Sequence, Union
2+
from typing import Literal, Optional, Sequence, Union
33

44
import numpy as np
55
import torch
66
from gymnasium import spaces
77

88
from mani_skill.agents.controllers.utils.kinematics import Kinematics
9-
from mani_skill.utils import gym_utils
9+
from mani_skill.utils import gym_utils, sapien_utils
1010
from mani_skill.utils.geometry.rotation_conversions import (
1111
euler_angles_to_matrix,
1212
matrix_to_quaternion,
@@ -45,6 +45,13 @@ def _initialize_joints(self):
4545

4646
self.ee_link = self.kinematics.end_link
4747

48+
if self.config.root_link_name is not None:
49+
self.root_link = sapien_utils.get_obj_by_name(
50+
self.articulation.get_links(), self.config.root_link_name
51+
)
52+
else:
53+
self.root_link = self.articulation.root
54+
4855
def _initialize_action_space(self):
4956
low = np.float32(np.broadcast_to(self.config.pos_lower, 3))
5057
high = np.float32(np.broadcast_to(self.config.pos_upper, 3))
@@ -60,7 +67,7 @@ def ee_pose(self):
6067

6168
@property
6269
def ee_pose_at_base(self):
63-
to_base = self.articulation.pose.inv()
70+
to_base = self.root_link.pose.inv()
6471
return to_base * (self.ee_pose)
6572

6673
def reset(self):
@@ -69,9 +76,9 @@ def reset(self):
6976
self._target_pose = self.ee_pose_at_base
7077
else:
7178
# TODO (stao): this is a strange way to mask setting individual batched pose parts
72-
self._target_pose.raw_pose[
73-
self.scene._reset_mask
74-
] = self.ee_pose_at_base.raw_pose[self.scene._reset_mask]
79+
self._target_pose.raw_pose[self.scene._reset_mask] = (
80+
self.ee_pose_at_base.raw_pose[self.scene._reset_mask]
81+
)
7582

7683
def compute_target_pose(self, prev_ee_pose_at_base, action):
7784
# Keep the current rotation and change the position
@@ -156,6 +163,8 @@ class PDEEPosControllerConfig(ControllerConfig):
156163
] = "root_translation"
157164
"""Choice of frame to use for translational and rotational control of the end-effector. To learn how these work explicitly
158165
with videos of each one's behavior see https://maniskill.readthedocs.io/en/latest/user_guide/concepts/controllers.html#pd-ee-end-effector-pose"""
166+
root_link_name: Optional[str] = None
167+
"""Optionally set different root link for root translation control (e.g. if root is different than base)"""
159168
use_delta: bool = True
160169
"""Whether to use delta-action control. If true then actions indicate the delta/change in position via translation and orientation via
161170
rotation. If false, then actions indicate in the base frame (typically wherever the root link of the robot is) what pose the end effector

mani_skill/agents/controllers/utils/kinematics.py

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Code for kinematics utilities on CPU/GPU
33
"""
4+
45
from contextlib import contextmanager, redirect_stderr, redirect_stdout
56
from os import devnull
67
from typing import List
@@ -12,6 +13,7 @@
1213
"pytorch_kinematics_ms not installed. Install with pip install pytorch_kinematics_ms"
1314
)
1415
import torch
16+
from lxml import etree as ET
1517
from sapien.wrapper.pinocchio_model import PinocchioModel
1618

1719
from mani_skill.utils import common
@@ -44,12 +46,18 @@ def __init__(
4446
articulation (Articulation): the articulation object
4547
active_joint_indices (torch.Tensor): indices of the active joints that can be controlled
4648
"""
49+
50+
# NOTE (arth): urdf path with feasible kinematic chain. may not be same urdf used to
51+
# build the sapien articulation (e.g. sapien articulation may have added joints for
52+
# mobile base which should not be used in IK)
4753
self.urdf_path = urdf_path
4854
self.end_link = articulation.links_map[end_link_name]
49-
self.end_link_idx = articulation.links.index(self.end_link)
50-
self.active_joint_indices = active_joint_indices
55+
5156
self.articulation = articulation
5257
self.device = articulation.device
58+
59+
self.active_joint_indices = active_joint_indices
60+
5361
# note that everything past the end-link is ignored. Any joint whose ancestor is self.end_link is ignored
5462
cur_link = self.end_link
5563
active_ancestor_joints: List[ArticulationJoint] = []
@@ -58,17 +66,25 @@ def __init__(
5866
active_ancestor_joints.append(cur_link.joint)
5967
cur_link = cur_link.joint.parent_link
6068
active_ancestor_joints = active_ancestor_joints[::-1]
61-
self.active_ancestor_joints = active_ancestor_joints
6269

63-
# initially self.active_joint_indices references active joints that are controlled.
64-
# we also make the assumption that the active index is the same across all parallel managed joints
65-
self.active_ancestor_joint_idxs = [
66-
(x.active_index[0]).cpu().item() for x in self.active_ancestor_joints
67-
]
68-
self.controlled_joints_idx_in_qmask = [
69-
self.active_ancestor_joint_idxs.index(idx)
70-
for idx in self.active_joint_indices
70+
# NOTE (arth): some robots, like Fetch, have dummy joints that can mess with IK solver.
71+
# we assume that the urdf_path provides a valid kinematic chain, and prune joints
72+
# which are in the ManiSkill articulation but not in the kinematic chain
73+
with open(self.urdf_path, "r") as f:
74+
urdf_string = f.read()
75+
xml = ET.fromstring(urdf_string.encode("utf-8"))
76+
self._kinematic_chain_joint_names = set(
77+
node.get("name") for node in xml if node.tag == "joint"
78+
)
79+
self._kinematic_chain_link_names = set(
80+
node.get("name") for node in xml if node.tag == "link"
81+
)
82+
self.active_ancestor_joints = [
83+
x
84+
for x in active_ancestor_joints
85+
if x.name in self._kinematic_chain_joint_names
7186
]
87+
7288
if self.device.type == "cuda":
7389
self.use_gpu_ik = True
7490
self._setup_gpu()
@@ -79,14 +95,47 @@ def __init__(
7995
def _setup_cpu(self):
8096
"""setup the kinematics solvers on the CPU"""
8197
self.use_gpu_ik = False
82-
# NOTE (stao): currently using the pinnochio that comes packaged with SAPIEN
83-
self.qmask = torch.zeros(
84-
self.articulation.max_dof, dtype=bool, device=self.device
98+
99+
with open(self.urdf_path, "r") as f:
100+
xml = f.read()
101+
102+
joint_order = [
103+
j.name
104+
for j in self.articulation.active_joints
105+
if j.name in self._kinematic_chain_joint_names
106+
]
107+
link_order = [
108+
l.name
109+
for l in self.articulation.links
110+
if l.name in self._kinematic_chain_link_names
111+
]
112+
113+
self.pmodel = PinocchioModel(xml, [0, 0, -9.81])
114+
self.pmodel.set_joint_order(joint_order)
115+
self.pmodel.set_link_order(link_order)
116+
117+
controlled_joint_names = [
118+
self.articulation.active_joints[i].name for i in self.active_joint_indices
119+
]
120+
self.pmodel_controlled_joint_indices = torch.tensor(
121+
[joint_order.index(cj) for cj in controlled_joint_names],
122+
dtype=torch.int,
123+
device=self.device,
85124
)
86-
self.pmodel: PinocchioModel = self.articulation._objs[
87-
0
88-
].create_pinocchio_model()
89-
self.qmask[self.active_joint_indices] = 1
125+
126+
articulation_active_joint_names_to_idx = dict(
127+
(j.name, i) for i, j in enumerate(self.articulation.active_joints)
128+
)
129+
self.pmodel_active_joint_indices = torch.tensor(
130+
[articulation_active_joint_names_to_idx[jn] for jn in joint_order],
131+
dtype=torch.int,
132+
device=self.device,
133+
)
134+
135+
# NOTE (arth): pmodel will use urdf_path, set values based on this xml
136+
self.end_link_idx = link_order.index(self.end_link.name)
137+
self.qmask = torch.zeros(len(joint_order), dtype=bool, device=self.device)
138+
self.qmask[self.pmodel_controlled_joint_indices] = 1
90139

91140
def _setup_gpu(self):
92141
"""setup the kinematics solvers on the GPU"""
@@ -116,6 +165,16 @@ def suppress_stdout_stderr():
116165
num_retries=1,
117166
)
118167

168+
# initially self.active_joint_indices references active joints that are controlled.
169+
# we also make the assumption that the active index is the same across all parallel managed joints
170+
self.active_ancestor_joint_idxs = [
171+
(x.active_index[0]).cpu().item() for x in self.active_ancestor_joints
172+
]
173+
self.controlled_joints_idx_in_qmask = [
174+
self.active_ancestor_joint_idxs.index(idx)
175+
for idx in self.active_joint_indices
176+
]
177+
119178
self.qmask = torch.zeros(
120179
len(self.active_ancestor_joints), dtype=bool, device=self.device
121180
)
@@ -167,10 +226,15 @@ def compute_ik(
167226
if pos_only:
168227
jacobian = jacobian[:, 0:3]
169228

229+
# NOTE (arth): use only the parts of the jacobian that correspond to the active joints
230+
jacobian = jacobian[:, :, self.qmask]
231+
170232
# NOTE (stao): this method of IK is from https://mathweb.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf by Samuel R. Buss
171233
delta_joint_pos = torch.linalg.pinv(jacobian) @ action.unsqueeze(-1)
172-
return q0 + delta_joint_pos.squeeze(-1)
234+
235+
return q0[:, self.qmask] + delta_joint_pos.squeeze(-1)
173236
else:
237+
q0 = q0[:, self.pmodel_active_joint_indices]
174238
result, success, error = self.pmodel.compute_inverse_kinematics(
175239
self.end_link_idx,
176240
target_pose.sp,
@@ -180,7 +244,7 @@ def compute_ik(
180244
)
181245
if success:
182246
return common.to_tensor(
183-
[result[self.active_ancestor_joint_idxs]], device=self.device
247+
[result[self.pmodel_controlled_joint_indices]], device=self.device
184248
)
185249
else:
186250
return None

mani_skill/agents/robots/fetch/fetch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
class Fetch(BaseAgent):
2828
uid = "fetch"
2929
urdf_path = f"{PACKAGE_ASSET_DIR}/robots/fetch/fetch.urdf"
30+
urdf_arm_ik_path = f"{PACKAGE_ASSET_DIR}/robots/fetch/fetch_torso_up.urdf"
3031
urdf_config = dict(
3132
_materials=dict(
3233
gripper=dict(static_friction=2.0, dynamic_friction=2.0, restitution=0.0)
@@ -149,7 +150,8 @@ def _controller_configs(self):
149150
damping=self.arm_damping,
150151
force_limit=self.arm_force_limit,
151152
ee_link=self.ee_link_name,
152-
urdf_path=self.urdf_path,
153+
urdf_path=self.urdf_arm_ik_path,
154+
root_link_name="torso_lift_link",
153155
)
154156
arm_pd_ee_delta_pose = PDEEPoseControllerConfig(
155157
joint_names=self.arm_joint_names,
@@ -161,7 +163,8 @@ def _controller_configs(self):
161163
damping=self.arm_damping,
162164
force_limit=self.arm_force_limit,
163165
ee_link=self.ee_link_name,
164-
urdf_path=self.urdf_path,
166+
urdf_path=self.urdf_arm_ik_path,
167+
root_link_name="torso_lift_link",
165168
)
166169

167170
arm_pd_ee_target_delta_pos = deepcopy(arm_pd_ee_delta_pos)

0 commit comments

Comments
 (0)