1313# limitations under the License.
1414
1515from dataclasses import dataclass , field
16+ from typing import Final
1617
1718import torch
1819
@@ -97,13 +98,15 @@ def __post_init__(self):
9798 "Model must be loaded in fp32 to enable mixed precision training."
9899 )
99100
101+ trainer_type : Final [TrainerType ] = self .training .trainer_type
102+
100103 # Check values for model sequence length.
101104 if self .model .model_max_length and self .model .model_max_length > 0 :
102105 max_seq_length_value = int (self .model .model_max_length )
103106 max_seq_length_key = None
104- if self . training . trainer_type == TrainerType .TRL_SFT :
107+ if trainer_type == TrainerType .TRL_SFT :
105108 max_seq_length_key = "max_seq_length"
106- elif self . training . trainer_type == TrainerType .TRL_DPO :
109+ elif trainer_type == TrainerType .TRL_DPO :
107110 max_seq_length_key = "max_length"
108111 # TODO: DPOTrainer also defines "max_prompt_length" and
109112 # "max_target_length". How to handle them?
@@ -130,18 +133,44 @@ def __post_init__(self):
130133 # patch ourselves.
131134 # TODO(OPE-1117): Clean up this logic after upgrading to trl 0.16.
132135 if self .model .enable_liger_kernel :
133- if self . training . trainer_type == TrainerType .TRL_SFT :
136+ if trainer_type == TrainerType .TRL_SFT :
134137 self .training .trainer_kwargs ["use_liger" ] = True
135138 self .training .trainer_kwargs ["use_liger_kernel" ] = True
136139 self .model .enable_liger_kernel = False
137- elif (
138- self .training .trainer_type == TrainerType .TRL_DPO
139- or self .training .trainer_type == TrainerType .HF
140- ):
140+ elif trainer_type in (TrainerType .TRL_DPO , TrainerType .HF ):
141141 self .training .trainer_kwargs ["use_liger_kernel" ] = True
142142 self .model .enable_liger_kernel = False
143- elif self . training . trainer_type == TrainerType .OUMI :
143+ elif trainer_type == TrainerType .OUMI :
144144 # We need to Liger patch ourselves for our own training loop.
145145 pass
146146 else :
147147 raise ValueError ("Unrecognized trainer type!" )
148+
149+ # Setup and validate params for "vision_language_sft" collator.
150+ # The collator expects VLM SFT dataset to only produce just
151+ # one column: 'conversation_json' (JSON-encoded `Conversation`)!
152+ collator_name : Final [str ] = self .data .train .collator_name or ""
153+ if collator_name == "vision_language_sft" :
154+ for dataset_params in self .data .train .datasets :
155+ if not dataset_params .dataset_kwargs .get ("return_conversations" , True ):
156+ raise ValueError (
157+ "`return_conversations` must be True "
158+ f"for the dataset '{ dataset_params .dataset_name } ' "
159+ f"when using '{ collator_name } ' collator!"
160+ )
161+ dataset_params .dataset_kwargs ["return_conversations" ] = True
162+ # Extra setup for TRL_SFT.
163+ if trainer_type == TrainerType .TRL_SFT :
164+ if self .training .trainer_kwargs .get ("remove_unused_columns" , False ):
165+ raise ValueError (
166+ "`remove_unused_columns` must be False "
167+ f"when using '{ collator_name } ' collator! "
168+ 'The "unused" columns are consumed by the collator, '
169+ "not by a model."
170+ )
171+ self .training .trainer_kwargs ["remove_unused_columns" ] = False
172+
173+ # `trl` shouldn't be preparing the dataset, as we do it in Oumi.
174+ dataset_kwargs = self .training .trainer_kwargs .get ("dataset_kwargs" , {})
175+ dataset_kwargs ["skip_prepare_dataset" ] = True
176+ self .training .trainer_kwargs ["dataset_kwargs" ] = dataset_kwargs
0 commit comments