Skip to content

Commit 0a435bc

Browse files
author
Jan Michelfeit
committed
#625 use an OffPolicy for pebble
1 parent 5d1b7d7 commit 0a435bc

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

src/imitation/scripts/config/train_preference_comparisons_pebble.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
"""Configuration for imitation.scripts.train_preference_comparisons_pebble."""
22

3+
import warnings
4+
35
import sacred
6+
import stable_baselines3 as sb3
47

58
from imitation.algorithms import preference_comparisons
9+
from imitation.policies import base
610
from imitation.scripts.common import common, reward, rl, train
711

812
train_preference_comparisons_pebble_ex = sacred.Experiment(
@@ -15,14 +19,42 @@
1519
],
1620
)
1721

18-
1922
MUJOCO_SHARED_LOCALS = dict(rl=dict(rl_kwargs=dict(ent_coef=0.1)))
2023
ANT_SHARED_LOCALS = dict(
2124
total_timesteps=int(3e7),
2225
rl=dict(batch_size=16384),
2326
)
2427

2528

29+
@rl.rl_ingredient.config
30+
def rl_sac():
31+
# For recommended SAC hyperparams in each environment, see:
32+
# https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/sac.yml
33+
rl_cls = sb3.SAC
34+
warnings.warn(
35+
"SAC currently only supports continuous action spaces. "
36+
"Consider adding a discrete version as mentioned here: "
37+
"https://github.com/DLR-RM/stable-baselines3/issues/505",
38+
category=RuntimeWarning,
39+
)
40+
# Default HPs are as follows:
41+
batch_size = 256 # batch size for RL algorithm
42+
rl_kwargs = dict(batch_size=None) # make sure to set batch size to None
43+
locals() # quieten flake8
44+
45+
46+
@train.train_ingredient.config
47+
def train_sac():
48+
policy_cls = base.SAC1024Policy # noqa: F841
49+
locals() # quieten flake8
50+
51+
52+
@common.common_ingredient.config
53+
def mountain_car():
54+
env_name = "MountainCarContinuous-v0"
55+
locals() # quieten flake8
56+
57+
2658
@train_preference_comparisons_pebble_ex.config
2759
def train_defaults():
2860
fragment_length = 100 # timesteps per fragment used for comparisons

0 commit comments

Comments
 (0)