|
21 | 21 | from mujoco.mjx._src import math
|
22 | 22 | import numpy as np
|
23 | 23 |
|
24 |
| -FLOOR_GEOM_ID = 0 |
25 |
| -BOX_GEOM_ID = 81 |
| 24 | +from mujoco_playground._src.manipulation.franka_emika_panda import pick_cartesian |
| 25 | + |
| 26 | + |
| 27 | +def sample_light_position(): |
| 28 | + position = np.zeros(3) |
| 29 | + while np.linalg.norm(position) < 1.0: |
| 30 | + position = np.random.uniform([1.5, -0.2, 0.8], [3, 0.2, 1.5]) |
| 31 | + return position |
| 32 | + |
| 33 | + |
| 34 | +def perturb_orientation( |
| 35 | + key: jax.Array, original: jax.Array, deg: float |
| 36 | +) -> jax.Array: |
| 37 | + """Perturbs a 3D or 4D orientation by up to deg.""" |
| 38 | + key_axis, key_theta, key = jax.random.split(key, 3) |
| 39 | + perturb_axis = jax.random.uniform(key_axis, (3,), minval=-1, maxval=1) |
| 40 | + # Only perturb upwards in the y axis. |
| 41 | + key_y, key = jax.random.split(key, 2) |
| 42 | + perturb_axis = perturb_axis.at[1].set( |
| 43 | + jax.random.uniform(key_y, (), minval=0, maxval=1) |
| 44 | + ) |
| 45 | + perturb_axis = perturb_axis / jp.linalg.norm(perturb_axis) |
| 46 | + perturb_theta = jax.random.uniform( |
| 47 | + key_theta, shape=(1,), minval=0, maxval=np.deg2rad(deg) |
| 48 | + ) |
| 49 | + rot_offset = math.axis_angle_to_quat(perturb_axis, perturb_theta) |
| 50 | + if original.shape == (4,): |
| 51 | + return math.quat_mul(rot_offset, original) |
| 52 | + elif original.shape == (3,): |
| 53 | + return math.rotate(original, rot_offset) |
| 54 | + else: |
| 55 | + raise ValueError('Invalid input shape:', original.shape) |
26 | 56 |
|
27 | 57 |
|
28 | 58 | def domain_randomize(
|
29 | 59 | mjx_model: mjx.Model, num_worlds: int
|
30 | 60 | ) -> Tuple[mjx.Model, mjx.Model]:
|
31 | 61 | """Tile the necessary axes for the Madrona BatchRenderer."""
|
| 62 | + mj_model = pick_cartesian.PandaPickCubeCartesian().mj_model |
| 63 | + FLOOR_GEOM_ID = mj_model.geom('floor').id |
| 64 | + BOX_GEOM_ID = mj_model.geom('box').id |
| 65 | + STRIP_GEOM_ID = mj_model.geom('init_space').id |
| 66 | + |
32 | 67 | in_axes = jax.tree_util.tree_map(lambda x: None, mjx_model)
|
33 | 68 | in_axes = in_axes.tree_replace({
|
34 | 69 | 'geom_rgba': 0,
|
35 | 70 | 'geom_matid': 0,
|
36 |
| - 'geom_size': 0, |
37 |
| - 'geom_friction': 0, |
38 | 71 | 'cam_pos': 0,
|
39 | 72 | 'cam_quat': 0,
|
40 | 73 | 'light_pos': 0,
|
41 | 74 | 'light_dir': 0,
|
42 | 75 | 'light_directional': 0,
|
43 | 76 | 'light_castshadow': 0,
|
44 |
| - 'light_cutoff': 0, |
45 | 77 | })
|
46 | 78 | rng = jax.random.key(0)
|
47 | 79 |
|
| 80 | + # Simpler logic implementing via Numpy. |
| 81 | + np.random.seed(0) |
| 82 | + light_positions = [sample_light_position() for _ in range(num_worlds)] |
| 83 | + light_positions = jp.array(light_positions) |
| 84 | + |
48 | 85 | @jax.vmap
|
49 |
| - def rand(rng): |
| 86 | + def rand(rng: jax.Array, light_position: jax.Array): |
| 87 | + """Generate randomized model fields.""" |
50 | 88 | _, key = jax.random.split(rng, 2)
|
51 |
| - # friction |
52 |
| - friction = jax.random.uniform(key, (1,), minval=0.6, maxval=1.4) |
53 |
| - friction = mjx_model.geom_friction.at[:, 0].set(friction) |
54 |
| - key_r, key_g, key_b, key = jax.random.split(key, 4) |
55 |
| - rgba = jp.array([ |
56 |
| - jax.random.uniform(key_r, (), minval=0.5, maxval=1.0), |
57 |
| - jax.random.uniform(key_g, (), minval=0.0, maxval=0.5), |
58 |
| - jax.random.uniform(key_b, (), minval=0.0, maxval=0.5), |
59 |
| - 1.0, |
60 |
| - ]) |
| 89 | + |
| 90 | + #### Apearance #### |
| 91 | + # Sample a random color for the box |
| 92 | + key_box, key_strip, key_floor, key = jax.random.split(key, 4) |
| 93 | + rgba = jp.array( |
| 94 | + [jax.random.uniform(key_box, (), minval=0.5, maxval=1.0), 0.0, 0.0, 1.0] |
| 95 | + ) |
61 | 96 | geom_rgba = mjx_model.geom_rgba.at[BOX_GEOM_ID].set(rgba)
|
62 | 97 |
|
| 98 | + strip_white = jax.random.uniform(key_strip, (), minval=0.8, maxval=1.0) |
| 99 | + geom_rgba = mjx_model.geom_rgba.at[STRIP_GEOM_ID].set( |
| 100 | + jp.array([strip_white, strip_white, strip_white, 1.0]) |
| 101 | + ) |
| 102 | + |
63 | 103 | # Sample a shade of gray
|
64 |
| - key_gs, key = jax.random.split(key) |
65 |
| - gray_scale = jax.random.uniform(key_gs, (), minval=0.0, maxval=0.8) |
| 104 | + gray_scale = jax.random.uniform(key_floor, (), minval=0.0, maxval=0.25) |
66 | 105 | geom_rgba = geom_rgba.at[FLOOR_GEOM_ID].set(
|
67 | 106 | jp.array([gray_scale, gray_scale, gray_scale, 1.0])
|
68 | 107 | )
|
69 | 108 |
|
70 |
| - # Set unrandomized and randomized matID's to -1 and -2. |
71 |
| - geom_matid = jp.ones_like(mjx_model.geom_matid) * -1 |
72 |
| - geom_matid = geom_matid.at[BOX_GEOM_ID].set(-2) |
| 109 | + mat_offset, num_geoms = 5, geom_rgba.shape[0] |
| 110 | + key_matid, key = jax.random.split(key) |
| 111 | + geom_matid = ( |
| 112 | + jax.random.randint(key_matid, shape=(num_geoms,), minval=0, maxval=10) |
| 113 | + + mat_offset |
| 114 | + ) |
| 115 | + geom_matid = geom_matid.at[BOX_GEOM_ID].set( |
| 116 | + -2 |
| 117 | + ) # Use the above randomized colors |
73 | 118 | geom_matid = geom_matid.at[FLOOR_GEOM_ID].set(-2)
|
| 119 | + geom_matid = geom_matid.at[STRIP_GEOM_ID].set(-2) |
74 | 120 |
|
75 |
| - key_pos, key = jax.random.split(key) |
| 121 | + #### Cameras #### |
| 122 | + key_pos, key_ori, key = jax.random.split(key, 3) |
76 | 123 | cam_offset = jax.random.uniform(key_pos, (3,), minval=-0.05, maxval=0.05)
|
| 124 | + assert ( |
| 125 | + len(mjx_model.cam_pos) == 1 |
| 126 | + ), f'Expected single camera, got {len(mjx_model.cam_pos)}' |
77 | 127 | cam_pos = mjx_model.cam_pos.at[0].set(mjx_model.cam_pos[0] + cam_offset)
|
| 128 | + cam_quat = mjx_model.cam_quat.at[0].set( |
| 129 | + perturb_orientation(key_ori, mjx_model.cam_quat[0], 10) |
| 130 | + ) |
| 131 | + |
| 132 | + #### Lighting #### |
| 133 | + nlight = mjx_model.light_pos.shape[0] |
| 134 | + assert ( |
| 135 | + nlight == 1 |
| 136 | + ), f'Sim2Real was trained with a single light source, got {nlight}' |
| 137 | + key_lsha, key_ldir, key_ldct, key = jax.random.split(key, 4) |
78 | 138 |
|
79 |
| - key_axis, key_theta, key = jax.random.split(key, 3) |
80 |
| - perturb_axis = jax.random.uniform(key_axis, (3,), minval=-1, maxval=1) |
81 |
| - perturb_axis = perturb_axis / jp.linalg.norm(perturb_axis) |
82 |
| - perturb_theta = jax.random.uniform( |
83 |
| - key_theta, shape=(1,), maxval=np.deg2rad(10) |
| 139 | + # Direction |
| 140 | + shine_at = jp.array([0.661, -0.001, 0.179]) # Gripper starting position |
| 141 | + nom_dir = (shine_at - light_position) / jp.linalg.norm( |
| 142 | + shine_at - light_position |
84 | 143 | )
|
85 |
| - camera_rot_offset = math.axis_angle_to_quat(perturb_axis, perturb_theta) |
86 |
| - cam_quat = mjx_model.cam_quat.at[0].set( |
87 |
| - math.quat_mul(camera_rot_offset, mjx_model.cam_quat[0]) |
| 144 | + light_dir = mjx_model.light_dir.at[0].set( |
| 145 | + perturb_orientation(key_ldir, nom_dir, 20) |
88 | 146 | )
|
89 | 147 |
|
90 |
| - return friction, geom_rgba, geom_matid, cam_pos, cam_quat |
| 148 | + # Whether to cast shadows |
| 149 | + light_castshadow = jax.random.bernoulli( |
| 150 | + key_lsha, 0.75, shape=(nlight,) |
| 151 | + ).astype(jp.float32) |
91 | 152 |
|
92 |
| - friction, geom_rgba, geom_matid, cam_pos, cam_quat = rand( |
93 |
| - jax.random.split(rng, num_worlds) |
94 |
| - ) |
| 153 | + # No need to randomize into specular lighting |
| 154 | + light_directional = jp.ones((nlight,)) |
| 155 | + |
| 156 | + return ( |
| 157 | + geom_rgba, |
| 158 | + geom_matid, |
| 159 | + cam_pos, |
| 160 | + cam_quat, |
| 161 | + light_dir, |
| 162 | + light_directional, |
| 163 | + light_castshadow, |
| 164 | + ) |
| 165 | + |
| 166 | + ( |
| 167 | + geom_rgba, |
| 168 | + geom_matid, |
| 169 | + cam_pos, |
| 170 | + cam_quat, |
| 171 | + light_dir, |
| 172 | + light_directional, |
| 173 | + light_castshadow, |
| 174 | + ) = rand(jax.random.split(rng, num_worlds), light_positions) |
95 | 175 |
|
96 | 176 | mjx_model = mjx_model.tree_replace({
|
97 | 177 | 'geom_rgba': geom_rgba,
|
98 | 178 | 'geom_matid': geom_matid,
|
99 |
| - 'geom_size': jp.repeat( |
100 |
| - jp.expand_dims(mjx_model.geom_size, 0), num_worlds, axis=0 |
101 |
| - ), |
102 |
| - 'geom_friction': friction, |
103 | 179 | 'cam_pos': cam_pos,
|
104 | 180 | 'cam_quat': cam_quat,
|
105 |
| - 'light_pos': jp.repeat( |
106 |
| - jp.expand_dims(mjx_model.light_pos, 0), num_worlds, axis=0 |
107 |
| - ), |
108 |
| - 'light_dir': jp.repeat( |
109 |
| - jp.expand_dims(mjx_model.light_dir, 0), num_worlds, axis=0 |
110 |
| - ), |
111 |
| - 'light_directional': jp.repeat( |
112 |
| - jp.expand_dims(mjx_model.light_directional, 0), num_worlds, axis=0 |
113 |
| - ), |
114 |
| - 'light_castshadow': jp.repeat( |
115 |
| - jp.expand_dims(mjx_model.light_castshadow, 0), num_worlds, axis=0 |
116 |
| - ), |
117 |
| - 'light_cutoff': jp.repeat( |
118 |
| - jp.expand_dims(mjx_model.light_cutoff, 0), num_worlds, axis=0 |
119 |
| - ), |
| 181 | + 'light_pos': light_positions, |
| 182 | + 'light_dir': light_dir, |
| 183 | + 'light_directional': light_directional, |
| 184 | + 'light_castshadow': light_castshadow, |
120 | 185 | })
|
121 | 186 |
|
122 | 187 | return mjx_model, in_axes
|
0 commit comments