Skip to content

Commit 15ed55f

Browse files
authored
[BugFix] Fix bug with cached reset wrapper with counting the number of cached resets for user defined reset states (#1273)
1 parent 56d784c commit 15ed55f

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mani_skill/utils/tree.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,14 @@ def replace(x, i, y):
2222
replace(v, i, y[k])
2323
else:
2424
x[i] = y
25+
26+
def shape(x, first_only=False):
27+
"""
28+
Get the shape of leaf items in a tree. If first_only is True, return the shape of the first item only
29+
"""
30+
if isinstance(x, dict):
31+
if first_only:
32+
return shape(next(iter(x.values())), first_only)
33+
return {k: shape(v, first_only) for k, v in x.items()}
34+
else:
35+
return x.shape

mani_skill/utils/wrappers/cached_reset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
if reset_to_env_states is not None:
5959
self._cached_resets_env_states = reset_to_env_states["env_states"]
6060
self._cached_resets_obs_buffer = reset_to_env_states.get("obs", None)
61-
self._num_cached_resets = len(self._cached_resets_env_states)
61+
self._num_cached_resets = tree.shape(self._cached_resets_env_states)
6262
else:
6363
if self.cached_resets_config.num_resets is None:
6464
self.cached_resets_config.num_resets = self.num_envs

0 commit comments

Comments
 (0)