Skip to content

Commit 4f65b13

Browse files
bordeauxredMichael Panchenko
and
Michael Panchenko
authored
Feat/refactor collector (#1063)
Closes: #1058 ### Api Extensions - Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 - `Collector`s can now be closed, and their reset is more granular. #1063 - Trainers can control whether collectors should be reset prior to training. #1063 - Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 - Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 - Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 - Improved typing for `exploration_noise` and within Collector. #1063 ### Breaking Changes - Removed `.data` attribute from `Collector` and its child classes. #1063 - Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` expicitly or pass `reset_before_collect=True` . #1063 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 --------- Co-authored-by: Michael Panchenko <[email protected]>
1 parent edae9e4 commit 4f65b13

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1152
-642
lines changed

CHANGELOG.md

+23
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,27 @@
11
# Changelog
22

3+
## Release 1.1.0
4+
5+
### Api Extensions
6+
- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063
7+
- `Collector`s can now be closed, and their reset is more granular. #1063
8+
- Trainers can control whether collectors should be reset prior to training. #1063
9+
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
10+
11+
### Internal Improvements
12+
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
13+
- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063
14+
- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063
15+
- Improved typing for `exploration_noise` and within Collector. #1063
16+
17+
### Breaking Changes
18+
19+
- Removed `.data` attribute from `Collector` and its child classes. #1063
20+
- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset`
21+
expicitly or pass `reset_before_collect=True` . #1063
22+
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
23+
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063
24+
25+
326
Started after v1.0.0
427

docs/02_notebooks/L0_overview.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@
164164
"source": [
165165
"# Let's watch its performance!\n",
166166
"policy.eval()\n",
167-
"eval_result = test_collector.collect(n_episode=1, render=False)\n",
167+
"eval_result = test_collector.collect(n_episode=3, render=False)\n",
168168
"print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")"
169169
]
170170
},

docs/02_notebooks/L5_Collector.ipynb

+2-3
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
},
120120
"outputs": [],
121121
"source": [
122-
"collect_result = test_collector.collect(n_episode=9)\n",
122+
"collect_result = test_collector.collect(reset_before_collect=True, n_episode=9)\n",
123123
"\n",
124124
"collect_result.pprint_asdict()"
125125
]
@@ -146,8 +146,7 @@
146146
"outputs": [],
147147
"source": [
148148
"# Reset the collector\n",
149-
"test_collector.reset()\n",
150-
"collect_result = test_collector.collect(n_episode=9, random=True)\n",
149+
"collect_result = test_collector.collect(reset_before_collect=True, n_episode=9, random=True)\n",
151150
"\n",
152151
"collect_result.pprint_asdict()"
153152
]

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ ignore = [
166166
"RET505",
167167
"D106", # undocumented public nested class
168168
"D205", # blank line after summary (prevents summary-only docstrings, which makes no sense)
169+
"PLW2901", # overwrite vars in loop
169170
]
170171
unfixable = [
171172
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all

test/base/env.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,24 @@
99
from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple
1010

1111

12-
class MyTestEnv(gym.Env):
13-
"""A task for "going right". The task is to go right ``size`` steps."""
12+
class MoveToRightEnv(gym.Env):
13+
"""A task for "going right". The task is to go right ``size`` steps.
14+
15+
The observation is the current index, and the action is to go left or right.
16+
Action 0 is to go left, and action 1 is to go right.
17+
Taking action 0 at index 0 will keep the index at 0.
18+
Arriving at index ``size`` means the task is done.
19+
In the current implementation, stepping after the task is done is possible, which will
20+
lead the index to be larger than ``size``.
21+
22+
Index 0 is the starting point. If reset is called with default options, the index will
23+
be reset to 0.
24+
"""
1425

1526
def __init__(
1627
self,
1728
size: int,
18-
sleep: int = 0,
29+
sleep: float = 0.0,
1930
dict_state: bool = False,
2031
recurse_state: bool = False,
2132
ma_rew: int = 0,
@@ -74,8 +85,13 @@ def __init__(
7485
def reset(
7586
self,
7687
seed: int | None = None,
88+
# TODO: passing a dict here doesn't make any sense
7789
options: dict[str, Any] | None = None,
7890
) -> tuple[dict[str, Any] | np.ndarray, dict]:
91+
""":param seed:
92+
:param options: the start index is provided in options["state"]
93+
:return:
94+
"""
7995
if options is None:
8096
options = {"state": 0}
8197
super().reset(seed=seed)
@@ -188,7 +204,7 @@ def step(
188204
return self._encode_obs(), 1.0, False, False, {}
189205

190206

191-
class MyGoalEnv(MyTestEnv):
207+
class MyGoalEnv(MoveToRightEnv):
192208
def __init__(self, *args: Any, **kwargs: Any) -> None:
193209
assert (
194210
kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0

test/base/test_buffer.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
from tianshou.data.utils.converter import to_hdf5
2323

2424
if __name__ == "__main__":
25-
from env import MyGoalEnv, MyTestEnv
25+
from env import MoveToRightEnv, MyGoalEnv
2626
else: # pytest
27-
from test.base.env import MyGoalEnv, MyTestEnv
27+
from test.base.env import MoveToRightEnv, MyGoalEnv
2828

2929

3030
def test_replaybuffer(size=10, bufsize=20) -> None:
31-
env = MyTestEnv(size)
31+
env = MoveToRightEnv(size)
3232
buf = ReplayBuffer(bufsize)
3333
buf.update(buf)
3434
assert str(buf) == buf.__class__.__name__ + "()"
@@ -209,7 +209,7 @@ def test_ignore_obs_next(size=10) -> None:
209209

210210

211211
def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None:
212-
env = MyTestEnv(size)
212+
env = MoveToRightEnv(size)
213213
buf = ReplayBuffer(bufsize, stack_num=stack_num)
214214
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
215215
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
@@ -280,7 +280,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None:
280280

281281

282282
def test_priortized_replaybuffer(size=32, bufsize=15) -> None:
283-
env = MyTestEnv(size)
283+
env = MoveToRightEnv(size)
284284
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
285285
buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5)
286286
obs, info = env.reset()
@@ -1028,7 +1028,7 @@ def test_multibuf_stack() -> None:
10281028
bufsize = 9
10291029
stack_num = 4
10301030
cached_num = 3
1031-
env = MyTestEnv(size)
1031+
env = MoveToRightEnv(size)
10321032
# test if CachedReplayBuffer can handle stack_num + ignore_obs_next
10331033
buf4 = CachedReplayBuffer(
10341034
ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),

0 commit comments

Comments
 (0)