Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

phi 4 multimodal training version 1 ( with limitations ) #1555

Merged
merged 17 commits into from
Mar 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions configs/recipes/vision/phi4/sft/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Phi-4-multimodal-instruct

Configs for Phi-4-multimodal-instruct 5.6Β model. See https://huggingface.co/microsoft/Phi-4-multimodal-instruct

This is a multimodal model that combines text, visual, and audio inputs.
It uses a "Mixture of LoRAs" approach, allowing you to plug in adapters for each
modality without needing to retrain the base model. For more information consider
reading the following:

- [Mixture-of-LoRAs](https://arxiv.org/abs/2403.03432)
- [Phi-4 Multimodal Technical Report](https://arxiv.org/abs/2503.01743)
60 changes: 60 additions & 0 deletions configs/recipes/vision/phi4/sft/gcp_job.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Phi4 multimodal 5.6B full fine-tune training job config.
#
# Requirements:
# - Set up SkyPilot GCP: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html#setup
# - Log into WandB (`wandb login`) or disable `enable_wandb`
#
# Usage:
# oumi launch up --config configs/recipes/vision/phi4/sft/gcp_job.yaml --cluster phi4-multimodal
#
# See Also:
# - Documentation: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html
# - Config class: oumi.core.configs.JobConfig
# - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/job_config.py
# - Other job configs: configs/**/*job.yaml

name: phi4-mm-sft-oumi-train

resources:
cloud: gcp
accelerators: "A100:1" # Feel free to bump up the number of GPUs!
use_spot: false
disk_size: 1000 # Disk size in GBs

num_nodes: 1 # Set it to N for multi-node training.

working_dir: .

file_mounts:
~/.netrc: ~/.netrc # WandB credentials
~/.cache/huggingface/token: ~/.cache/huggingface/token # HF credentials

envs:
WANDB_PROJECT: oumi-train
OUMI_RUN_NAME: phi4-vl.fft.oumi

setup: |
set -e

pip install uv && uv pip install oumi[gpu] hf_transfer
# Install model from HF Hub. This tool increases download speed compared to
# downloading the model during training.
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download microsoft/Phi-4-multimodal-instruct

# The model requires flash_attention_2! Install it here.
pip install -U flash-attn --no-build-isolation


run: |
set -e # Exit if any command failed.
source ./configs/examples/misc/sky_init.sh
set -x
oumi distributed torchrun \
-m oumi train \
-c configs/recipes/vision/phi4/sft/train.yaml \
--training.run_name "${OUMI_RUN_NAME}.${SKYPILOT_TASK_ID}" \
--training.max_steps 25 \
--training.save_steps 0 \
--training.save_final_model false

echo "Node ${SKYPILOT_NODE_RANK} is all done!"
94 changes: 94 additions & 0 deletions configs/recipes/vision/phi4/sft/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Phi-4-multimodal-instruct training config for SFT finetuning.
#
# Phi-4-multimodal-instruct is a multimodal model that combines text, visual, and audio # inputs. It uses a "Mixture of LoRAs" approach, allowing you to plug in adapters for
# each modality without needing to retrain the base model.
#
# Important Note: Oumi has currently integrated and fully tested Phi-4 for vision and
# text modalities only (!).
#
# Requirements:
# - Log into WandB (`wandb login`) or disable `enable_wandb`
# - Run `pip install -U flash-attn --no-build-isolation`
#
# Usage:
# oumi train -c configs/recipes/vision/phi4/sft/train.yaml
#
# See Also:
# - Documentation: https://oumi.ai/docs/en/latest/user_guides/train/train.html
# - Config class: oumi.core.configs.TrainingConfig
# - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
# - Other training configs: configs/**/pretraining/, configs/**/sft/, configs/**/dpo/

model:
model_name: "microsoft/Phi-4-multimodal-instruct"
torch_dtype_str: "bfloat16"
model_max_length: 4096
trust_remote_code: True
attn_implementation: "flash_attention_2" # The model requires Flash Attention.

# The model by default freezes the following audio/image-related modules:
# model.embed_tokens_extend.audio_embed
# model.embed_tokens_extend.image_embed

data:
train:
collator_name: "vision_language_with_padding"
use_torchdata: true
datasets:
- dataset_name: "merve/vqav2-small"
split: "validation"
shuffle: True
seed: 42
trust_remote_code: True
transform_num_workers: "auto"
dataset_kwargs:
processor_name: "microsoft/Phi-4-multimodal-instruct"
return_tensors: True

# Below are examples of other vision SFT datasets
# - dataset_name: "HuggingFaceH4/llava-instruct-mix-vsft"
# split: "train"
# shuffle: True
# seed: 42
# trust_remote_code: True
# transform_num_workers: "auto"
# dataset_kwargs:
# processor_name: "microsoft/Phi-4-multimodal-instruct"
# return_tensors: True
# - dataset_name: "coco_captions"
# split: "train"
# trust_remote_code: True
# dataset_kwargs:
# processor_name: "microsoft/Phi-4-multimodal-instruct"
# return_tensors: True
# - dataset_name: vision_language_jsonl
# dataset_path: "training.jsonl" # See notebook for example how to generate this file
# dataset_kwargs:
# data_column: "messages"
# processor_name: "microsoft/Phi-4-multimodal-instruct"

training:
output_dir: "output/vlm_finetuned"
trainer_type: "TRL_SFT"
enable_gradient_checkpointing: True
per_device_train_batch_size: 1 # Due to processor's handling of variable sized img-features.
gradient_accumulation_steps: 8
max_steps: 20

gradient_checkpointing_kwargs:
# Reentrant docs: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
use_reentrant: False
ddp_find_unused_parameters: True

optimizer: "adamw_torch_fused"
learning_rate: 2e-5
warmup_ratio: 0.03
weight_decay: 0.0
lr_scheduler_type: "cosine"

logging_steps: 10
dataloader_main_process_only: True
dataloader_num_workers: 4
dataloader_prefetch_factor: 8
include_performance_metrics: True
enable_wandb: True
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@ dependencies = [
"aiohttp>=3.10,<3.12", # Used by inference engine
"aiofiles>=22.1.0,<25", # Allows to use async file operations
"aioresponses>=0.7.6,<0.8", # User by inference engine tests
"backoff>=2.2.1,<2.3",
"datasets>=3.2.0,<3.3",
"jsonlines",
"lm_eval[wandb]>=0.4.5,<0.5.0",
3 changes: 3 additions & 0 deletions src/oumi/builders/collators.py
Original file line number Diff line number Diff line change
@@ -161,6 +161,9 @@ def build_collator_from_config(
collator_kwargs["allow_multi_image_inputs"] = (
model_config.visual_config.supports_multiple_images
)
collator_kwargs["main_image_feature"] = (
model_config.visual_config.main_image_feature
)

if collator_name == "vision_language_sft":
processor_name = collator_kwargs.get(
3 changes: 3 additions & 0 deletions src/oumi/builders/processors.py
Original file line number Diff line number Diff line change
@@ -50,9 +50,11 @@ def build_processor(

# Initialize model-specific params.
label_ignore_index: Optional[int] = constants.LABEL_IGNORE_INDEX
ignore_features: Optional[list[str]] = None
processor_kwargs = {}
if model_config is not None:
label_ignore_index = model_config.label_ignore_index
ignore_features = model_config.ignore_features
processor_kwargs.update(model_config.processor_kwargs)

create_processor_fn = functools.partial(
@@ -70,4 +72,5 @@ def build_processor(
worker_processor,
tokenizer,
label_ignore_index=label_ignore_index,
ignore_features=ignore_features,
)
20 changes: 11 additions & 9 deletions src/oumi/core/collators/vision_language_collator_with_padding.py
Original file line number Diff line number Diff line change
@@ -21,8 +21,6 @@
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
from oumi.utils.torch_utils import pad_to_max_dim_and_stack

_PIXEL_VALUES_KEY = "pixel_values"


class VisionLanguageCollatorWithPadding:
def __init__(
@@ -33,6 +31,7 @@ def __init__(
truncation: bool = False,
label_ignore_index: Optional[int] = None,
allow_multi_image_inputs: bool = True,
main_image_feature: str = "pixel_values",
):
"""Custom collator for multi-modal vision-language training.

@@ -45,8 +44,11 @@ def __init__(
label_ignore_index: If set, then label values of tokens that shouldn't
contribute to the loss computation will be replaced by this special value.
allow_multi_image_inputs: Whether to allow multi-image inputs.
main_image_feature: The key to use for fetching the main image data
(e.g., raw pixels, patches, etc.) from the input.
"""
self._allow_multi_image_inputs = allow_multi_image_inputs
self._main_image_feature = main_image_feature
self._text_collator: TextCollatorWithPadding = TextCollatorWithPadding(
tokenizer=tokenizer,
max_length=max_length,
@@ -60,7 +62,7 @@ def __init__(
)

def __call__(self, batch) -> dict[str, Any]:
"""Custom collator for multi-modal vision-language training.
"""Custom collator for multi-modal vision-language training.

Args:
batch: List of batch items.
@@ -71,7 +73,7 @@ def __call__(self, batch) -> dict[str, Any]:
# Collate batch prompts
collated_batch = self._text_collator(batch) # type: ignore
known_input_names: set[str] = set(collated_batch.keys()).union(
{_PIXEL_VALUES_KEY}
{self._main_image_feature}
)
other_input_names: set[str] = set()

@@ -80,12 +82,12 @@ def __call__(self, batch) -> dict[str, Any]:
# TODO Consider relaxing this constraint: a vision/language model
# can handle text-only inputs e.g., a follow-up to an answer,
# or image-only inputs e.g., captioning.
if _PIXEL_VALUES_KEY not in item:
if self._main_image_feature not in item:
raise ValueError(
f"Item doesn't contain '{_PIXEL_VALUES_KEY}' key. "
f"Item doesn't contain '{self._main_image_feature}' key. "
f"Available keys: {item.keys()}"
)
images.append(item[_PIXEL_VALUES_KEY])
images.append(item[self._main_image_feature])

for key in item:
if (
@@ -96,10 +98,10 @@ def __call__(self, batch) -> dict[str, Any]:
other_input_names.add(key)

# Collate images.
pixel_values = self.collate_images(images)
image_input_features = self.collate_images(images)

# Add images to other inputs.
collated_batch[_PIXEL_VALUES_KEY] = pixel_values
collated_batch[self._main_image_feature] = image_input_features

# For other inputs, let's verify they present in all examples and stack them.
if len(other_input_names) > 0:
9 changes: 9 additions & 0 deletions src/oumi/core/configs/internal/internal_model_config.py
Original file line number Diff line number Diff line change
@@ -74,6 +74,12 @@ class InternalFeatureSpec(NamedTuple):

@dataclass
class InternalVisualModelConfig(BaseConfig):
main_image_feature: str = "pixel_values"
"""The key corresponding to the main image feature consumed by the model.

E.g., raw pixels, transformed image patches, etc. resulting from data
preprocessing and consumed by the underlying model."""

variable_shape_image_features: bool = False
"""Whether image features can be of variable shape.

@@ -133,5 +139,8 @@ class InternalModelConfig(BaseConfig):
processor_kwargs: dict[str, Any] = field(default_factory=dict)
"""Extra params to pass to processor constructor."""

ignore_features: list[str] = field(default_factory=list)
"""Features from processing the input to ignore in the model's forward method."""

visual_config: Optional[InternalVisualModelConfig] = None
"""Configuration specific to visual models."""
50 changes: 50 additions & 0 deletions src/oumi/core/configs/internal/supported_models.py
Original file line number Diff line number Diff line change
@@ -200,6 +200,51 @@ def _create_phi3_vlm_config() -> InternalModelConfig:
return config


def _create_phi4_vlm_config() -> InternalModelConfig:
config = InternalModelConfig()
config.chat_template = "phi3-instruct"
config.ignore_features = [
"audio_attention_mask", # We won't use audio features.
"audio_embed_sizes",
"input_audio_embeds",
]

config.model_input_features.update(
{
feature_name: InternalFeatureSpec(
name=feature_name,
required=True,
variable_shape=True,
image_dependent=True,
first_dim_action=InternalFeatureFirstDimAction.DROP_IF_DUMMY,
)
for feature_name in (
"input_image_embeds",
"image_attention_mask",
)
}
)
config.model_input_features.update(
{
feature_name: InternalFeatureSpec(
name=feature_name,
required=True,
variable_shape=False,
image_dependent=True,
)
for feature_name in ("image_sizes",)
}
)
visual_config = InternalVisualModelConfig()
# FIXME OPE-355 Set to True once multi-image issues are resolved for the model.
visual_config.supports_multiple_images = False
visual_config.variable_shape_image_features = True
visual_config.main_image_feature = "input_image_embeds"

config.visual_config = visual_config
return config


def _create_idefics3_vlm_config() -> InternalModelConfig:
config = _create_default_vlm_config(
supports_multiple_images=True, pixel_values_variable_shape=True
@@ -321,6 +366,11 @@ def get_all_models_map() -> (
tested=True,
config=_create_phi3_vlm_config(),
),
_ModelTypeInfo(
model_type="phi4mm",
model_class=transformers.AutoModelForCausalLM,
config=_create_phi4_vlm_config(),
),
]

# Make it immutable.
6 changes: 6 additions & 0 deletions src/oumi/core/processors/base_processor.py
Original file line number Diff line number Diff line change
@@ -87,6 +87,12 @@ def label_ignore_index(self) -> Optional[int]:
"""Returns a label ignore index."""
raise NotImplementedError

@property
@abc.abstractmethod
def ignore_features(self) -> list[str]:
"""Returns a list of keys of features to ignore from feeding the model."""
raise NotImplementedError

@abc.abstractmethod
def __call__(
self,
Loading