Skip to content

Commit 92f14c7

Browse files
committed
Add KTO support for preference tuning
1 parent 4df8a4d commit 92f14c7

File tree

9 files changed

+279
-0
lines changed

9 files changed

+279
-0
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Phi3 KTO train config.
2+
#
3+
# Usage:
4+
# oumi train -c configs/recipes/phi3/kto/train.yaml
5+
#
6+
# See Also:
7+
# - Documentation: https://oumi.ai/docs/en/latest/user_guides/train/train.html
8+
# - Config class: oumi.core.configs.TrainingConfig
9+
# - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
10+
# - Other training configs: configs/**/pretraining/, configs/**/sft/, configs/**/dpo/
11+
12+
model:
13+
model_name: "microsoft/Phi-3-mini-4k-instruct"
14+
trust_remote_code: True
15+
16+
data:
17+
train:
18+
datasets:
19+
- dataset_name: "mlabonne/kto-mix-40k"
20+
21+
training:
22+
optimizer: "adamw_torch"
23+
use_peft: true
24+
output_dir: "output/phi3.kto"
25+
trainer_type: "TRL_KTO"
26+
27+
peft:
28+
q_lora: False
29+
lora_target_modules:
30+
- "q_proj"
31+
- "k_proj"
32+
- "v_proj"
33+
- "o_proj"
34+
- "gate_proj"
35+
- "up_proj"
36+
- "down_proj"

src/oumi/builders/training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def _init_oumi_trainer(*args, **kwargs) -> BaseTrainer:
9898
return _create_hf_builder_fn(trl.SFTTrainer)
9999
elif trainer_type == TrainerType.TRL_DPO:
100100
return _create_hf_builder_fn(trl.DPOTrainer)
101+
elif trainer_type == TrainerType.TRL_KTO:
102+
return _create_hf_builder_fn(trl.KTOTrainer)
101103
elif trainer_type == TrainerType.TRL_GRPO:
102104
return _create_hf_builder_fn(trl.GRPOTrainer)
103105
elif trainer_type == TrainerType.HF:

src/oumi/core/configs/params/training_params.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ class TrainerType(Enum):
4545
for fine-tuning language models based on human preferences.
4646
"""
4747

48+
TRL_KTO = "trl_kto"
49+
"""Kahneman-Tversky Optimization trainer from `trl` library.
50+
51+
This trainer implements the KTO algorithm for fine-tuning language models
52+
based on binary feedback (desirable/undesirable) rather than preference pairs.
53+
"""
54+
4855
TRL_GRPO = "trl_grpo"
4956
"""Group Relative Policy Optimization trainer from `trl` library.
5057
@@ -153,6 +160,8 @@ class TrainingParams(BaseParams):
153160
- HF: HuggingFace's Trainer
154161
- TRL_SFT: TRL's SFT Trainer
155162
- TRL_DPO: TRL's DPO Trainer
163+
- TRL_KTO: TRL's KTO Trainer
164+
- TRL_GRPO: TRL's GRPO Trainer
156165
- OUMI: Custom generic trainer implementation
157166
"""
158167

@@ -661,6 +670,8 @@ def to_hf(self):
661670
config_class = trl.SFTConfig
662671
elif self.trainer_type == TrainerType.TRL_DPO:
663672
config_class = trl.DPOConfig
673+
elif self.trainer_type == TrainerType.TRL_KTO:
674+
config_class = trl.KTOConfig
664675
elif self.trainer_type == TrainerType.TRL_GRPO:
665676
config_class = trl.GRPOConfig
666677
else:

src/oumi/core/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from oumi.core.datasets.base_dpo_dataset import BaseExperimentalDpoDataset
2929
from oumi.core.datasets.base_grpo_dataset import BaseExperimentalGrpoDataset
3030
from oumi.core.datasets.base_iterable_dataset import BaseIterableDataset
31+
from oumi.core.datasets.base_kto_dataset import BaseKtoDataset
3132
from oumi.core.datasets.base_map_dataset import BaseMapDataset
3233
from oumi.core.datasets.base_pretraining_dataset import BasePretrainingDataset
3334
from oumi.core.datasets.base_sft_dataset import BaseSftDataset
@@ -41,6 +42,7 @@
4142
"BaseExperimentalDpoDataset",
4243
"BaseExperimentalGrpoDataset",
4344
"BaseIterableDataset",
45+
"BaseKtoDataset",
4446
"BaseMapDataset",
4547
"BasePretrainingDataset",
4648
"BaseSftDataset",
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2025 - Oumi
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Base dataset class for KTO (Kahneman-Tversky Optimization).
16+
17+
This module provides a base class for datasets used in KTO training.
18+
Unlike DPO which requires preference pairs, KTO works with simple binary feedback
19+
indicating whether an output is desirable or undesirable.
20+
"""
21+
22+
from typing import Optional
23+
24+
from oumi.core.datasets.base_map_dataset import BaseMapDataset
25+
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
26+
27+
_PROMPT_KEY = "prompt"
28+
_RESPONSE_KEY = "response"
29+
_LABEL_KEY = "label" # True for desirable, False for undesirable
30+
31+
class BaseKtoDataset(BaseMapDataset):
32+
"""Base class for KTO datasets.
33+
34+
This class provides a foundation for creating KTO datasets that work with
35+
binary feedback (desirable/undesirable) rather than preference pairs.
36+
37+
Warning:
38+
This class is experimental and subject to change.
39+
"""
40+
41+
def __init__(
42+
self,
43+
*,
44+
dataset_name: Optional[str] = None,
45+
dataset_path: Optional[str] = None,
46+
split: Optional[str] = None,
47+
tokenizer: Optional[BaseTokenizer] = None,
48+
return_tensors: bool = False,
49+
**kwargs,
50+
) -> None:
51+
"""Initializes a new instance of the BaseKtoDataset class."""
52+
super().__init__(
53+
dataset_name=dataset_name,
54+
dataset_path=dataset_path,
55+
split=split,
56+
**kwargs,
57+
)
58+
59+
if return_tensors:
60+
raise NotImplementedError(
61+
"return_tensors=True is not implemented for this class"
62+
)
63+
64+
self._tokenizer = tokenizer
65+
self._return_tensors = return_tensors
66+
67+
self._data = self._load_data()
68+
69+
def transform_kto(self, sample: dict) -> dict:
70+
"""Transform the sample to the KTO format.
71+
72+
Args:
73+
sample: A dictionary containing the raw sample data.
74+
75+
Returns:
76+
A dictionary with the following keys:
77+
- prompt: The input prompt
78+
- response: The model's response
79+
- label: Boolean indicating if the response is desirable (True) or undesirable (False)
80+
"""
81+
prompt = sample[_PROMPT_KEY]
82+
response = sample[_RESPONSE_KEY]
83+
label = sample[_LABEL_KEY]
84+
85+
return {
86+
_PROMPT_KEY: prompt,
87+
_RESPONSE_KEY: response,
88+
_LABEL_KEY: label,
89+
}
90+
91+
def transform(self, sample: dict) -> dict:
92+
"""Transform the sample to the KTO format."""
93+
return self.transform_kto(sample)

src/oumi/datasets/debug.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing_extensions import override
2222

2323
from oumi.core.datasets.base_dpo_dataset import BaseExperimentalDpoDataset
24+
from oumi.core.datasets.base_kto_dataset import BaseKtoDataset
2425
from oumi.core.datasets.base_pretraining_dataset import BasePretrainingDataset
2526
from oumi.core.datasets.base_sft_dataset import BaseSftDataset
2627
from oumi.core.registry import register_dataset
@@ -191,3 +192,45 @@ def _load_data(self) -> pd.DataFrame:
191192
],
192193
}
193194
)
195+
196+
197+
@register_dataset("debug_kto")
198+
class DebugKtoDataset(BaseKtoDataset):
199+
default_dataset = "debug_kto"
200+
201+
def __init__(
202+
self,
203+
dataset_size: int = 5,
204+
**kwargs,
205+
):
206+
"""Initializes a DebugKtoDataset."""
207+
self.size = dataset_size
208+
209+
super().__init__(**kwargs)
210+
211+
def transform_kto(self, sample: dict) -> dict:
212+
"""Transforms the sample into a KTO dict."""
213+
return {
214+
"prompt": sample["prompt"],
215+
"completion": sample["completion"],
216+
"label": sample["label"],
217+
}
218+
219+
@override
220+
def _load_data(self) -> pd.DataFrame:
221+
return pd.DataFrame(
222+
{
223+
"prompt": [
224+
f"Hello, how are you? (Document number {idx})"
225+
for idx in range(self.size)
226+
],
227+
"completion": [
228+
f"I'm fine, thank you! (Document number {idx})"
229+
for idx in range(self.size)
230+
],
231+
"label": [
232+
idx % 2 == 0 # True for even indices, False for odd indices
233+
for idx in range(self.size)
234+
],
235+
}
236+
)

src/oumi/datasets/preference_tuning/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
"""Preference tuning datasets module."""
1616

17+
from oumi.datasets.preference_tuning.kto_mix import KtoMix40kDataset
1718
from oumi.datasets.preference_tuning.orpo_dpo_mix import OrpoDpoMix40kDataset
1819

1920
__all__ = [
21+
"KtoMix40kDataset",
2022
"OrpoDpoMix40kDataset",
2123
]
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2025 - Oumi
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from oumi.core.datasets import BaseKtoDataset
16+
from oumi.core.registry import register_dataset
17+
18+
19+
@register_dataset("mlabonne/kto-mix-40k")
20+
class KtoMix40kDataset(BaseKtoDataset):
21+
"""Preprocess the KTO dataset.
22+
23+
A dataset designed for KTO (Kahneman-Tversky Optimization) training.
24+
This dataset is a combination of high-quality datasets with binary feedback,
25+
including:
26+
- Capybara-Preferences (converted to binary)
27+
- distilabel-intel-orca-dpo-pairs (converted to binary)
28+
- ultrafeedback-binarized-preferences-cleaned
29+
- distilabel-math-preference-dpo (converted to binary)
30+
- toxic-dpo-v0.2 (converted to binary)
31+
- prm_dpo_pairs_cleaned (converted to binary)
32+
- truthy-dpo-v0.1 (converted to binary)
33+
34+
Rule-based filtering was applied to remove 'gptisms' in the desirable answers.
35+
36+
Data Fields:
37+
- source: string
38+
- prompt: string
39+
- response: string
40+
- label: boolean (True for desirable, False for undesirable)
41+
42+
See Also:
43+
For more information on how to use this dataset, refer to:
44+
- Paper: https://arxiv.org/pdf/2402.01306
45+
- Huggingface hub: https://huggingface.co/docs/trl/main/en/kto_trainer
46+
"""
47+
48+
default_dataset = "mlabonne/kto-mix-40k"

tests/integration/train/test_train.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,45 @@ def test_train_dpo():
194194
)
195195

196196
train(config)
197+
198+
199+
def test_train_kto():
200+
with tempfile.TemporaryDirectory() as output_temp_dir:
201+
output_training_dir = str(pathlib.Path(output_temp_dir) / "train")
202+
config: TrainingConfig = TrainingConfig(
203+
data=DataParams(
204+
train=DatasetSplitParams(
205+
datasets=[
206+
DatasetParams(
207+
dataset_name="debug_kto",
208+
)
209+
],
210+
),
211+
),
212+
model=ModelParams(
213+
model_name="openai-community/gpt2",
214+
model_max_length=1024,
215+
trust_remote_code=True,
216+
tokenizer_pad_token="<|endoftext|>",
217+
),
218+
training=TrainingParams(
219+
per_device_train_batch_size=2,
220+
trainer_type=TrainerType.TRL_KTO,
221+
max_steps=3,
222+
logging_steps=3,
223+
log_model_summary=True,
224+
enable_wandb=False,
225+
enable_tensorboard=False,
226+
output_dir=output_training_dir,
227+
try_resume_from_last_checkpoint=False,
228+
save_final_model=True,
229+
trainer_kwargs={
230+
"max_length": 512,
231+
"max_prompt_length": 128,
232+
"remove_unused_columns": False,
233+
"desirable_weight": 0.8,
234+
},
235+
),
236+
)
237+
238+
train(config)

0 commit comments

Comments
 (0)