Skip to content

Commit 8cecab4

Browse files
authored
Update TD3/DDPG defaults and upgrade to MuJoCo v4 envs (#430)
* Update TD3/DDPG defaults and upgrade to MuJoCo v4 envs * Update SB3 version
1 parent 28dc228 commit 8cecab4

14 files changed

+132
-116
lines changed

CHANGELOG.md

+18
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
## Release 2.3.0a1 (WIP)
2+
3+
### Breaking Changes
4+
- Updated defaults hyperparameters for TD3/DDPG to be more consistent with SAC
5+
- Upgraded MuJoCo envs hyperparameters to v4 (pre-trained agents need to be updated)
6+
- Upgraded to SB3 >= 2.3.0
7+
8+
### New Features
9+
10+
11+
### Bug fixes
12+
13+
### Documentation
14+
15+
### Other
16+
17+
18+
119
## Release 2.2.1 (2023-11-17)
220

321
### Breaking Changes

hyperparams/a2c.yml

+6-6
Original file line numberDiff line numberDiff line change
@@ -165,24 +165,24 @@ ReacherBulletEnv-v0:
165165

166166
# === Mujoco Envs ===
167167

168-
HalfCheetah-v3: &mujoco-defaults
168+
HalfCheetah-v4: &mujoco-defaults
169169
normalize: true
170170
n_timesteps: !!float 1e6
171171
policy: 'MlpPolicy'
172172

173-
Ant-v3:
173+
Ant-v4:
174174
<<: *mujoco-defaults
175175

176-
Hopper-v3:
176+
Hopper-v4:
177177
<<: *mujoco-defaults
178178

179-
Walker2d-v3:
179+
Walker2d-v4:
180180
<<: *mujoco-defaults
181181

182-
Humanoid-v3:
182+
Humanoid-v4:
183183
<<: *mujoco-defaults
184184
n_timesteps: !!float 2e6
185185

186-
Swimmer-v3:
186+
Swimmer-v4:
187187
<<: *mujoco-defaults
188188
gamma: 0.9999

hyperparams/ars.yml

+6-6
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ ReacherBulletEnv-v0:
108108

109109
# === Mujoco Envs ===
110110
# Params closest to original paper
111-
Swimmer-v3:
111+
Swimmer-v4:
112112
n_envs: 1
113113
policy: 'LinearPolicy'
114114
n_timesteps: !!float 2e6
@@ -119,7 +119,7 @@ Swimmer-v3:
119119
alive_bonus_offset: 0
120120
# normalize: "dict(norm_obs=True, norm_reward=False)"
121121

122-
Hopper-v3:
122+
Hopper-v4:
123123
n_envs: 1
124124
policy: 'LinearPolicy'
125125
n_timesteps: !!float 7e6
@@ -130,7 +130,7 @@ Hopper-v3:
130130
alive_bonus_offset: -1
131131
normalize: "dict(norm_obs=True, norm_reward=False)"
132132

133-
HalfCheetah-v3:
133+
HalfCheetah-v4:
134134
n_envs: 1
135135
policy: 'LinearPolicy'
136136
n_timesteps: !!float 1.25e7
@@ -141,7 +141,7 @@ HalfCheetah-v3:
141141
alive_bonus_offset: 0
142142
normalize: "dict(norm_obs=True, norm_reward=False)"
143143

144-
Walker2d-v3:
144+
Walker2d-v4:
145145
n_envs: 1
146146
policy: 'LinearPolicy'
147147
n_timesteps: !!float 7.5e7
@@ -152,7 +152,7 @@ Walker2d-v3:
152152
alive_bonus_offset: -1
153153
normalize: "dict(norm_obs=True, norm_reward=False)"
154154

155-
Ant-v3:
155+
Ant-v4:
156156
n_envs: 1
157157
policy: 'LinearPolicy'
158158
n_timesteps: !!float 7.5e7
@@ -164,7 +164,7 @@ Ant-v3:
164164
normalize: "dict(norm_obs=True, norm_reward=False)"
165165

166166

167-
Humanoid-v3:
167+
Humanoid-v4:
168168
n_envs: 1
169169
policy: 'LinearPolicy'
170170
n_timesteps: !!float 2.5e8

hyperparams/ddpg.yml

+29-29
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ MountainCarContinuous-v0:
44
policy: 'MlpPolicy'
55
noise_type: 'ornstein-uhlenbeck'
66
noise_std: 0.5
7+
gradient_steps: 1
8+
train_freq: 1
9+
learning_rate: !!float 1e-3
10+
batch_size: 256
11+
policy_kwargs: "dict(net_arch=[400, 300])"
712

813
Pendulum-v1:
914
n_timesteps: 20000
@@ -13,8 +18,8 @@ Pendulum-v1:
1318
learning_starts: 10000
1419
noise_type: 'normal'
1520
noise_std: 0.1
16-
gradient_steps: -1
17-
train_freq: [1, "episode"]
21+
gradient_steps: 1
22+
train_freq: 1
1823
learning_rate: !!float 1e-3
1924
policy_kwargs: "dict(net_arch=[400, 300])"
2025

@@ -26,8 +31,8 @@ LunarLanderContinuous-v2:
2631
learning_starts: 10000
2732
noise_type: 'normal'
2833
noise_std: 0.1
29-
gradient_steps: -1
30-
train_freq: [1, "episode"]
34+
gradient_steps: 1
35+
train_freq: 1
3136
learning_rate: !!float 1e-3
3237
policy_kwargs: "dict(net_arch=[400, 300])"
3338

@@ -39,23 +44,23 @@ BipedalWalker-v3:
3944
learning_starts: 10000
4045
noise_type: 'normal'
4146
noise_std: 0.1
42-
gradient_steps: -1
43-
train_freq: [1, "episode"]
47+
gradient_steps: 1
48+
train_freq: 1
4449
learning_rate: !!float 1e-3
4550
policy_kwargs: "dict(net_arch=[400, 300])"
4651

4752
# To be tuned
4853
BipedalWalkerHardcore-v3:
4954
n_timesteps: !!float 1e7
5055
policy: 'MlpPolicy'
51-
gamma: 0.98
52-
buffer_size: 200000
56+
gamma: 0.99
57+
buffer_size: 1000000
5358
learning_starts: 10000
5459
noise_type: 'normal'
5560
noise_std: 0.1
56-
gradient_steps: -1
57-
train_freq: [1, "episode"]
58-
learning_rate: !!float 1e-3
61+
batch_size: 256
62+
train_freq: 1
63+
learning_rate: lin_7e-4
5964
policy_kwargs: "dict(net_arch=[400, 300])"
6065

6166
# Tuned
@@ -69,28 +74,21 @@ HalfCheetahBulletEnv-v0: &pybullet-defaults
6974
noise_std: 0.1
7075
gradient_steps: 1
7176
train_freq: 1
72-
learning_rate: !!float 1e-3
77+
batch_size: 256
78+
learning_rate: !!float 7e-4
7379
policy_kwargs: "dict(net_arch=[400, 300])"
7480

7581
# Tuned
7682
AntBulletEnv-v0:
7783
<<: *pybullet-defaults
78-
learning_rate: !!float 7e-4
79-
policy_kwargs: "dict(net_arch=[400, 300])"
8084

8185
# Tuned
8286
HopperBulletEnv-v0:
8387
<<: *pybullet-defaults
84-
train_freq: 64
85-
gradient_steps: 64
86-
batch_size: 256
87-
learning_rate: !!float 7e-4
8888

8989
# Tuned
9090
Walker2DBulletEnv-v0:
9191
<<: *pybullet-defaults
92-
batch_size: 256
93-
learning_rate: !!float 7e-4
9492

9593
# TO BE tested
9694
HumanoidBulletEnv-v0:
@@ -123,29 +121,31 @@ InvertedPendulumSwingupBulletEnv-v0:
123121
n_timesteps: !!float 3e5
124122

125123
# === Mujoco Envs ===
126-
127-
HalfCheetah-v3: &mujoco-defaults
124+
HalfCheetah-v4: &mujoco-defaults
128125
n_timesteps: !!float 1e6
129126
policy: 'MlpPolicy'
130127
learning_starts: 10000
131128
noise_type: 'normal'
132129
noise_std: 0.1
130+
train_freq: 1
131+
gradient_steps: 1
132+
learning_rate: !!float 1e-3
133+
batch_size: 256
134+
policy_kwargs: "dict(net_arch=[400, 300])"
133135

134-
Ant-v3:
136+
Ant-v4:
135137
<<: *mujoco-defaults
136138

137-
Hopper-v3:
139+
Hopper-v4:
138140
<<: *mujoco-defaults
139141

140-
Walker2d-v3:
142+
Walker2d-v4:
141143
<<: *mujoco-defaults
142144

143-
Humanoid-v3:
145+
Humanoid-v4:
144146
<<: *mujoco-defaults
145147
n_timesteps: !!float 2e6
146148

147-
Swimmer-v3:
149+
Swimmer-v4:
148150
<<: *mujoco-defaults
149151
gamma: 0.9999
150-
train_freq: 1
151-
gradient_steps: 1

hyperparams/ppo.yml

+11-11
Original file line numberDiff line numberDiff line change
@@ -380,28 +380,28 @@ CarRacing-v2:
380380

381381

382382
# === Mujoco Envs ===
383-
# HalfCheetah-v3: &mujoco-defaults
383+
# HalfCheetah-v4: &mujoco-defaults
384384
# normalize: true
385385
# n_timesteps: !!float 1e6
386386
# policy: 'MlpPolicy'
387387

388-
Ant-v3: &mujoco-defaults
388+
Ant-v4: &mujoco-defaults
389389
normalize: true
390390
n_timesteps: !!float 1e6
391391
policy: 'MlpPolicy'
392392

393-
# Hopper-v3:
393+
# Hopper-v4:
394394
# <<: *mujoco-defaults
395395
#
396-
# Walker2d-v3:
396+
# Walker2d-v4:
397397
# <<: *mujoco-defaults
398398
#
399-
# Humanoid-v3:
399+
# Humanoid-v4:
400400
# <<: *mujoco-defaults
401401
# n_timesteps: !!float 2e6
402402
#
403403
# tuned
404-
Swimmer-v3:
404+
Swimmer-v4:
405405
<<: *mujoco-defaults
406406
gamma: 0.9999
407407
n_envs: 4
@@ -413,7 +413,7 @@ Swimmer-v3:
413413
# Tuned
414414
# 10 mujoco envs
415415

416-
HalfCheetah-v3:
416+
HalfCheetah-v4:
417417
normalize: true
418418
n_envs: 1
419419
policy: 'MlpPolicy'
@@ -435,7 +435,7 @@ HalfCheetah-v3:
435435
net_arch=dict(pi=[256, 256], vf=[256, 256])
436436
)"
437437

438-
# Ant-v3:
438+
# Ant-v4:
439439
# normalize: true
440440
# n_envs: 1
441441
# policy: 'MlpPolicy'
@@ -451,7 +451,7 @@ HalfCheetah-v3:
451451
# max_grad_norm: 0.6
452452
# vf_coef: 0.677239
453453

454-
Hopper-v3:
454+
Hopper-v4:
455455
normalize: true
456456
n_envs: 1
457457
policy: 'MlpPolicy'
@@ -495,7 +495,7 @@ HumanoidStandup-v2:
495495
net_arch=dict(pi=[256, 256], vf=[256, 256])
496496
)"
497497

498-
Humanoid-v3:
498+
Humanoid-v4:
499499
normalize: true
500500
n_envs: 1
501501
policy: 'MlpPolicy'
@@ -565,7 +565,7 @@ Reacher-v2:
565565
max_grad_norm: 0.9
566566
vf_coef: 0.950368
567567

568-
Walker2d-v3:
568+
Walker2d-v4:
569569
normalize: true
570570
n_envs: 1
571571
policy: 'MlpPolicy'

0 commit comments

Comments
 (0)