Skip to content

Commit f05c3fa

Browse files
kawinekashif
andauthored
minor KTO setting changes + KL batch size (huggingface#2153)
* add argument for dropout * increase default lr * change default lr in examples * fix bug in calculation of KL batch size * KL batch size should be args.per_device_train_batch_size * Update kto_trainer.mdx with hparam recs * typo * allow dropout to be disabled * update lr in sample scrippt * Update kto_config.py * Update trl/trainer/kto_trainer.py * Update docs/source/kto_trainer.mdx --------- Co-authored-by: Kashif Rasul <[email protected]>
1 parent 4799ba4 commit f05c3fa

File tree

3 files changed

+26
-18
lines changed

3 files changed

+26
-18
lines changed

docs/source/kto_trainer.mdx

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ For a full example have a look at [`examples/scripts/kto.py`].
77

88
Depending on how good your base model is, you may or may not need to do SFT before KTO.
99
This is different from standard RLHF and DPO, which always require SFT.
10+
You can also train with imbalanced data (more chosen than rejected examples, or vice-versa), but you will need to adjust hyperparameters accordingly (see below).
1011

1112
## Expected dataset format
1213

@@ -51,7 +52,8 @@ kto_dataset_dict = {
5152
```
5253

5354
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
54-
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion.
55+
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
56+
In theory, the dataset must contain at least one desirable and one undesirable completion; however, some people have had success running KTO on _only_ desirable or undesirable data (in the latter case, it is best to use a conservative learning rate).
5557

5658

5759
## Expected model format
@@ -61,13 +63,17 @@ The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that
6163

6264
For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
6365

64-
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
66+
The `beta` refers to the hyperparameter that controls how quickly the loss saturates, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
6567

6668
The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
6769
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.
6870

6971
<Tip>
70-
It is strongly recommended you use a learning rate between `5e-7` and `5e-6` with an effective batch size between `8` and `32`, for both LoRA and full finetuning. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, using smaller batch sizes and/or more training epochs will give you better results.
72+
Every choice of `beta` has a maximum learning rate it will tolerate before learning degenerates. For the default `beta = 0.1', this learning rate is `1e-6` for most models. The lower the beta is, the lower your learning rate should be. In general, we strongly recommend a learning rate between `5e-7` and `5e-6`. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, use more epochs.
73+
</Tip>
74+
75+
<Tip>
76+
Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor.
7177
</Tip>
7278

7379
```py

trl/trainer/kto_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,11 @@ class KTOConfig(TrainingArguments):
7575
from a string.
7676
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
7777
Number of processes to use for processing the dataset.
78+
disable_dropout (`bool`, *optional*, defaults to `True`):
79+
Whether to disable dropout in the model.
7880
"""
7981

80-
learning_rate: float = 5e-7
82+
learning_rate: float = 1e-6
8183
max_length: Optional[int] = None
8284
max_prompt_length: Optional[int] = None
8385
max_completion_length: Optional[int] = None
@@ -90,6 +92,7 @@ class KTOConfig(TrainingArguments):
9092
truncation_mode: str = "keep_end"
9193
generate_during_eval: bool = False
9294
is_encoder_decoder: Optional[bool] = None
95+
disable_dropout: bool = True
9396
precompute_ref_log_probs: bool = False
9497
model_init_kwargs: Optional[Dict[str, Any]] = None
9598
ref_model_init_kwargs: Optional[Dict[str, Any]] = None

trl/trainer/kto_trainer.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@
7373

7474

7575
def _get_kl_dataset(batch: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
76-
"""Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions."""
76+
"""
77+
Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
78+
For best results, the mismatched outputs y' used to estimate the KL term for a batch should be the same set as the matched
79+
outputs y used to estimate the rewards in that batch, just paired with different x.
80+
"""
7781
batch["answer_input_ids"] = [batch["answer_input_ids"][-1]] + batch["answer_input_ids"][:-1]
7882
batch["answer_attention_mask"] = [batch["answer_attention_mask"][-1]] + batch["answer_attention_mask"][:-1]
7983
return batch
@@ -514,10 +518,10 @@ def make_inputs_require_grad(module, input, output):
514518
else:
515519
self.use_dpo_data_collator = False
516520

517-
# disable dropout in the model and reference model
518-
disable_dropout_in_model(model)
519-
if self.ref_model is not None:
520-
disable_dropout_in_model(self.ref_model)
521+
if args.disable_dropout:
522+
disable_dropout_in_model(model)
523+
if self.ref_model is not None:
524+
disable_dropout_in_model(self.ref_model)
521525

522526
self.loss_type = args.loss_type
523527
self.max_length = max_length
@@ -601,22 +605,17 @@ def make_inputs_require_grad(module, input, output):
601605

602606
# Get KL datasets if needed
603607
if self.calculate_KL:
604-
total_batch_size = (
605-
max(torch.cuda.device_count(), 1)
606-
* args.per_device_train_batch_size
607-
* args.gradient_accumulation_steps
608-
)
609-
if total_batch_size <= 1:
608+
if args.per_device_train_batch_size <= 1:
610609
raise ValueError(
611-
"Batch size is 1 (too small). KTO will not work properly because the KL term will be equivalent to the implied reward."
610+
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
612611
)
613612

614613
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
615614
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
616615
train_kl_dataset = train_dataset.map(
617616
_get_kl_dataset,
618617
batched=True,
619-
batch_size=total_batch_size,
618+
batch_size=args.per_device_train_batch_size,
620619
num_proc=args.dataset_num_proc,
621620
desc="Extracting KL train dataset",
622621
)
@@ -638,7 +637,7 @@ def make_inputs_require_grad(module, input, output):
638637
eval_kl_dataset = eval_dataset.map(
639638
_get_kl_dataset,
640639
batched=True,
641-
batch_size=total_batch_size,
640+
batch_size=args.per_device_train_batch_size,
642641
num_proc=args.dataset_num_proc,
643642
desc="Extracting eval KL dataset",
644643
)

0 commit comments

Comments
 (0)