Skip to content

shedding-ash/RLHF-impl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 

Repository files navigation

轻量级RLHF算法实现与实验框架

项目背景

本项目从零实现RLHF核心算法(PPO DPO GRPO),剥离工业级框架的冗余设计,提供:

  • 清晰的代码结构​​:分模块实现损失函数、奖励函数、数据收集
  • 轻量级依赖​​:仅需PyTorch+Transformers,适合教学与二次开发

适用于算法学习者理解RLHF全流程和小规模任务(对话机器人、文本风格迁移)的快速验证

项目结构

.
├── README.md                  # 项目说明文档
│
├── dpo/                       
│   ├── 1.actor.ipynb          # 策略模型训练
│   ├── 2.dpo.ipynb            # DPO算法实现
│   ├── 3.test.ipynb           # 调用训练好的模型,实现效果
│   └── 4.trl.ipynb            # 与Transformer强化学习库的集成
│
├── grpo/                      
│   └── grpo.ipynb             # GRPO主实现文件
│
├── ppo0/                      # torch实现
│   ├── 1.actor.ipynb          # 策略模型训练
│   ├── 2.critic.ipynb         # 价值模型训练
│   ├── 3.ppo.ipynb            # PPO算法实现
│   ├── 4.test.ipynb           # 调用训练好的模型,实现效果 
│   └── 5.trl.ipynb            # TRL库适配实现
│
└── ppo1/                      # transformers实现
    ├── 1.actor.ipynb          # 策略模型训练
    ├── 2.reward.ipynb         # 奖励模型实现
    ├── 3.rlhf.ipynb           # 基于人类反馈的强化学习
    ├── 4.test.ipynb           # 调用训练好的模型,实现效果
    ├── lora.py                # LoRA微调模块
    └── util.py                # 工具函数集

子项目简介

/ppo0

PPO情感可控文本生成


任务目标

通过强化学习(PPO)优化生成模型(model_actor),使其生成的文本符合指定的情感标签(0=负面1=正面)。


输入 & 输出

输入:情感标签(0/1) + 部分文本(如 "This movie is"
输出:生成符合指定情感的完整评论(如 "great!""terrible!"


优化目标

  1. 最大化奖励:生成文本的情感与标签一致(通过 model_critic 评分)
  2. 最小化KL散度:保持生成质量,避免偏离初始模型(model_actor_ref

数据 & 模型

数据:IMDB电影评论
• 部分文本用于 model_actor 训练(自回归)
• 完整评论用于 model_critic 监督微调(情感分类)
模型
model_actor(GPT-2, CausalLM):生成评论
model_critic(DistilBERT, SequenceClassification):情感评分
model_actor_ref(GPT-2):约束生成质量(KL散度)


训练流程

  1. 监督微调(预训练阶段):
    model_actor:自回归训练(保持语言模型能力)
    model_critic:情感分类微调(二分类)
  2. PPO强化学习(优化阶段):
    • 优化 model_actor,使其生成文本:
    ◦ 符合 model_critic 的情感要求(奖励最大化)
    ◦ 不偏离 model_actor_ref 的生成质量(KL散度最小化)

/ppo1

PPO SQL语句生成(大小写反转)


任务目标

  1. 训练模型生成大小写反转的SQL语句,展现PPO的格式微调能力。
  2. 通过强化学习优化生成结果,确保SQL语法正确且大小写反转符合要求。

数据 & 模型

数据
Prompt:原始SQL语句的自然语言描述 • Chosen(正样本):大小写反转的SQL(如 "select ENGINE from TABLE_NAME_27 where YEAR = "61"")。
Rejected(负样本):原始SQL。

模型
Actor模型facebook/opt-1.3b
Critic模型facebook/opt-350m


训练流程

  1. 微调Actor(LoRA)
    • 输入:原始SQL → 目标:生成反转大小写SQL(监督学习)。
  2. 训练Critic(Reward Model)
    • 输入:(prompt, chosen, rejected),学习区分“好”(反转SQL)和“坏”(原始SQL)样本。
  3. PPO强化学习
    • Actor生成SQL → Critic打分 → PPO更新策略,最大化奖励。

/dpo

DPO SQL语句生成


任务目标

基于已有的语言模型(GPT2),使用DPO方法,对SQL生成任务进行微调。根据给定的上下文(context)和问题(question),生成符合要求的SQL答案(answer)。


数据 & 模型

数据:b-mc2/sql-create-context 数据集

  • context(背景知识)
  • question(自然语言问题)
  • answer(SQL语句)

模型

  • GPT2

输入 & 输出

  • 输入:自然语言的上下文(context)与问题(question)
    示例:

    context: <背景知识>
    question: <问题>
    answer:
    
  • 输出:对应的SQL语句(answer)

DPO 实现逻辑

(1) 数据准备 (get_data)
输入:从数据集中采样 b 个样本(question + answer)。
构造对比对
Chosen(优选回答)question + correct_answer + EOS
Rejected(拒绝回答)question(简化处理,实际应采样错误答案)
输出
input_ids
attention_maskanswer_mask

(2) 概率差异计算 (get_prob_diff)
公式
prob_diff = log P_θ(chosen) - log P_θ(rejected)

实现

  1. 模型前向计算:获取 logits
  2. 对齐预测与目标
    input_idslogits 偏移 1 位(自回归)。
  3. 计算对数概率
    ◦ 对 logitslog_softmax
    ◦ 用 gather 提取目标 token 对应的概率值。
  4. 掩码处理
    ◦ 通过 answer_mask 仅累加答案部分的联合对数概率。
  5. 差异计算
    ◦ 返回 chosenrejected 的概率差。

(3) 损失函数与优化

loss = -log(σ(prob_diff - prob_diff_ref))

prob_diff_ref提供稳定基线。
目标:最大化 prob_diff,使模型更偏好优质回答。

/grpo

GRPO小模型数学推理及特定格式保持


任务目标

  1. 训练模型生成符合特定XML格式的数学推理答案(包含<reasoning><answer>标签)。
  2. 通过多奖励函数优化生成结果,确保答案正确性、格式规范性和数值合理性。

数据 & 模型

数据
GSM8K数据集(数学推理问答),包含问题和人工标注的逐步解答。使用main分支
• 预处理:提取问题作为输入,答案中的最终数值(#### 答案格式)作为标签。

模型
Qwen2.5-0.5B-Instruct
• 使用transformers库加载。


输入 & 输出

输入
• GSM8K数据集中的数学问题(如"John has 3 apples. He buys 5 more. How many does he have?")。
• 系统提示要求模型按指定XML格式生成答案。
输出
• 模型生成的响应,需包含推理过程和最终答案,例如:
xml <reasoning> John initially has 3 apples. After buying 5 more, total = 3 + 5 = 8. </reasoning> <answer> /box{8} </answer>


奖励函数设计

通过以下奖励函数联合优化模型输出:

  1. 答案正确性correctness_reward_func):答案与标签完全匹配得2分。
  2. 格式规范性
    • 严格XML格式(换行符正确,strict_format_reward_func)得0.5分。
    • 宽松XML格式(允许空白字符,soft_format_reward_func)得0.5分。
    • 标签完整性(xmlcount_reward_func)按标签出现位置和冗余文本扣分。
  3. 数值合理性int_reward_func):答案为整数得0.5分。

RLHF算法简介

待更新 image-20250407020806292

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published