Skip to content

Commit 5d1b7d7

Browse files
author
Jan Michelfeit
committed
#625 make copy of train_preference_comparisons.py for pebble
1 parent 1fdfc74 commit 5d1b7d7

File tree

2 files changed

+420
-0
lines changed

2 files changed

+420
-0
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Configuration for imitation.scripts.train_preference_comparisons_pebble."""
2+
3+
import sacred
4+
5+
from imitation.algorithms import preference_comparisons
6+
from imitation.scripts.common import common, reward, rl, train
7+
8+
train_preference_comparisons_pebble_ex = sacred.Experiment(
9+
"train_preference_comparisons_pebble",
10+
ingredients=[
11+
common.common_ingredient,
12+
reward.reward_ingredient,
13+
rl.rl_ingredient,
14+
train.train_ingredient,
15+
],
16+
)
17+
18+
19+
MUJOCO_SHARED_LOCALS = dict(rl=dict(rl_kwargs=dict(ent_coef=0.1)))
20+
ANT_SHARED_LOCALS = dict(
21+
total_timesteps=int(3e7),
22+
rl=dict(batch_size=16384),
23+
)
24+
25+
26+
@train_preference_comparisons_pebble_ex.config
27+
def train_defaults():
28+
fragment_length = 100 # timesteps per fragment used for comparisons
29+
total_timesteps = int(1e6) # total number of environment timesteps
30+
total_comparisons = 5000 # total number of comparisons to elicit
31+
num_iterations = 5 # Arbitrary, should be tuned for the task
32+
comparison_queue_size = None
33+
# factor by which to oversample transitions before creating fragments
34+
transition_oversampling = 1
35+
# fraction of total_comparisons that will be sampled right at the beginning
36+
initial_comparison_frac = 0.1
37+
# fraction of sampled trajectories that will include some random actions
38+
exploration_frac = 0.0
39+
preference_model_kwargs = {}
40+
reward_trainer_kwargs = {
41+
"epochs": 3,
42+
}
43+
save_preferences = False # save preference dataset at the end?
44+
agent_path = None # path to a (partially) trained agent to load at the beginning
45+
# type of PreferenceGatherer to use
46+
gatherer_cls = preference_comparisons.SyntheticGatherer
47+
# arguments passed on to the PreferenceGatherer specified by gatherer_cls
48+
gatherer_kwargs = {}
49+
active_selection = False
50+
active_selection_oversampling = 2
51+
uncertainty_on = "logit"
52+
fragmenter_kwargs = {
53+
"warning_threshold": 0,
54+
}
55+
# path to a pickled sequence of trajectories used instead of training an agent
56+
trajectory_path = None
57+
trajectory_generator_kwargs = {} # kwargs to pass to trajectory generator
58+
allow_variable_horizon = False
59+
60+
checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only)
61+
query_schedule = "hyperbolic"
62+
63+
64+
@train_preference_comparisons_pebble_ex.named_config
65+
def cartpole():
66+
common = dict(env_name="CartPole-v1")
67+
allow_variable_horizon = True
68+
69+
70+
@train_preference_comparisons_pebble_ex.named_config
71+
def seals_ant():
72+
locals().update(**MUJOCO_SHARED_LOCALS)
73+
locals().update(**ANT_SHARED_LOCALS)
74+
common = dict(env_name="seals/Ant-v0")
75+
76+
77+
@train_preference_comparisons_pebble_ex.named_config
78+
def half_cheetah():
79+
locals().update(**MUJOCO_SHARED_LOCALS)
80+
common = dict(env_name="HalfCheetah-v2")
81+
rl = dict(batch_size=16384, rl_kwargs=dict(batch_size=1024))
82+
83+
84+
@train_preference_comparisons_pebble_ex.named_config
85+
def seals_hopper():
86+
locals().update(**MUJOCO_SHARED_LOCALS)
87+
common = dict(env_name="seals/Hopper-v0")
88+
89+
90+
@train_preference_comparisons_pebble_ex.named_config
91+
def seals_humanoid():
92+
locals().update(**MUJOCO_SHARED_LOCALS)
93+
common = dict(env_name="seals/Humanoid-v0")
94+
total_timesteps = int(4e6)
95+
96+
97+
@train_preference_comparisons_pebble_ex.named_config
98+
def seals_cartpole():
99+
common = dict(env_name="seals/CartPole-v0")
100+
101+
102+
@train_preference_comparisons_pebble_ex.named_config
103+
def pendulum():
104+
common = dict(env_name="Pendulum-v1")
105+
106+
107+
@train_preference_comparisons_pebble_ex.named_config
108+
def mountain_car():
109+
common = dict(env_name="MountainCar-v0")
110+
allow_variable_horizon = True
111+
112+
113+
@train_preference_comparisons_pebble_ex.named_config
114+
def seals_mountain_car():
115+
common = dict(env_name="seals/MountainCar-v0")
116+
117+
118+
@train_preference_comparisons_pebble_ex.named_config
119+
def fast():
120+
# Minimize the amount of computation. Useful for test cases.
121+
total_timesteps = 50
122+
total_comparisons = 5
123+
initial_comparison_frac = 0.2
124+
num_iterations = 1
125+
fragment_length = 2
126+
reward_trainer_kwargs = {
127+
"epochs": 1,
128+
}

0 commit comments

Comments
 (0)