Skip to content

Commit 4df8a4d

Browse files
authored
Auto-populate and validate params specific to vision_language_sft collator in TrainingConfig (#1537)
1 parent bb28d67 commit 4df8a4d

File tree

1 file changed

+37
-8
lines changed

1 file changed

+37
-8
lines changed

src/oumi/core/configs/training_config.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass, field
16+
from typing import Final
1617

1718
import torch
1819

@@ -97,13 +98,15 @@ def __post_init__(self):
9798
"Model must be loaded in fp32 to enable mixed precision training."
9899
)
99100

101+
trainer_type: Final[TrainerType] = self.training.trainer_type
102+
100103
# Check values for model sequence length.
101104
if self.model.model_max_length and self.model.model_max_length > 0:
102105
max_seq_length_value = int(self.model.model_max_length)
103106
max_seq_length_key = None
104-
if self.training.trainer_type == TrainerType.TRL_SFT:
107+
if trainer_type == TrainerType.TRL_SFT:
105108
max_seq_length_key = "max_seq_length"
106-
elif self.training.trainer_type == TrainerType.TRL_DPO:
109+
elif trainer_type == TrainerType.TRL_DPO:
107110
max_seq_length_key = "max_length"
108111
# TODO: DPOTrainer also defines "max_prompt_length" and
109112
# "max_target_length". How to handle them?
@@ -130,18 +133,44 @@ def __post_init__(self):
130133
# patch ourselves.
131134
# TODO(OPE-1117): Clean up this logic after upgrading to trl 0.16.
132135
if self.model.enable_liger_kernel:
133-
if self.training.trainer_type == TrainerType.TRL_SFT:
136+
if trainer_type == TrainerType.TRL_SFT:
134137
self.training.trainer_kwargs["use_liger"] = True
135138
self.training.trainer_kwargs["use_liger_kernel"] = True
136139
self.model.enable_liger_kernel = False
137-
elif (
138-
self.training.trainer_type == TrainerType.TRL_DPO
139-
or self.training.trainer_type == TrainerType.HF
140-
):
140+
elif trainer_type in (TrainerType.TRL_DPO, TrainerType.HF):
141141
self.training.trainer_kwargs["use_liger_kernel"] = True
142142
self.model.enable_liger_kernel = False
143-
elif self.training.trainer_type == TrainerType.OUMI:
143+
elif trainer_type == TrainerType.OUMI:
144144
# We need to Liger patch ourselves for our own training loop.
145145
pass
146146
else:
147147
raise ValueError("Unrecognized trainer type!")
148+
149+
# Setup and validate params for "vision_language_sft" collator.
150+
# The collator expects VLM SFT dataset to only produce just
151+
# one column: 'conversation_json' (JSON-encoded `Conversation`)!
152+
collator_name: Final[str] = self.data.train.collator_name or ""
153+
if collator_name == "vision_language_sft":
154+
for dataset_params in self.data.train.datasets:
155+
if not dataset_params.dataset_kwargs.get("return_conversations", True):
156+
raise ValueError(
157+
"`return_conversations` must be True "
158+
f"for the dataset '{dataset_params.dataset_name}' "
159+
f"when using '{collator_name}' collator!"
160+
)
161+
dataset_params.dataset_kwargs["return_conversations"] = True
162+
# Extra setup for TRL_SFT.
163+
if trainer_type == TrainerType.TRL_SFT:
164+
if self.training.trainer_kwargs.get("remove_unused_columns", False):
165+
raise ValueError(
166+
"`remove_unused_columns` must be False "
167+
f"when using '{collator_name}' collator! "
168+
'The "unused" columns are consumed by the collator, '
169+
"not by a model."
170+
)
171+
self.training.trainer_kwargs["remove_unused_columns"] = False
172+
173+
# `trl` shouldn't be preparing the dataset, as we do it in Oumi.
174+
dataset_kwargs = self.training.trainer_kwargs.get("dataset_kwargs", {})
175+
dataset_kwargs["skip_prepare_dataset"] = True
176+
self.training.trainer_kwargs["dataset_kwargs"] = dataset_kwargs

0 commit comments

Comments
 (0)