Skip to content

Commit 54f8081

Browse files
committedJan 21, 2025
Merge pull request #29 from Andrew-Luo1:aloha_handover
PiperOrigin-RevId: 718018875 Change-Id: If3ab6a67b96d435a89b96cae44200b1031dfdfcb
2 parents 28c9487 + fa4e8f1 commit 54f8081

File tree

8 files changed

+389
-31
lines changed

8 files changed

+389
-31
lines changed
 

‎mujoco_playground/_src/manipulation/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from mujoco import mjx
2121

2222
from mujoco_playground._src import mjx_env
23+
from mujoco_playground._src.manipulation.aloha import handover as aloha_handover
2324
from mujoco_playground._src.manipulation.aloha import single_peg_insertion as aloha_peg
2425
from mujoco_playground._src.manipulation.franka_emika_panda import open_cabinet as panda_open_cabinet
2526
from mujoco_playground._src.manipulation.franka_emika_panda import pick as panda_pick
@@ -29,6 +30,7 @@
2930
from mujoco_playground._src.manipulation.leap_hand import rotate_z as leap_rotate_z
3031

3132
_envs = {
33+
"AlohaHandOver": aloha_handover.HandOver,
3234
"AlohaSinglePegInsertion": aloha_peg.SinglePegInsertion,
3335
"PandaPickCube": panda_pick.PandaPickCube,
3436
"PandaPickCubeOrientation": panda_pick.PandaPickCubeOrientation,
@@ -40,6 +42,7 @@
4042
}
4143

4244
_cfgs = {
45+
"AlohaHandOver": aloha_handover.default_config,
4346
"AlohaSinglePegInsertion": aloha_peg.default_config,
4447
"PandaPickCube": panda_pick.default_config,
4548
"PandaPickCubeOrientation": panda_pick.default_config,

‎mujoco_playground/_src/manipulation/aloha/aloha_constants.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,7 @@
1616

1717
from mujoco_playground._src import mjx_env
1818

19-
XML_PATH = (
20-
mjx_env.ROOT_PATH
21-
/ "manipulation"
22-
/ "aloha"
23-
/ "xmls"
24-
/ "mjx_single_peg_insertion.xml"
25-
)
19+
XML_PATH = mjx_env.ROOT_PATH / "manipulation" / "aloha" / "xmls"
2620

2721
ARM_JOINTS = [
2822
"left/waist",
@@ -49,3 +43,10 @@
4943
"right/right_finger_top",
5044
"right/right_finger_bottom",
5145
]
46+
47+
FINGER_JOINTS = [
48+
"left/left_finger",
49+
"left/right_finger",
50+
"right/left_finger",
51+
"right/right_finger",
52+
]

‎mujoco_playground/_src/manipulation/aloha/base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
from typing import Any, Dict, Optional, Union
1818

1919
from etils import epath
20+
import jax.numpy as jp
2021
from ml_collections import config_dict
2122
import mujoco
2223
from mujoco import mjx
24+
import numpy as np
2325

26+
from mujoco_playground._src import collision
2427
from mujoco_playground._src import mjx_env
28+
from mujoco_playground._src.manipulation.aloha import aloha_constants as consts
2529

2630

2731
def get_assets() -> Dict[str, bytes]:
@@ -58,6 +62,26 @@ def __init__(
5862
self._mjx_model = mjx.put_model(self._mj_model)
5963
self._xml_path = xml_path
6064

65+
def _post_init_aloha(self, keyframe: str = "home"):
66+
"""Initializes helpful robot properties."""
67+
self._left_gripper_site = self._mj_model.site("left/gripper").id
68+
self._right_gripper_site = self._mj_model.site("right/gripper").id
69+
self._table_geom = self._mj_model.geom("table").id
70+
self._finger_geoms = [
71+
self._mj_model.geom(geom_id).id for geom_id in consts.FINGER_GEOMS
72+
]
73+
self._init_q = jp.array(self._mj_model.keyframe(keyframe).qpos)
74+
self._init_ctrl = jp.array(self._mj_model.keyframe(keyframe).ctrl)
75+
self._lowers, self._uppers = self.mj_model.actuator_ctrlrange.T
76+
arm_joint_ids = [self._mj_model.joint(j).id for j in consts.ARM_JOINTS]
77+
self._arm_qadr = jp.array(
78+
[self._mj_model.jnt_qposadr[joint_id] for joint_id in arm_joint_ids]
79+
)
80+
self._finger_qposadr = np.array([
81+
self._mj_model.jnt_qposadr[self._mj_model.joint(j).id]
82+
for j in consts.FINGER_JOINTS
83+
])
84+
6185
@property
6286
def xml_path(self) -> str:
6387
return self._xml_path
@@ -73,3 +97,11 @@ def mj_model(self) -> mujoco.MjModel:
7397
@property
7498
def mjx_model(self) -> mjx.Model:
7599
return self._mjx_model
100+
101+
def hand_table_collision(self, data) -> jp.ndarray:
102+
# Check for collisions with the floor.
103+
hand_table_collisions = [
104+
collision.geoms_colliding(data, self._table_geom, g)
105+
for g in self._finger_geoms
106+
]
107+
return (sum(hand_table_collisions) > 0).astype(float)
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
# Copyright 2025 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Handover task for ALOHA."""
16+
17+
from typing import Any, Dict, Optional, Union
18+
19+
import jax
20+
from jax import numpy as jp
21+
from ml_collections import config_dict
22+
from mujoco import mjx
23+
24+
from mujoco_playground._src import mjx_env
25+
from mujoco_playground._src.manipulation.aloha import aloha_constants as consts
26+
from mujoco_playground._src.manipulation.aloha import base as aloha_base
27+
28+
29+
def default_config() -> config_dict.ConfigDict:
30+
return config_dict.create(
31+
ctrl_dt=0.02,
32+
sim_dt=0.005,
33+
episode_length=250, # 5 sec.
34+
action_repeat=1,
35+
action_scale=0.015,
36+
reward_config=config_dict.create(
37+
scales=config_dict.create(
38+
gripper_box=1,
39+
box_handover=4,
40+
handover_target=8,
41+
no_table_collision=0.3,
42+
),
43+
),
44+
)
45+
46+
47+
# Default parameters: 12 cm decay range centered around x = 0.
48+
def logistic_barrier(x: jax.Array, x0=0, k=100, direction=1.0):
49+
# direction = 1.0: Penalize going to the left.
50+
return 1 / (1 + jp.exp(-k * direction * (x - x0)))
51+
52+
53+
class HandOver(aloha_base.AlohaEnv):
54+
"""Single peg insertion task for ALOHA."""
55+
56+
def __init__(
57+
self,
58+
config: config_dict.ConfigDict = default_config(),
59+
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
60+
):
61+
super().__init__(
62+
xml_path=(consts.XML_PATH / 'mjx_hand_over.xml').as_posix(),
63+
config=config,
64+
config_overrides=config_overrides,
65+
)
66+
self._post_init()
67+
self.grip_alpha = 0.1
68+
69+
def _post_init(self):
70+
self._post_init_aloha(keyframe='home')
71+
# Aid finger exploration
72+
self._lowers[6] = 0.01
73+
self._mocap_target = self._mj_model.body('mocap_target').mocapid
74+
self._box_body = self._mj_model.body('box').id
75+
self._box_top_site = self._mj_model.site('box_top').id
76+
self._box_bottom_site = self._mj_model.site('box_bottom').id
77+
78+
self._box_qadr = self._mj_model.jnt_qposadr[
79+
self._mj_model.body_jntadr[self._box_body]
80+
]
81+
82+
# Used for reward calculation.
83+
self._left_thresh = -0.1
84+
self._right_thresh = 0.0
85+
self._handover_pos = jp.array([0.0, 0.0, 0.24])
86+
87+
self._box_geom = self._mj_model.geom('box').id
88+
self._picked_q = self._mj_model.keyframe('picked').qpos
89+
self._picked_ctrl = self._mj_model.keyframe('picked').ctrl
90+
self._transferred_q = self._mj_model.keyframe('transferred').qpos
91+
self._transferred_ctrl = self._mj_model.keyframe('transferred').ctrl
92+
93+
def reset(self, rng: jax.Array) -> mjx_env.State:
94+
rng, rng_box_x, rng_box_y = jax.random.split(rng, 3)
95+
96+
box_xy = jp.array([
97+
jax.random.uniform(rng_box_x, (), minval=-0.05, maxval=0.05),
98+
jax.random.uniform(rng_box_y, (), minval=-0.1, maxval=0.1),
99+
])
100+
init_q = self._init_q.at[self._box_qadr : self._box_qadr + 2].add(box_xy)
101+
102+
data = mjx_env.init(
103+
self._mjx_model,
104+
init_q,
105+
jp.zeros(self._mjx_model.nv, dtype=float),
106+
ctrl=self._init_ctrl,
107+
)
108+
109+
rng, rng_target = jax.random.split(rng)
110+
target_pos = jp.array([0.20, 0.0, 0.25])
111+
target_pos += jax.random.uniform(
112+
rng_target, (3,), minval=-0.15, maxval=0.15
113+
)
114+
target_x = jp.clip(target_pos[0], 0.15, None) # Saturate log barrier.
115+
target_pos = target_pos.at[0].set(target_x)
116+
117+
data = data.replace(
118+
mocap_pos=data.mocap_pos.at[self._mocap_target].set(target_pos)
119+
)
120+
info = {
121+
'rng': rng,
122+
'target_pos': target_pos,
123+
'prev_potential': jp.array(0.0, dtype=float),
124+
'_steps': jp.array(0, dtype=int),
125+
'episode_picked': jp.array(0, dtype=bool), # To help count above.
126+
}
127+
128+
obs = self._get_obs(data, info)
129+
reward, done = jp.zeros(2)
130+
131+
metrics = {
132+
'out_of_bounds': jp.array(0.0, dtype=float),
133+
**{k: 0.0 for k in self._config.reward_config.scales.keys()},
134+
}
135+
136+
return mjx_env.State(data, obs, reward, done, metrics, info)
137+
138+
def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
139+
newly_reset = state.info['_steps'] == 0
140+
state.info['episode_picked'] = jp.where(
141+
newly_reset, 0, state.info['episode_picked']
142+
)
143+
state.info['prev_potential'] = jp.where(
144+
newly_reset, 0.0, state.info['prev_potential']
145+
)
146+
147+
# Scale actions
148+
delta = action * self._config.action_scale
149+
ctrl = state.data.ctrl + delta
150+
ctrl = jp.clip(ctrl, self._lowers, self._uppers)
151+
152+
data = mjx_env.step(self._mjx_model, state.data, ctrl, self.n_substeps)
153+
154+
raw_rewards = self._get_reward(data, state.info)
155+
156+
rewards = {
157+
k: v * self._config.reward_config.scales[k]
158+
for k, v in raw_rewards.items()
159+
}
160+
potential = sum(rewards.values()) / sum(
161+
self._config.reward_config.scales.values()
162+
)
163+
164+
# Reward progress. Clip at zero to not penalize mistakes like dropping
165+
# during exploration.
166+
reward = jp.maximum(
167+
potential - state.info['prev_potential'], jp.zeros_like(potential)
168+
)
169+
170+
box_pos = data.xpos[self._box_body]
171+
172+
# Don't affect learning to transfer between hands but bias to holding the
173+
# end state.
174+
l_gripper = data.site_xpos[self._left_gripper_site]
175+
condition = logistic_barrier(l_gripper[0], direction=-1) * logistic_barrier(
176+
box_pos[0], 0.10
177+
)
178+
reward += 0.02 * potential * condition
179+
180+
state.info['prev_potential'] = jp.maximum(
181+
potential, state.info['prev_potential']
182+
)
183+
reward = jp.where(newly_reset, 0.0, reward) # Prevent first-step artifact
184+
185+
# No reward information if you've dropped a block after you've picked it up.
186+
picked = box_pos[2] > 0.15
187+
state.info['episode_picked'] = jp.logical_or(
188+
state.info['episode_picked'], picked
189+
)
190+
dropped = (box_pos[2] < 0.05) & state.info['episode_picked']
191+
reward += dropped.astype(float) * -0.1 # Small penalty.
192+
193+
out_of_bounds = jp.any(jp.abs(box_pos) > 1.0)
194+
out_of_bounds |= box_pos[2] < 0.0
195+
done = (
196+
out_of_bounds
197+
| jp.isnan(data.qpos).any()
198+
| jp.isnan(data.qvel).any()
199+
| dropped
200+
)
201+
state.info['_steps'] += self._config.action_repeat
202+
state.info['_steps'] = jp.where(
203+
done | (state.info['_steps'] >= self._config.episode_length),
204+
0,
205+
state.info['_steps'],
206+
)
207+
208+
state.metrics.update(**rewards, out_of_bounds=out_of_bounds.astype(float))
209+
210+
obs = self._get_obs(data, state.info)
211+
return mjx_env.State(
212+
data, obs, reward, done.astype(float), state.metrics, state.info
213+
)
214+
215+
def _get_reward(self, data: mjx.Data, info: Dict[str, Any]) -> Dict[str, Any]:
216+
def distance(x, y):
217+
return jp.exp(-10 * jp.linalg.norm(x - y))
218+
219+
box_top = data.site_xpos[self._box_top_site]
220+
box_bottom = data.site_xpos[self._box_bottom_site]
221+
box = data.xpos[self._box_body]
222+
l_gripper = data.site_xpos[self._left_gripper_site]
223+
r_gripper = data.site_xpos[self._right_gripper_site]
224+
225+
pre = jp.where(box[0] < self._left_thresh, 1.0, 0.0)
226+
past = jp.where(box[0] >= self._right_thresh, 1.0, 0.0)
227+
btwn = (1 - pre) * (1 - past)
228+
229+
#### Gripper Box
230+
r_lg = distance(box_top, l_gripper) * (pre + btwn)
231+
# If you're past the left threshold, also reward the right gripper.
232+
r_rg = distance(box_bottom, r_gripper) * (btwn + past)
233+
# Maintain reward level after left out of range.
234+
r_rg_bias = distance(box_bottom, r_gripper) * past
235+
236+
#### Box Handover to handover point
237+
box_handover = distance(box, self._handover_pos)
238+
# Maintain this term after RH takes box away.
239+
hand_handover = distance(l_gripper, self._handover_pos) * past
240+
box_handover = jp.maximum(box_handover, hand_handover)
241+
242+
#### Bring box to target
243+
box_target = distance(info['target_pos'], box) * (r_rg + r_rg_bias)
244+
# Don't let the left hand do it.
245+
box_target *= logistic_barrier(l_gripper[0], direction=-1)
246+
247+
#### Avoid table collision - unstable simulation.
248+
table_collision = self.hand_table_collision(data)
249+
250+
return {
251+
'gripper_box': r_lg + r_rg + r_rg_bias,
252+
'box_handover': box_handover,
253+
'handover_target': box_target,
254+
'no_table_collision': 1 - table_collision,
255+
}
256+
257+
def _get_obs(self, data: mjx.Data, info: Dict[str, Any]) -> jax.Array:
258+
left_gripper_pos = data.site_xpos[self._left_gripper_site]
259+
left_gripper_mat = data.site_xmat[self._left_gripper_site]
260+
right_gripper_pos = data.site_xpos[self._right_gripper_site]
261+
right_gripper_mat = data.site_xmat[self._right_gripper_site]
262+
box_mat = data.xmat[self._box_body]
263+
box_top = data.site_xpos[self._box_top_site]
264+
box_bottom = data.site_xpos[self._box_bottom_site]
265+
finger_qposadr = data.qpos[self._finger_qposadr]
266+
box_width = self.mjx_model.geom_size[self._box_geom][1]
267+
268+
obs = jp.concatenate([
269+
data.qpos,
270+
data.qvel,
271+
(finger_qposadr - box_width),
272+
box_top,
273+
box_bottom,
274+
left_gripper_pos,
275+
left_gripper_mat.ravel()[3:],
276+
right_gripper_pos,
277+
right_gripper_mat.ravel()[3:],
278+
box_mat.ravel()[3:],
279+
data.xpos[self._box_body] - info['target_pos'],
280+
(info['_steps'].reshape((1,)) / self._config.episode_length).astype(
281+
float
282+
),
283+
])
284+
285+
return obs
286+
287+
@property
288+
def observation_size(self) -> int:
289+
return 83

0 commit comments

Comments
 (0)