1+ # Copyright 2025 - Oumi
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ """Base dataset class for KTO (Kahneman-Tversky Optimization).
16+
17+ This module provides a base class for datasets used in KTO training.
18+ Unlike DPO which requires preference pairs, KTO works with simple binary feedback
19+ indicating whether an output is desirable or undesirable.
20+ """
21+
22+ from typing import Optional
23+
24+ from oumi .core .datasets .base_map_dataset import BaseMapDataset
25+ from oumi .core .tokenizers .base_tokenizer import BaseTokenizer
26+
27+ _PROMPT_KEY = "prompt"
28+ _RESPONSE_KEY = "response"
29+ _LABEL_KEY = "label" # True for desirable, False for undesirable
30+
31+ class BaseKtoDataset (BaseMapDataset ):
32+ """Base class for KTO datasets.
33+
34+ This class provides a foundation for creating KTO datasets that work with
35+ binary feedback (desirable/undesirable) rather than preference pairs.
36+
37+ Warning:
38+ This class is experimental and subject to change.
39+ """
40+
41+ def __init__ (
42+ self ,
43+ * ,
44+ dataset_name : Optional [str ] = None ,
45+ dataset_path : Optional [str ] = None ,
46+ split : Optional [str ] = None ,
47+ tokenizer : Optional [BaseTokenizer ] = None ,
48+ return_tensors : bool = False ,
49+ ** kwargs ,
50+ ) -> None :
51+ """Initializes a new instance of the BaseKtoDataset class."""
52+ super ().__init__ (
53+ dataset_name = dataset_name ,
54+ dataset_path = dataset_path ,
55+ split = split ,
56+ ** kwargs ,
57+ )
58+
59+ if return_tensors :
60+ raise NotImplementedError (
61+ "return_tensors=True is not implemented for this class"
62+ )
63+
64+ self ._tokenizer = tokenizer
65+ self ._return_tensors = return_tensors
66+
67+ self ._data = self ._load_data ()
68+
69+ def transform_kto (self , sample : dict ) -> dict :
70+ """Transform the sample to the KTO format.
71+
72+ Args:
73+ sample: A dictionary containing the raw sample data.
74+
75+ Returns:
76+ A dictionary with the following keys:
77+ - prompt: The input prompt
78+ - response: The model's response
79+ - label: Boolean indicating if the response is desirable (True) or undesirable (False)
80+ """
81+ prompt = sample [_PROMPT_KEY ]
82+ response = sample [_RESPONSE_KEY ]
83+ label = sample [_LABEL_KEY ]
84+
85+ return {
86+ _PROMPT_KEY : prompt ,
87+ _RESPONSE_KEY : response ,
88+ _LABEL_KEY : label ,
89+ }
90+
91+ def transform (self , sample : dict ) -> dict :
92+ """Transform the sample to the KTO format."""
93+ return self .transform_kto (sample )
0 commit comments