Skip to content

Commit c681ca3

Browse files
author
Jan Michelfeit
committed
#625 introduce parameter for pretraining steps
1 parent c2bc9dc commit c681ca3

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

src/imitation/algorithms/preference_comparisons.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,6 +1493,7 @@ def __init__(
14931493
transition_oversampling: float = 1,
14941494
initial_comparison_frac: float = 0.1,
14951495
initial_epoch_multiplier: float = 200.0,
1496+
initial_agent_pretrain_frac: float = 0.01,
14961497
custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
14971498
allow_variable_horizon: bool = False,
14981499
rng: Optional[np.random.Generator] = None,
@@ -1542,6 +1543,9 @@ def __init__(
15421543
initial_epoch_multiplier: before agent training begins, train the reward
15431544
model for this many more epochs than usual (on fragments sampled from a
15441545
random agent).
1546+
initial_agent_pretrain_frac: fraction of total_timesteps for which the
1547+
agent will be trained without preference gathering (and reward model
1548+
training)
15451549
custom_logger: Where to log to; if None (default), creates a new logger.
15461550
allow_variable_horizon: If False (default), algorithm will raise an
15471551
exception if it detects trajectories of different length during
@@ -1640,6 +1644,7 @@ def __init__(
16401644
self.fragment_length = fragment_length
16411645
self.initial_comparison_frac = initial_comparison_frac
16421646
self.initial_epoch_multiplier = initial_epoch_multiplier
1647+
self.initial_agent_pretrain_frac = initial_agent_pretrain_frac
16431648
self.num_iterations = num_iterations
16441649
self.transition_oversampling = transition_oversampling
16451650
if callable(query_schedule):
@@ -1672,10 +1677,11 @@ def train(
16721677
preference_query_schedule = self._preference_gather_schedule(total_comparisons)
16731678
print(f"Query schedule: {preference_query_schedule}")
16741679

1675-
timesteps_per_iteration, extra_timesteps = divmod(
1676-
total_timesteps,
1677-
self.num_iterations,
1678-
)
1680+
(
1681+
agent_pretrain_timesteps,
1682+
timesteps_per_iteration,
1683+
extra_timesteps,
1684+
) = self._compute_timesteps(total_timesteps)
16791685
reward_loss = None
16801686
reward_accuracy = None
16811687

@@ -1752,3 +1758,13 @@ def _preference_gather_schedule(self, total_comparisons):
17521758
shares = util.oric(probs * total_comparisons)
17531759
schedule = [initial_comparisons] + shares.tolist()
17541760
return schedule
1761+
1762+
def _compute_timesteps(self, total_timesteps: int) -> Tuple[int, int, int]:
1763+
agent_pretrain_timesteps = int(
1764+
total_timesteps * self.initial_agent_pretrain_frac
1765+
)
1766+
timesteps_per_iteration, extra_timesteps = divmod(
1767+
total_timesteps - agent_pretrain_timesteps,
1768+
self.num_iterations,
1769+
)
1770+
return agent_pretrain_timesteps, timesteps_per_iteration, extra_timesteps

src/imitation/scripts/config/train_preference_comparisons_pebble.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def train_defaults():
6868
initial_comparison_frac = 0.1
6969
# fraction of sampled trajectories that will include some random actions
7070
exploration_frac = 0.0
71+
# fraction of total_timesteps for training before preference gathering
72+
initial_agent_pretrain_frac = 0.05
7173
preference_model_kwargs = {}
7274
reward_trainer_kwargs = {
7375
"epochs": 3,
@@ -153,6 +155,7 @@ def fast():
153155
total_timesteps = 50
154156
total_comparisons = 5
155157
initial_comparison_frac = 0.2
158+
initial_agent_pretrain_frac = 0.2
156159
num_iterations = 1
157160
fragment_length = 2
158161
reward_trainer_kwargs = {

0 commit comments

Comments
 (0)