Skip to content

Commit 24c16ee

Browse files
Merge pull request #17 from Andrew-Luo1:main
PiperOrigin-RevId: 716742418 Change-Id: I2a3c32a5098f1645687af489320b834a1297ad4d
2 parents c8650a3 + 55b0c8f commit 24c16ee

File tree

6 files changed

+225
-119
lines changed

6 files changed

+225
-119
lines changed

learning/notebooks/training_vision_2.ipynb

Lines changed: 30 additions & 54 deletions
Large diffs are not rendered by default.

mujoco_playground/_src/manipulation/franka_emika_panda/pick_cartesian.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from mujoco import mjx
2626
import numpy as np
2727

28+
from mujoco_playground._src import collision
2829
from mujoco_playground._src import mjx_env
2930
from mujoco_playground._src.manipulation.franka_emika_panda import panda
3031
from mujoco_playground._src.manipulation.franka_emika_panda import panda_kinematics
@@ -49,7 +50,7 @@ def default_config():
4950
episode_length=200,
5051
action_repeat=1,
5152
# Size of cartesian increment.
52-
action_scale=0.01,
53+
action_scale=0.005,
5354
reward_config=config_dict.create(
5455
reward_scales=config_dict.create(
5556
# Gripper goes to the box.
@@ -58,6 +59,8 @@ def default_config():
5859
box_target=8.0,
5960
# Do not collide the gripper with the floor.
6061
no_floor_collision=0.25,
62+
# Do not collide cube with gripper
63+
no_box_collision=0.05,
6164
# Destabilizes training in cartesian action space.
6265
robot_target_qpos=0.0,
6366
),
@@ -69,6 +72,9 @@ def default_config():
6972
vision=False,
7073
vision_config=default_vision_config(),
7174
obs_noise=config_dict.create(brightness=[1.0, 1.0]),
75+
box_init_range=0.05,
76+
success_threshold=0.05,
77+
action_history_length=1,
7278
)
7379
return config
7480

@@ -112,6 +118,7 @@ def __init__(
112118

113119
# Set gripper in sight of camera
114120
self._post_init(obj_name='box', keyframe='low_home')
121+
self._box_geom = self._mj_model.geom('box').id
115122

116123
if self._vision:
117124
try:
@@ -168,9 +175,10 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
168175

169176
# intialize box position
170177
rng, rng_box = jax.random.split(rng)
178+
r_range = self._config.box_init_range
171179
box_pos = jp.array([
172180
x_plane,
173-
jax.random.uniform(rng_box, (), minval=-0.05, maxval=0.05),
181+
jax.random.uniform(rng_box, (), minval=-r_range, maxval=r_range),
174182
0.0,
175183
])
176184

@@ -218,6 +226,9 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
218226
'newly_reset': jp.array(False, dtype=bool),
219227
'prev_action': jp.zeros(3),
220228
'_steps': jp.array(0, dtype=int),
229+
'action_history': jp.zeros((
230+
self._config.action_history_length,
231+
)), # Gripper only
221232
}
222233

223234
reward, done = jp.zeros(2)
@@ -245,6 +256,17 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
245256

246257
def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
247258
"""Runs one timestep of the environment's dynamics."""
259+
action_history = (
260+
jp.roll(state.info['action_history'], 1).at[0].set(action[2])
261+
)
262+
state.info['action_history'] = action_history
263+
# Add action delay
264+
state.info['rng'], key = jax.random.split(state.info['rng'])
265+
action_idx = jax.random.randint(
266+
key, (), minval=0, maxval=self._config.action_history_length
267+
)
268+
action = action.at[2].set(state.info['action_history'][action_idx])
269+
248270
state.info['newly_reset'] = state.info['_steps'] == 0
249271

250272
newly_reset = state.info['newly_reset']
@@ -275,9 +297,7 @@ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
275297

276298
# Cartesian control
277299
increment = jp.zeros(4)
278-
increment = increment.at[1:].set(
279-
action[:]
280-
) # set y, z and gripper commands.
300+
increment = increment.at[1:].set(action) # set y, z and gripper commands.
281301
ctrl, new_tip_position, no_soln = self._move_tip(
282302
state.info['current_pos'],
283303
self._start_tip_transform[:3, :3],
@@ -297,6 +317,10 @@ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
297317
for k, v in raw_rewards.items()
298318
}
299319

320+
# Penalize collision with box.
321+
hand_box = collision.geoms_colliding(data, self._box_geom, self._hand_geom)
322+
raw_rewards['no_box_collision'] = jp.where(hand_box, 0.0, 1.0)
323+
300324
total_reward = jp.clip(sum(rewards.values()), -1e4, 1e4)
301325

302326
if not self._vision:
@@ -362,7 +386,11 @@ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
362386
def _get_success(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array:
363387
box_pos = data.xpos[self._obj_body]
364388
target_pos = info['target_pos']
365-
return jp.linalg.norm(box_pos - target_pos) < 0.05
389+
if (
390+
self._vision
391+
): # Randomized camera positions cannot see location along y line.
392+
box_pos, target_pos = box_pos[2], target_pos[2]
393+
return jp.linalg.norm(box_pos - target_pos) < self._config.success_threshold
366394

367395
def _move_tip(
368396
self,

mujoco_playground/_src/manipulation/franka_emika_panda/randomize_vision.py

Lines changed: 118 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -21,102 +21,167 @@
2121
from mujoco.mjx._src import math
2222
import numpy as np
2323

24-
FLOOR_GEOM_ID = 0
25-
BOX_GEOM_ID = 81
24+
from mujoco_playground._src.manipulation.franka_emika_panda import pick_cartesian
25+
26+
27+
def sample_light_position():
28+
position = np.zeros(3)
29+
while np.linalg.norm(position) < 1.0:
30+
position = np.random.uniform([1.5, -0.2, 0.8], [3, 0.2, 1.5])
31+
return position
32+
33+
34+
def perturb_orientation(
35+
key: jax.Array, original: jax.Array, deg: float
36+
) -> jax.Array:
37+
"""Perturbs a 3D or 4D orientation by up to deg."""
38+
key_axis, key_theta, key = jax.random.split(key, 3)
39+
perturb_axis = jax.random.uniform(key_axis, (3,), minval=-1, maxval=1)
40+
# Only perturb upwards in the y axis.
41+
key_y, key = jax.random.split(key, 2)
42+
perturb_axis = perturb_axis.at[1].set(
43+
jax.random.uniform(key_y, (), minval=0, maxval=1)
44+
)
45+
perturb_axis = perturb_axis / jp.linalg.norm(perturb_axis)
46+
perturb_theta = jax.random.uniform(
47+
key_theta, shape=(1,), minval=0, maxval=np.deg2rad(deg)
48+
)
49+
rot_offset = math.axis_angle_to_quat(perturb_axis, perturb_theta)
50+
if original.shape == (4,):
51+
return math.quat_mul(rot_offset, original)
52+
elif original.shape == (3,):
53+
return math.rotate(original, rot_offset)
54+
else:
55+
raise ValueError('Invalid input shape:', original.shape)
2656

2757

2858
def domain_randomize(
2959
mjx_model: mjx.Model, num_worlds: int
3060
) -> Tuple[mjx.Model, mjx.Model]:
3161
"""Tile the necessary axes for the Madrona BatchRenderer."""
62+
mj_model = pick_cartesian.PandaPickCubeCartesian().mj_model
63+
FLOOR_GEOM_ID = mj_model.geom('floor').id
64+
BOX_GEOM_ID = mj_model.geom('box').id
65+
STRIP_GEOM_ID = mj_model.geom('init_space').id
66+
3267
in_axes = jax.tree_util.tree_map(lambda x: None, mjx_model)
3368
in_axes = in_axes.tree_replace({
3469
'geom_rgba': 0,
3570
'geom_matid': 0,
36-
'geom_size': 0,
37-
'geom_friction': 0,
3871
'cam_pos': 0,
3972
'cam_quat': 0,
4073
'light_pos': 0,
4174
'light_dir': 0,
4275
'light_directional': 0,
4376
'light_castshadow': 0,
44-
'light_cutoff': 0,
4577
})
4678
rng = jax.random.key(0)
4779

80+
# Simpler logic implementing via Numpy.
81+
np.random.seed(0)
82+
light_positions = [sample_light_position() for _ in range(num_worlds)]
83+
light_positions = jp.array(light_positions)
84+
4885
@jax.vmap
49-
def rand(rng):
86+
def rand(rng: jax.Array, light_position: jax.Array):
87+
"""Generate randomized model fields."""
5088
_, key = jax.random.split(rng, 2)
51-
# friction
52-
friction = jax.random.uniform(key, (1,), minval=0.6, maxval=1.4)
53-
friction = mjx_model.geom_friction.at[:, 0].set(friction)
54-
key_r, key_g, key_b, key = jax.random.split(key, 4)
55-
rgba = jp.array([
56-
jax.random.uniform(key_r, (), minval=0.5, maxval=1.0),
57-
jax.random.uniform(key_g, (), minval=0.0, maxval=0.5),
58-
jax.random.uniform(key_b, (), minval=0.0, maxval=0.5),
59-
1.0,
60-
])
89+
90+
#### Apearance ####
91+
# Sample a random color for the box
92+
key_box, key_strip, key_floor, key = jax.random.split(key, 4)
93+
rgba = jp.array(
94+
[jax.random.uniform(key_box, (), minval=0.5, maxval=1.0), 0.0, 0.0, 1.0]
95+
)
6196
geom_rgba = mjx_model.geom_rgba.at[BOX_GEOM_ID].set(rgba)
6297

98+
strip_white = jax.random.uniform(key_strip, (), minval=0.8, maxval=1.0)
99+
geom_rgba = mjx_model.geom_rgba.at[STRIP_GEOM_ID].set(
100+
jp.array([strip_white, strip_white, strip_white, 1.0])
101+
)
102+
63103
# Sample a shade of gray
64-
key_gs, key = jax.random.split(key)
65-
gray_scale = jax.random.uniform(key_gs, (), minval=0.0, maxval=0.8)
104+
gray_scale = jax.random.uniform(key_floor, (), minval=0.0, maxval=0.25)
66105
geom_rgba = geom_rgba.at[FLOOR_GEOM_ID].set(
67106
jp.array([gray_scale, gray_scale, gray_scale, 1.0])
68107
)
69108

70-
# Set unrandomized and randomized matID's to -1 and -2.
71-
geom_matid = jp.ones_like(mjx_model.geom_matid) * -1
72-
geom_matid = geom_matid.at[BOX_GEOM_ID].set(-2)
109+
mat_offset, num_geoms = 5, geom_rgba.shape[0]
110+
key_matid, key = jax.random.split(key)
111+
geom_matid = (
112+
jax.random.randint(key_matid, shape=(num_geoms,), minval=0, maxval=10)
113+
+ mat_offset
114+
)
115+
geom_matid = geom_matid.at[BOX_GEOM_ID].set(
116+
-2
117+
) # Use the above randomized colors
73118
geom_matid = geom_matid.at[FLOOR_GEOM_ID].set(-2)
119+
geom_matid = geom_matid.at[STRIP_GEOM_ID].set(-2)
74120

75-
key_pos, key = jax.random.split(key)
121+
#### Cameras ####
122+
key_pos, key_ori, key = jax.random.split(key, 3)
76123
cam_offset = jax.random.uniform(key_pos, (3,), minval=-0.05, maxval=0.05)
124+
assert (
125+
len(mjx_model.cam_pos) == 1
126+
), f'Expected single camera, got {len(mjx_model.cam_pos)}'
77127
cam_pos = mjx_model.cam_pos.at[0].set(mjx_model.cam_pos[0] + cam_offset)
128+
cam_quat = mjx_model.cam_quat.at[0].set(
129+
perturb_orientation(key_ori, mjx_model.cam_quat[0], 10)
130+
)
131+
132+
#### Lighting ####
133+
nlight = mjx_model.light_pos.shape[0]
134+
assert (
135+
nlight == 1
136+
), f'Sim2Real was trained with a single light source, got {nlight}'
137+
key_lsha, key_ldir, key_ldct, key = jax.random.split(key, 4)
78138

79-
key_axis, key_theta, key = jax.random.split(key, 3)
80-
perturb_axis = jax.random.uniform(key_axis, (3,), minval=-1, maxval=1)
81-
perturb_axis = perturb_axis / jp.linalg.norm(perturb_axis)
82-
perturb_theta = jax.random.uniform(
83-
key_theta, shape=(1,), maxval=np.deg2rad(10)
139+
# Direction
140+
shine_at = jp.array([0.661, -0.001, 0.179]) # Gripper starting position
141+
nom_dir = (shine_at - light_position) / jp.linalg.norm(
142+
shine_at - light_position
84143
)
85-
camera_rot_offset = math.axis_angle_to_quat(perturb_axis, perturb_theta)
86-
cam_quat = mjx_model.cam_quat.at[0].set(
87-
math.quat_mul(camera_rot_offset, mjx_model.cam_quat[0])
144+
light_dir = mjx_model.light_dir.at[0].set(
145+
perturb_orientation(key_ldir, nom_dir, 20)
88146
)
89147

90-
return friction, geom_rgba, geom_matid, cam_pos, cam_quat
148+
# Whether to cast shadows
149+
light_castshadow = jax.random.bernoulli(
150+
key_lsha, 0.75, shape=(nlight,)
151+
).astype(jp.float32)
91152

92-
friction, geom_rgba, geom_matid, cam_pos, cam_quat = rand(
93-
jax.random.split(rng, num_worlds)
94-
)
153+
# No need to randomize into specular lighting
154+
light_directional = jp.ones((nlight,))
155+
156+
return (
157+
geom_rgba,
158+
geom_matid,
159+
cam_pos,
160+
cam_quat,
161+
light_dir,
162+
light_directional,
163+
light_castshadow,
164+
)
165+
166+
(
167+
geom_rgba,
168+
geom_matid,
169+
cam_pos,
170+
cam_quat,
171+
light_dir,
172+
light_directional,
173+
light_castshadow,
174+
) = rand(jax.random.split(rng, num_worlds), light_positions)
95175

96176
mjx_model = mjx_model.tree_replace({
97177
'geom_rgba': geom_rgba,
98178
'geom_matid': geom_matid,
99-
'geom_size': jp.repeat(
100-
jp.expand_dims(mjx_model.geom_size, 0), num_worlds, axis=0
101-
),
102-
'geom_friction': friction,
103179
'cam_pos': cam_pos,
104180
'cam_quat': cam_quat,
105-
'light_pos': jp.repeat(
106-
jp.expand_dims(mjx_model.light_pos, 0), num_worlds, axis=0
107-
),
108-
'light_dir': jp.repeat(
109-
jp.expand_dims(mjx_model.light_dir, 0), num_worlds, axis=0
110-
),
111-
'light_directional': jp.repeat(
112-
jp.expand_dims(mjx_model.light_directional, 0), num_worlds, axis=0
113-
),
114-
'light_castshadow': jp.repeat(
115-
jp.expand_dims(mjx_model.light_castshadow, 0), num_worlds, axis=0
116-
),
117-
'light_cutoff': jp.repeat(
118-
jp.expand_dims(mjx_model.light_cutoff, 0), num_worlds, axis=0
119-
),
181+
'light_pos': light_positions,
182+
'light_dir': light_dir,
183+
'light_directional': light_directional,
184+
'light_castshadow': light_castshadow,
120185
})
121186

122187
return mjx_model, in_axes

mujoco_playground/_src/manipulation/franka_emika_panda/xmls/mjx_cabinet.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<mujoco model="panda scene">
1+
<mujoco model="panda cabinet">
22
<include file="mjx_scene.xml"/>
33

44
<worldbody>

mujoco_playground/_src/manipulation/franka_emika_panda/xmls/mjx_single_cube.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<mujoco model="panda scene">
1+
<mujoco model="panda single cube">
22
<include file="mjx_scene.xml"/>
33

44
<worldbody>

0 commit comments

Comments
 (0)