Skip to content

Commit d45c86e

Browse files
qgallouedeclewtun
andauthored
Conversational dataset support for CPOTrainer (huggingface#2144)
* extract prompt and apply chat template in cpo trainer * default leanring rate * simplify example * update doc * test all formats * extend exptract prompt * improve doc format * link in dataset formats * Update docs/source/cpo_trainer.mdx Co-authored-by: lewtun <[email protected]> * Update docs/source/cpo_trainer.mdx Co-authored-by: lewtun <[email protected]> --------- Co-authored-by: lewtun <[email protected]>
1 parent c6b0d13 commit d45c86e

File tree

6 files changed

+119
-112
lines changed

6 files changed

+119
-112
lines changed

docs/source/cpo_trainer.mdx

Lines changed: 69 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2,101 +2,66 @@
22

33
[![](https://img.shields.io/badge/All_models-CPO-blue)](https://huggingface.co/models?other=cpo)
44

5-
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, and Young Jin Kim. At a high-level, CPO trains models to
6-
avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.
5+
## Overview
76

8-
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
7+
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.
98

10-
## SimPO
11-
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the `CPOTrainer`. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the `CPOConfig`.
9+
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
1210

13-
## CPO-SimPO
14-
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO Github](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the CPOConfig.
11+
## Quick start
1512

16-
## Expected dataset format
13+
This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [Capybara dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
1714

18-
The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
19-
20-
- `prompt`
21-
- `chosen`
22-
- `rejected`
23-
24-
for example:
25-
26-
```py
27-
cpo_dataset_dict = {
28-
"prompt": [
29-
"hello",
30-
"how are you",
31-
"What is your name?",
32-
"What is your name?",
33-
"Which is the best programming language?",
34-
"Which is the best programming language?",
35-
"Which is the best programming language?",
36-
],
37-
"chosen": [
38-
"hi nice to meet you",
39-
"I am fine",
40-
"My name is Mary",
41-
"My name is Mary",
42-
"Python",
43-
"Python",
44-
"Java",
45-
],
46-
"rejected": [
47-
"leave me alone",
48-
"I am not fine",
49-
"Whats it to you?",
50-
"I dont have a name",
51-
"Javascript",
52-
"C++",
53-
"C++",
54-
],
55-
}
56-
```
57-
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
15+
<iframe
16+
src="https://huggingface.co/datasets/trl-lib/Capybara-Preferences/embed/viewer/default/train?row=0"
17+
frameborder="0"
18+
width="100%"
19+
height="560px"
20+
></iframe>
5821

59-
## Expected model format
60-
The CPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
22+
Below is the script to train the model:
6123

62-
## Using the `CPOTrainer`
63-
For a detailed example have a look at the `examples/scripts/cpo.py` script. At a high level we need to initialize the `CPOTrainer` with a `model` we wish to train. **Note that CPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above.
24+
```python
25+
# train_cpo.py
26+
from datasets import load_dataset
27+
from trl import CPOConfig, CPOTrainer
28+
from transformers import AutoModelForCausalLM, AutoTokenizer
6429

65-
```py
66-
training_args = CPOConfig(
67-
beta=0.1,
68-
)
30+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
31+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
32+
train_dataset = load_dataset("trl-lib/Capybara-Preferences", split="train")
6933

70-
cpo_trainer = CPOTrainer(
71-
model,
72-
args=training_args,
73-
train_dataset=train_dataset,
74-
tokenizer=tokenizer,
75-
)
34+
training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO", logging_steps=10)
35+
trainer = CPOTrainer(model=model, args=training_args, tokenizer=tokenizer, train_dataset=train_dataset)
36+
trainer.train()
7637
```
77-
After this one can then call:
7838

79-
```py
80-
cpo_trainer.train()
81-
```
39+
Execute the script using the following command:
8240

83-
## Loss functions
41+
```bash
42+
accelerate launch train_cpo.py
43+
```
8444

85-
Given the preference data, the `CPOTrainer` uses the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.
45+
## Expected dataset format
8646

87-
The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. The `CPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.
47+
CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
8848

89-
The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).
49+
## Example script
9050

91-
### For Mixture of Experts Models: Enabling the auxiliary loss
51+
We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py)
9252

93-
MOEs are the most efficient if the load is about equally distributed between experts.
94-
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
53+
To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
9554

96-
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
97-
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
55+
```bash
56+
accelerate launch examples/scripts/cpo.py \
57+
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
58+
--dataset_name trl-lib/ultrafeedback_binarized \
59+
--num_train_epochs 1 \
60+
--logging_steps 25 \
61+
--output_dir Qwen2-0.5B-CPO
62+
```
9863

99-
## Logging
64+
## Logged metrics
10065

10166
While training and evaluating we record the following reward metrics:
10267

@@ -106,6 +71,34 @@ While training and evaluating we record the following reward metrics:
10671
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
10772
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
10873

74+
## CPO variants
75+
76+
### Simple Preference Optimization (SimPO)
77+
78+
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the [`CPOConfig`].
79+
80+
### CPO-SimPO
81+
82+
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`].
83+
84+
## Loss functions
85+
86+
The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported:
87+
88+
| `loss_type=` | Description |
89+
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
90+
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
91+
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
92+
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
93+
94+
### For Mixture of Experts Models: Enabling the auxiliary loss
95+
96+
MOEs are the most efficient if the load is about equally distributed between experts.
97+
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
98+
99+
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
100+
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
101+
109102
## CPOTrainer
110103

111104
[[autodoc]] CPOTrainer

docs/source/dataset_formats.mdx

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -194,20 +194,20 @@ unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "
194194

195195
Choosing the right dataset format depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset formats supported by each TRL trainer.
196196

197-
| Trainer | Expected dataset format |
198-
| ----------------------- | ---------------------------- |
199-
| [`BCOTrainer`] | Unpaired preference |
200-
| [`CPOTrainer`] | Preference (explicit prompt) |
201-
| [`DPOTrainer`] | Preference (explicit prompt) |
202-
| [`IterativeSFTTrainer`] | Unpaired preference |
203-
| [`KTOTrainer`] | Unpaired preference |
204-
| [`NashMDTrainer`] | Prompt-only |
205-
| [`OnlineDPOTrainer`] | Prompt-only |
206-
| [`ORPOTrainer`] | Preference (explicit prompt) |
207-
| [`PPOv2Trainer`] | Tokenized language modeling |
208-
| [`RewardTrainer`] | Preference (implicit prompt) |
209-
| [`SFTTrainer`] | Language modeling |
210-
| [`XPOTrainer`] | Prompt-only |
197+
| Trainer | Expected dataset format |
198+
| ----------------------- | ------------------------------------------------------- |
199+
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
200+
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
201+
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
202+
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
203+
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) |
204+
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
205+
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
206+
| [`ORPOTrainer`] | [Preference (explicit prompt)](#preference) |
207+
| [`PPOv2Trainer`] | Tokenized language modeling |
208+
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
209+
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
210+
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
211211

212212
<Tip>
213213

examples/scripts/cpo.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454

5555
from dataclasses import dataclass, field
5656

57-
from accelerate import PartialState
5857
from datasets import load_dataset
5958
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
6059

@@ -65,7 +64,7 @@
6564
@dataclass
6665
class ScriptArguments:
6766
dataset_name: str = field(
68-
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style",
67+
default="trl-lib/ultrafeedback_binarized",
6968
metadata={"help": "The name of the dataset to use."},
7069
)
7170

@@ -93,16 +92,6 @@ class ScriptArguments:
9392
if tokenizer.chat_template is None:
9493
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
9594

96-
def process(row):
97-
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
98-
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
99-
return row
100-
101-
# Compute that only on the main process for faster data processing.
102-
# see: https://github.com/huggingface/trl/pull/1255
103-
with PartialState().local_main_process_first():
104-
dataset = dataset.map(process, num_proc=training_args.dataset_num_proc)
105-
10695
################
10796
# Training
10897
################

0 commit comments

Comments
 (0)