本项目从零实现RLHF核心算法(PPO DPO GRPO),剥离工业级框架的冗余设计,提供:
- 清晰的代码结构:分模块实现损失函数、奖励函数、数据收集
- 轻量级依赖:仅需PyTorch+Transformers,适合教学与二次开发
适用于算法学习者理解RLHF全流程和小规模任务(对话机器人、文本风格迁移)的快速验证
.
├── README.md # 项目说明文档
│
├── dpo/
│ ├── 1.actor.ipynb # 策略模型训练
│ ├── 2.dpo.ipynb # DPO算法实现
│ ├── 3.test.ipynb # 调用训练好的模型,实现效果
│ └── 4.trl.ipynb # 与Transformer强化学习库的集成
│
├── grpo/
│ └── grpo.ipynb # GRPO主实现文件
│
├── ppo0/ # torch实现
│ ├── 1.actor.ipynb # 策略模型训练
│ ├── 2.critic.ipynb # 价值模型训练
│ ├── 3.ppo.ipynb # PPO算法实现
│ ├── 4.test.ipynb # 调用训练好的模型,实现效果
│ └── 5.trl.ipynb # TRL库适配实现
│
└── ppo1/ # transformers实现
├── 1.actor.ipynb # 策略模型训练
├── 2.reward.ipynb # 奖励模型实现
├── 3.rlhf.ipynb # 基于人类反馈的强化学习
├── 4.test.ipynb # 调用训练好的模型,实现效果
├── lora.py # LoRA微调模块
└── util.py # 工具函数集
PPO情感可控文本生成
通过强化学习(PPO)优化生成模型(model_actor
),使其生成的文本符合指定的情感标签(0=负面,1=正面)。
• 输入:情感标签(0/1) + 部分文本(如 "This movie is"
)
• 输出:生成符合指定情感的完整评论(如 "great!"
或 "terrible!"
)
- 最大化奖励:生成文本的情感与标签一致(通过
model_critic
评分) - 最小化KL散度:保持生成质量,避免偏离初始模型(
model_actor_ref
)
• 数据:IMDB电影评论
• 部分文本用于 model_actor
训练(自回归)
• 完整评论用于 model_critic
监督微调(情感分类)
• 模型:
• model_actor
(GPT-2, CausalLM
):生成评论
• model_critic
(DistilBERT, SequenceClassification
):情感评分
• model_actor_ref
(GPT-2):约束生成质量(KL散度)
- 监督微调(预训练阶段):
•model_actor
:自回归训练(保持语言模型能力)
•model_critic
:情感分类微调(二分类) - PPO强化学习(优化阶段):
• 优化model_actor
,使其生成文本:
◦ 符合model_critic
的情感要求(奖励最大化)
◦ 不偏离model_actor_ref
的生成质量(KL散度最小化)
- 训练模型生成大小写反转的SQL语句,展现PPO的格式微调能力。
- 通过强化学习优化生成结果,确保SQL语法正确且大小写反转符合要求。
• 数据:
• Prompt:原始SQL语句的自然语言描述
• Chosen(正样本):大小写反转的SQL(如 "select ENGINE from TABLE_NAME_27 where YEAR = "61""
)。
• Rejected(负样本):原始SQL。
• 模型:
• Actor模型:facebook/opt-1.3b
。
• Critic模型:facebook/opt-350m
。
- 微调Actor(LoRA):
• 输入:原始SQL → 目标:生成反转大小写SQL(监督学习)。 - 训练Critic(Reward Model):
• 输入:(prompt, chosen, rejected)
,学习区分“好”(反转SQL)和“坏”(原始SQL)样本。 - PPO强化学习:
• Actor生成SQL → Critic打分 → PPO更新策略,最大化奖励。
DPO SQL语句生成
基于已有的语言模型(GPT2),使用DPO方法,对SQL生成任务进行微调。根据给定的上下文(context)和问题(question),生成符合要求的SQL答案(answer)。
• 数据:b-mc2/sql-create-context 数据集
- context(背景知识)
- question(自然语言问题)
- answer(SQL语句)
• 模型:
- GPT2
-
输入:自然语言的上下文(context)与问题(question)
示例:context: <背景知识> question: <问题> answer:
-
输出:对应的SQL语句(answer)
(1) 数据准备 (get_data
)
• 输入:从数据集中采样 b
个样本(question
+ answer
)。
• 构造对比对:
• Chosen(优选回答):question + correct_answer + EOS
• Rejected(拒绝回答):question
(简化处理,实际应采样错误答案)
• 输出:
• input_ids
• attention_mask
• answer_mask
(2) 概率差异计算 (get_prob_diff
)
• 公式:
prob_diff = log P_θ(chosen) - log P_θ(rejected)
• 实现:
- 模型前向计算:获取
logits
。 - 对齐预测与目标:
◦input_ids
和logits
偏移 1 位(自回归)。 - 计算对数概率:
◦ 对logits
取log_softmax
。
◦ 用gather
提取目标 token 对应的概率值。 - 掩码处理:
◦ 通过answer_mask
仅累加答案部分的联合对数概率。 - 差异计算:
◦ 返回chosen
和rejected
的概率差。
(3) 损失函数与优化
loss = -log(σ(prob_diff - prob_diff_ref))
• prob_diff_ref
提供稳定基线。
• 目标:最大化 prob_diff
,使模型更偏好优质回答。
GRPO小模型数学推理及特定格式保持
- 训练模型生成符合特定XML格式的数学推理答案(包含
<reasoning>
和<answer>
标签)。 - 通过多奖励函数优化生成结果,确保答案正确性、格式规范性和数值合理性。
• 数据:
• GSM8K数据集(数学推理问答),包含问题和人工标注的逐步解答。使用main分支
• 预处理:提取问题作为输入,答案中的最终数值(#### 答案
格式)作为标签。
• 模型:
• Qwen2.5-0.5B-Instruct
• 使用transformers
库加载。
• 输入:
• GSM8K数据集中的数学问题(如"John has 3 apples. He buys 5 more. How many does he have?"
)。
• 系统提示要求模型按指定XML格式生成答案。
• 输出:
• 模型生成的响应,需包含推理过程和最终答案,例如:
xml <reasoning> John initially has 3 apples. After buying 5 more, total = 3 + 5 = 8. </reasoning> <answer> /box{8} </answer>
通过以下奖励函数联合优化模型输出:
- 答案正确性(
correctness_reward_func
):答案与标签完全匹配得2分。 - 格式规范性:
• 严格XML格式(换行符正确,strict_format_reward_func
)得0.5分。
• 宽松XML格式(允许空白字符,soft_format_reward_func
)得0.5分。
• 标签完整性(xmlcount_reward_func
)按标签出现位置和冗余文本扣分。 - 数值合理性(
int_reward_func
):答案为整数得0.5分。