Skip to content

Commit 80a204b

Browse files
authored
[Feature] Cached resets, fix a few minor reset mask bugs, remove some default dict args (#1203)
* init * Create tree.py * work * fix bugs * add missing reset mask and scene id masking * fix reset mask * fixes * fixes * bug fix * docs
1 parent acb8ea1 commit 80a204b

File tree

9 files changed

+315
-19
lines changed

9 files changed

+315
-19
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Cached Reset
2+
3+
For some environments where environment resets may be slow/expensive, or during workflows like RL with partial resets where there are frequent small resets (instead of resetting all environments in GPU sim simultaneously), it can be useful to use cached resets.
4+
5+
Cached resets essentially skips the process of calling the environment's reset function and instead loads a previous environment state and observation instead. Loading environment state instead of running environment reset code (the `_initialize_episode` function) can be faster and boost environment FPS.
6+
7+
To use cached resets we provide a simple environment wrapper {py:class}`mani_skill.utils.wrappers.CachedResetWrapper` that can be used as follows
8+
9+
10+
```python
11+
from mani_skill.utils.wrappers import CachedResetWrapper
12+
import gymnasium as gym
13+
14+
env = gym.make("StackCube-v1", num_envs=256)
15+
# upon applying the wrapper below we will by default sample 256 different reset states and the corresponding observations and cache them
16+
env = CachedResetWrapper(env)
17+
# obs is now fetched from a cache, and we initialize the environment with environment state
18+
obs, _ = env.reset()
19+
```
20+
21+
Note that this does not cache geometry/texture details, only environment state. Most ManiSkill environments change geometries / textures / scenes when they are destroyed and recreated with a new seed or reconfigured with a new seed.
22+
23+
## Configuration Options
24+
25+
There are a few configuration options and ways to use the `CachedResetWrapper`. One way is to modify how the reset states are generated. Below is the configuration dataclass that you can use and/or override when creating the wrapper
26+
27+
```python
28+
@dataclass
29+
class CachedResetsConfig:
30+
num_resets: Optional[int] = None
31+
"""The number of reset states to cache. If none it will cache `num_envs` number of reset states."""
32+
device: Optional[Device] = None
33+
"""The device to cache the reset states on. If none it will use the base environment's device."""
34+
seed: Optional[int] = None
35+
"""The seed to use for generating the cached reset states."""
36+
37+
def dict(self):
38+
return {k: v for k, v in asdict(self).items()}
39+
```
40+
41+
For example to change the number of cached resets and the generation seed you can pass a dict as so
42+
43+
```python
44+
env = CachedResetWrapper(env, config=dict(num_resets=16384, seed=0))
45+
```
46+
47+
You can also manually pass in your own reset states and optionally observations paired with each reset.
48+
49+
```python
50+
# env_states should be the result of env.get_state_dict(). It should be a dictionary where each leaf has the same batch size
51+
# obs can be the observations you previously generated. It can also be none
52+
env = CachedResetWrapper(env, reset_to_env_states=dict(env_states=env_states, obs=obs))
53+
```
54+
55+
It may be useful to use the `tree` utility in ManiSkill if you want to e.g. concatenate multiple env_states values together from multiple calls to `env.get_state_dict` as so
56+
57+
```python
58+
from mani_skill.utils import tree
59+
state_dict_1 = env.get_state_dict()
60+
# do something to the env
61+
state_dict_2 = env.get_state_dict()
62+
env_states = tree.cat([state_dict_1, state_dict_2])
63+
env = CachedResetWrapper(env, reset_to_env_states=dict(env_states=env_states, obs=None))
64+
```
65+
66+
67+
## Performance
68+
69+
The following code snippet can quickly check the speed gains when using cached resets. For the example below with 256 envs, state observation mode
70+
cached resets took on average about 0.004s while normal resets took 0.007s on a RTX 3080. With the rgb observation mode the difference is more staggering, with cached resets taking on average about 0.005s while normal resets took 0.167s.
71+
72+
```python
73+
from mani_skill.utils.wrappers import CachedResetWrapper
74+
import gymnasium as gym
75+
import time
76+
77+
num_envs = 256
78+
obs_mode = "rgb"
79+
env = gym.make("StackCube-v1", obs_mode=obs_mode, num_envs=num_envs)
80+
env = CachedResetWrapper(env)
81+
82+
trials = 100
83+
start_time = time.time()
84+
for i in range(trials):
85+
env.reset()
86+
end_time = time.time()
87+
print(f"Average time per cached reset: {(end_time - start_time) / trials} seconds")
88+
89+
env = gym.make("StackCube-v1", obs_mode=obs_mode, num_envs=num_envs)
90+
# env = CachedResetWrapper(env)
91+
92+
trials = 100
93+
start_time = time.time()
94+
for i in range(trials):
95+
env.reset()
96+
end_time = time.time()
97+
print(f"Average time per reset: {(end_time - start_time) / trials} seconds")
98+
```

docs/source/user_guide/wrappers/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77
record
88
flatten
99
action_repeat
10+
cached_reset
1011
```

mani_skill/envs/sapien_env.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
update_camera_configs_from_dict,
3838
)
3939
from mani_skill.sensors.depth_camera import StereoDepthCamera, StereoDepthCameraConfig
40-
from mani_skill.utils import common, gym_utils, sapien_utils
40+
from mani_skill.utils import common, gym_utils, sapien_utils, tree
4141
from mani_skill.utils.structs import Actor, Articulation
4242
from mani_skill.utils.structs.pose import Pose
4343
from mani_skill.utils.structs.types import Array, SimConfig
@@ -316,6 +316,8 @@ def __init__(
316316
self._elapsed_steps = (
317317
torch.zeros(self.num_envs, device=self.device, dtype=torch.int32)
318318
)
319+
self._last_obs = None
320+
"""the last observation returned by the environment"""
319321
obs, _ = self.reset(seed=[2022 + i for i in range(self.num_envs)], options=dict(reconfigure=True))
320322

321323
self._init_raw_obs = common.to_cpu_tensor(obs)
@@ -850,7 +852,11 @@ def reset(self, seed: Union[None, int, list[int]] = None, options: Union[None, d
850852
options["reconfigure"] is True, will call self._reconfigure() which deletes the entire physx scene and reconstructs everything.
851853
Users building custom tasks generally do not need to override this function.
852854
853-
Returns the first observation and a info dictionary. The info dictionary is of type
855+
If options["reset_to_env_states"] is given, we expect there to be options["reset_to_env_states"]["env_states"] and optionally options["reset_to_env_states"]["obs"], both with
856+
batch size equal to the number of environments being reset. "env_states" can be a dictionary or flat tensor and we skip calling the environment's _initialize_episode function which
857+
generates the initial state on a normal reset. If "obs" is given we skip calling the environment's get_obs function which can save some compute/time.
858+
859+
Returns the observations and an info dictionary. The info dictionary is of type
854860
855861
856862
.. highlight:: python
@@ -917,12 +923,22 @@ def reset(self, seed: Union[None, int, list[int]] = None, options: Union[None, d
917923
if self.agent is not None:
918924
self.agent.reset()
919925

920-
if seed is not None or self._enhanced_determinism:
921-
with torch.random.fork_rng():
922-
torch.manual_seed(self._episode_seed[0])
923-
self._initialize_episode(env_idx, options)
926+
# we either reset to given env states or use the environment's defined _initialize_episode function to generate the initial state
927+
reset_to_env_states_obs = None
928+
if "reset_to_env_states" in options:
929+
env_states = options["reset_to_env_states"]["env_states"]
930+
reset_to_env_states_obs = options["reset_to_env_states"].get("obs", None)
931+
if isinstance(env_states, dict):
932+
self.set_state_dict(env_states, env_idx)
933+
else:
934+
self.set_state(env_states, env_idx)
924935
else:
925-
self._initialize_episode(env_idx, options)
936+
if seed is not None or self._enhanced_determinism:
937+
with torch.random.fork_rng():
938+
torch.manual_seed(self._episode_seed[0])
939+
self._initialize_episode(env_idx, options)
940+
else:
941+
self._initialize_episode(env_idx, options)
926942
# reset the reset mask back to all ones so any internal code in maniskill can continue to manipulate all scenes at once as usual
927943
self.scene._reset_mask = torch.ones(
928944
self.num_envs, dtype=bool, device=self.device
@@ -942,9 +958,13 @@ def reset(self, seed: Union[None, int, list[int]] = None, options: Union[None, d
942958
self.agent.controller.reset()
943959

944960
info = self.get_info()
945-
obs = self.get_obs(info)
946-
961+
if reset_to_env_states_obs is None:
962+
obs = self.get_obs(info)
963+
else:
964+
obs = self._last_obs
965+
tree.replace(obs, env_idx, common.to_tensor(reset_to_env_states_obs, device=self.device))
947966
info["reconfigure"] = reconfigure
967+
self._last_obs = obs
948968
return obs, info
949969

950970
def _set_main_rng(self, seed):
@@ -1031,7 +1051,7 @@ def step(self, action: Union[None, np.ndarray, torch.Tensor, Dict]):
10311051
terminated = info["fail"].clone()
10321052
else:
10331053
terminated = torch.zeros(self.num_envs, dtype=bool, device=self.device)
1034-
1054+
self._last_obs = obs
10351055
return (
10361056
obs,
10371057
reward,

mani_skill/utils/structs/articulation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ def set_joint_drive_targets(
888888
else:
889889
gx, gy = self.get_joint_target_indices(joint_indices)
890890
self.px.cuda_articulation_target_qpos.torch()[
891-
gx[self.scene._reset_mask], gy[self.scene._reset_mask]
891+
gx[self.scene._reset_mask[self._scene_idxs]], gy[self.scene._reset_mask[self._scene_idxs]]
892892
] = targets
893893
else:
894894
for i, joint in enumerate(joints):
@@ -911,7 +911,9 @@ def set_joint_drive_velocity_targets(
911911
gx, gy = self.get_joint_target_indices(joints)
912912
else:
913913
gx, gy = self.get_joint_target_indices(joint_indices)
914-
self.px.cuda_articulation_target_qvel.torch()[gx, gy] = targets
914+
self.px.cuda_articulation_target_qvel.torch()[
915+
gx[self.scene._reset_mask[self._scene_idxs]], gy[self.scene._reset_mask[self._scene_idxs]]
916+
] = targets
915917
else:
916918
for i, joint in enumerate(joints):
917919
joint.set_drive_velocity_target(targets[0, i])

mani_skill/utils/tree.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
3+
4+
# NOTE (stao): when tensordict is used we should replace all of this
5+
def slice(x, i):
6+
if isinstance(x, dict):
7+
return {k: slice(v, i) for k, v in x.items()}
8+
else:
9+
return x[i]
10+
11+
12+
def cat(x: list):
13+
if isinstance(x[0], dict):
14+
return {k: cat([d[k] for d in x]) for k in x[0].keys()}
15+
else:
16+
return torch.cat(x, dim=0)
17+
18+
19+
def replace(x, i, y):
20+
if isinstance(x, dict):
21+
for k, v in x.items():
22+
replace(v, i, y[k])
23+
else:
24+
x[i] = y

mani_skill/utils/wrappers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from .action_repeat import ActionRepeatWrapper
2+
from .cached_reset import CachedResetWrapper
13
from .flatten import (
24
FlattenActionSpaceWrapper,
35
FlattenObservationWrapper,
@@ -6,4 +8,3 @@
68
from .frame_stack import FrameStack
79
from .gymnasium import CPUGymWrapper
810
from .record import RecordEpisode
9-
from .action_repeat import ActionRepeatWrapper
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from dataclasses import asdict, dataclass
2+
from typing import List, Optional, Union
3+
4+
import dacite
5+
import gymnasium as gym
6+
import torch
7+
8+
from mani_skill.envs.sapien_env import BaseEnv
9+
from mani_skill.utils import common, tree
10+
from mani_skill.utils.structs.types import Device
11+
12+
13+
@dataclass
14+
class CachedResetsConfig:
15+
num_resets: Optional[int] = None
16+
"""The number of reset states to cache. If none it will cache `num_envs` number of reset states."""
17+
device: Optional[Device] = None
18+
"""The device to cache the reset states on. If none it will use the base environment's device."""
19+
seed: Optional[int] = None
20+
"""The seed to use for generating the cached reset states."""
21+
22+
def dict(self):
23+
return {k: v for k, v in asdict(self).items()}
24+
25+
26+
class CachedResetWrapper(gym.Wrapper):
27+
"""
28+
Cached reset wrapper for ManiSkill3 environments. Caching resets allows you to skip slower parts of the reset function call and boost environment FPS as a result.
29+
30+
Args:
31+
env: The environment to wrap.
32+
reset_to_env_states: A dictionary with keys "env_states" and optionally "obs". "env_states" is a dictionary of environment states to reset to.
33+
"obs" contains the corresponding observations generated at those env states. If reset_to_env_states is not provided, the wrapper will sample reset states
34+
from the environment using the given seed.
35+
config: A dictionary or a `CachedResetsConfig` object that contains the configuration for the cached resets.
36+
"""
37+
38+
def __init__(
39+
self,
40+
env: gym.Env,
41+
reset_to_env_states: Optional[dict] = None,
42+
config: Union[CachedResetsConfig, dict] = CachedResetsConfig(),
43+
):
44+
super().__init__(env)
45+
self.num_envs = self.base_env.num_envs
46+
if isinstance(config, CachedResetsConfig):
47+
config = config.dict()
48+
self.cached_resets_config = dacite.from_dict(
49+
data_class=CachedResetsConfig,
50+
data=config,
51+
config=dacite.Config(strict=True),
52+
)
53+
cached_data_device = self.cached_resets_config.device
54+
if cached_data_device is None:
55+
cached_data_device = self.base_env.device
56+
self._num_cached_resets = 0
57+
if reset_to_env_states is not None:
58+
self._cached_resets_env_states = reset_to_env_states["env_states"]
59+
self._cached_resets_obs_buffer = reset_to_env_states.get("obs", None)
60+
self._num_cached_resets = len(self._cached_resets_env_states)
61+
else:
62+
if self.cached_resets_config.num_resets is None:
63+
self.cached_resets_config.num_resets = 16384
64+
self._cached_resets_env_states = []
65+
self._cached_resets_obs_buffer = []
66+
while self._num_cached_resets < self.cached_resets_config.num_resets:
67+
obs, _ = self.env.reset(
68+
seed=self.cached_resets_config.seed,
69+
options=dict(
70+
env_idx=torch.arange(
71+
0,
72+
min(
73+
self.cached_resets_config.num_resets
74+
- self._num_cached_resets,
75+
self.num_envs,
76+
),
77+
device=self.base_env.device,
78+
)
79+
),
80+
)
81+
state = self.env.get_wrapper_attr("get_state_dict")()
82+
if (
83+
self.cached_resets_config.num_resets - self._num_cached_resets
84+
< self.num_envs
85+
):
86+
obs = tree.slice(
87+
obs,
88+
slice(
89+
0,
90+
self.cached_resets_config.num_resets
91+
- self._num_cached_resets,
92+
),
93+
)
94+
state = tree.slice(
95+
state,
96+
slice(
97+
0,
98+
self.cached_resets_config.num_resets
99+
- self._num_cached_resets,
100+
),
101+
)
102+
self._cached_resets_obs_buffer.append(
103+
common.to_tensor(obs, device=self.cached_resets_config.device)
104+
)
105+
self._cached_resets_env_states.append(
106+
common.to_tensor(state, device=self.cached_resets_config.device)
107+
)
108+
self._num_cached_resets += self.num_envs
109+
self._cached_resets_env_states = tree.cat(self._cached_resets_env_states)
110+
self._cached_resets_obs_buffer = tree.cat(self._cached_resets_obs_buffer)
111+
112+
self._cached_resets_env_states = common.to_tensor(
113+
self._cached_resets_env_states, device=cached_data_device
114+
)
115+
if self._cached_resets_obs_buffer is not None:
116+
self._cached_resets_obs_buffer = common.to_tensor(
117+
self._cached_resets_obs_buffer, device=cached_data_device
118+
)
119+
120+
@property
121+
def base_env(self) -> BaseEnv:
122+
return self.env.unwrapped
123+
124+
def reset(
125+
self,
126+
*args,
127+
seed: Optional[Union[int, List[int]]] = None,
128+
options: Optional[dict] = None,
129+
**kwargs
130+
):
131+
env_idx = None
132+
if options is None:
133+
options = dict()
134+
if "env_idx" in options:
135+
env_idx = options["env_idx"]
136+
if self._cached_resets_env_states is not None:
137+
sampled_ids = torch.randint(
138+
0,
139+
self._num_cached_resets,
140+
size=(len(env_idx) if env_idx is not None else self.num_envs,),
141+
device=self.base_env.device,
142+
)
143+
options["reset_to_env_states"] = dict(
144+
env_states=tree.slice(self._cached_resets_env_states, sampled_ids),
145+
)
146+
if self._cached_resets_obs_buffer is not None:
147+
options["reset_to_env_states"]["obs"] = tree.slice(
148+
self._cached_resets_obs_buffer, sampled_ids
149+
)
150+
obs, info = self.env.reset(seed=seed, options=options)
151+
return obs, info

0 commit comments

Comments
 (0)