@@ -1493,6 +1493,7 @@ def __init__(
1493
1493
transition_oversampling : float = 1 ,
1494
1494
initial_comparison_frac : float = 0.1 ,
1495
1495
initial_epoch_multiplier : float = 200.0 ,
1496
+ initial_agent_pretrain_frac : float = 0.01 ,
1496
1497
custom_logger : Optional [imit_logger .HierarchicalLogger ] = None ,
1497
1498
allow_variable_horizon : bool = False ,
1498
1499
rng : Optional [np .random .Generator ] = None ,
@@ -1542,6 +1543,9 @@ def __init__(
1542
1543
initial_epoch_multiplier: before agent training begins, train the reward
1543
1544
model for this many more epochs than usual (on fragments sampled from a
1544
1545
random agent).
1546
+ initial_agent_pretrain_frac: fraction of total_timesteps for which the
1547
+ agent will be trained without preference gathering (and reward model
1548
+ training)
1545
1549
custom_logger: Where to log to; if None (default), creates a new logger.
1546
1550
allow_variable_horizon: If False (default), algorithm will raise an
1547
1551
exception if it detects trajectories of different length during
@@ -1640,6 +1644,7 @@ def __init__(
1640
1644
self .fragment_length = fragment_length
1641
1645
self .initial_comparison_frac = initial_comparison_frac
1642
1646
self .initial_epoch_multiplier = initial_epoch_multiplier
1647
+ self .initial_agent_pretrain_frac = initial_agent_pretrain_frac
1643
1648
self .num_iterations = num_iterations
1644
1649
self .transition_oversampling = transition_oversampling
1645
1650
if callable (query_schedule ):
@@ -1672,10 +1677,11 @@ def train(
1672
1677
preference_query_schedule = self ._preference_gather_schedule (total_comparisons )
1673
1678
print (f"Query schedule: { preference_query_schedule } " )
1674
1679
1675
- timesteps_per_iteration , extra_timesteps = divmod (
1676
- total_timesteps ,
1677
- self .num_iterations ,
1678
- )
1680
+ (
1681
+ agent_pretrain_timesteps ,
1682
+ timesteps_per_iteration ,
1683
+ extra_timesteps ,
1684
+ ) = self ._compute_timesteps (total_timesteps )
1679
1685
reward_loss = None
1680
1686
reward_accuracy = None
1681
1687
@@ -1752,3 +1758,13 @@ def _preference_gather_schedule(self, total_comparisons):
1752
1758
shares = util .oric (probs * total_comparisons )
1753
1759
schedule = [initial_comparisons ] + shares .tolist ()
1754
1760
return schedule
1761
+
1762
+ def _compute_timesteps (self , total_timesteps : int ) -> Tuple [int , int , int ]:
1763
+ agent_pretrain_timesteps = int (
1764
+ total_timesteps * self .initial_agent_pretrain_frac
1765
+ )
1766
+ timesteps_per_iteration , extra_timesteps = divmod (
1767
+ total_timesteps - agent_pretrain_timesteps ,
1768
+ self .num_iterations ,
1769
+ )
1770
+ return agent_pretrain_timesteps , timesteps_per_iteration , extra_timesteps
0 commit comments