简体中文 | English
本仓库包含一个使用参数高效微调(PEFT)技术训练 Phi3-V 模型的脚本,支持各种配置和选项。
使用 requirements.txt
或 environment.yml
安装所需的软件包。
pip install -r requirements.txt
conda env create -f environment.yml
conda activate phi3v
在训练之前,从 HuggingFace 下载 Phi3-V 模型。建议使用 huggingface-cli
进行下载。
- 安装 HuggingFace CLI:
pip install -U "huggingface_hub[cli]"
- 下载模型:
huggingface-cli download microsoft/Phi-3-vision-128k-instruct --local-dir Phi-3-vision-128k-instruct --resume-download
运行训练脚本,请使用以下命令:
bash scripts/train.sh
注意: 请记得将 train.sh
中的路径替换为你的具体路径。
--data_path
(str): LLaVA 格式的训练数据路径(JSON 文件)。(必需)--image_folder
(str): LLaVA 格式的训练数据中引用的图像文件夹路径。(必需)--model_id
(str): Phi3-V 模型的路径。(必需)--proxy
(str): 代理设置(默认:无)。--output_dir
(str): 模型 checkpoint 的输出目录(默认:"output/test_train")。--num_train_epochs
(int): 训练 epoch 数(默认:1)。--per_device_train_batch_size
(int): 每个 GPU 每个前向步骤的训练批量大小。--gradient_accumulation_steps
(int): 梯度累积步骤(默认:4)。--deepspeed_config
(str): DeepSpeed 配置文件的路径(默认:"scripts/zero2.json")。--num_lora_modules
(int): 要添加 LoRA 的目标模块数(-1 表示所有层)。--lora_namespan_exclude
(str): 排除具有名称范围的模块以添加 LoRA。--max_seq_length
(int): 最大序列长度(默认:3072)。--quantization
(flag): 启用量化。--disable_flash_attn2
(flag): 禁用 Flash Attention 2。--report_to
(str): 报告工具(选项:'tensorboard', 'wandb', 'none')(默认:'tensorboard')。--logging_dir
(str): 日志目录(默认:"./tf-logs")。--lora_rank
(int): LoRA rank(默认:128)。--lora_alpha
(int): LoRA alpha(默认:256)。--lora_dropout
(float): LoRA dropout(默认:0.05)。--logging_steps
(int): 日志记录步骤(默认:1)。--dataloader_num_workers
(int): 数据加载器工作线程数(默认:4)。
该脚本需要按照 LLaVA 规范格式化的数据集。数据集应为 JSON 文件,每个条目包含对话和图像信息。确保数据集中图像路径与提供的 --image_folder
相匹配。
数据集示例
[
{
"id": "000000033471",
"image": "000000033471.jpg",
"conversations": [
{
"from": "human",
"value": "<image>\nWhat are the colors of the bus in the image?"
},
{
"from": "gpt",
"value": "The bus in the image is white and red."
},
{
"from": "human",
"value": "What feature can be seen on the back of the bus?"
},
{
"from": "gpt",
"value": "The back of the bus features an advertisement."
},
{
"from": "human",
"value": "Is the bus driving down the street or pulled off to the side?"
},
{
"from": "gpt",
"value": "The bus is driving down the street, which is crowded with people and other vehicles."
}
]
}
...
]
- 添加对 DeepSpeed ZeRO-3 的支持。
- 添加对 FSDP 的支持。
- 添加对全量微调的支持。
- 增加Grounding微调支持
- 支持多图微调
- 川虎 Chat 集成
本项目使用 Apache-2.0 许可证。详见 LICENSE 文件。
本项目借用了 LLaVA 和 Microsoft Phi-3-vision-128k-instruct 的代码。感谢这两个项目的贡献。
如果你在工作中使用了这个代码库,请引用本项目:
@misc{phi3vfinetuning2023,
author = {Gai Zhenbiao & Shao Zhenwei},
title = {Phi3V-Finetuning},
year = {2023},
publisher = {GitHub},
url = {https://github.com/GaiZhenbiao/Phi3V-Finetuning},
note = {GitHub repository},
}