Skip to content

Commit 9282cfd

Browse files
committed
[Tutorial] LLM integration
ghstack-source-id: c53afce03ca7216908298686535e3777da59884e Pull Request resolved: #2832
1 parent 27d3680 commit 9282cfd

File tree

24 files changed

+2471
-728
lines changed

24 files changed

+2471
-728
lines changed
+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from argparse import ArgumentParser
6+
7+
import torch
8+
from datasets import load_dataset
9+
from grpo_utils import (
10+
HF2vLLMLocalWeightUpdater,
11+
PrepareQuestion,
12+
ShapedCorrectnessReward,
13+
)
14+
from tensordict import TensorDict
15+
from torch.utils._pytree import tree_map
16+
from torch.utils.data import DataLoader
17+
from torchrl.collectors import SyncDataCollector
18+
from torchrl.data import LazyStackStorage, RayReplayBuffer, ReplayBuffer, SamplerWithoutReplacement
19+
from torchrl.envs import DataLoadingPrimer, KLRewardTransform, LLMEnv, StepCounter
20+
from torchrl.modules import from_hf_transformers, from_vllm
21+
from torchrl.objectives import ClipPPOLoss
22+
from transformers import AutoTokenizer, GPT2LMHeadModel
23+
from vllm import LLM
24+
25+
parser = ArgumentParser()
26+
parser.add_argument("--dataset", type=str, default="gsm8k")
27+
parser.add_argument("--batch_size", type=int, default=4)
28+
parser.add_argument("--epochs", type=int, default=10)
29+
parser.add_argument("--repeats", type=int, default=10)
30+
parser.add_argument("--steps_per_batch", type=int, default=16)
31+
parser.add_argument("--optim_batch_size", type=int, default=4)
32+
33+
34+
def compute_mc_advantage(trajectories):
35+
# Get the question
36+
answer = trajectories["answer"]
37+
# Identify indices where the answers match
38+
answer_ids = tree_map(lambda string: hash(string), answer)
39+
answer_ids = torch.tensor(answer_ids)
40+
unique_qs = answer_ids.view(-1).unique()
41+
trajectories["advantage"] = trajectories["next", "reward"] * 0
42+
for u in unique_qs:
43+
idx = answer_ids == u
44+
rewards = trajectories[idx]["next", "reward"]
45+
rewards = (rewards - rewards.mean()) / rewards.std().clamp(min=1e-4)
46+
trajectories.set_at_("advantage", rewards, idx)
47+
return trajectories
48+
49+
50+
if __name__ == "__main__":
51+
args = parser.parse_args()
52+
# Create env instance:
53+
# - Load the gsm8k dataset
54+
dataset = load_dataset(args.dataset, "main")
55+
train_dataset = dataset["train"]
56+
57+
def collate_fn(batch):
58+
batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch])
59+
batch.rename_key_("question", "text")
60+
return batch
61+
62+
# LLM
63+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
64+
# inference_model = GPT2LMHeadModel(GPT2Config())
65+
inference_model = LLM("gpt2")
66+
tokenizer.pad_token = tokenizer.eos_token
67+
tokenizer.padding_side = "left"
68+
69+
# Env
70+
dataloader = DataLoader( # noqa: TOR401
71+
train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn
72+
)
73+
env = LLMEnv.from_dataloader(
74+
dataloader=dataloader,
75+
tokenizer=tokenizer,
76+
str2str=True,
77+
batch_size=(args.batch_size * args.repeats,),
78+
repeats=args.repeats,
79+
)
80+
for i, trsf in enumerate(env.transform):
81+
if isinstance(trsf, DataLoadingPrimer):
82+
env.insert_transform(i, PrepareQuestion())
83+
break
84+
85+
# Finally, we want the env to stop after the first step
86+
env.append_transform(StepCounter(max_steps=1))
87+
88+
policy = from_vllm(
89+
inference_model,
90+
tokenizer=tokenizer,
91+
from_text=False,
92+
generate=True,
93+
return_log_probs=True,
94+
)
95+
96+
# Reward transform
97+
env.append_transform(ShapedCorrectnessReward(tokenizer=tokenizer))
98+
99+
# Ref model
100+
ref_model = GPT2LMHeadModel.from_pretrained("gpt2").eval()
101+
TensorDict.from_module(ref_model).data.to_module(ref_model)
102+
ref_model = from_hf_transformers(
103+
ref_model,
104+
tokenizer=tokenizer,
105+
from_text=False,
106+
generate=False,
107+
return_log_probs=True,
108+
)
109+
env.append_transform(
110+
KLRewardTransform(actor=ref_model, coef=0.1, log_prob_key="log_probs")
111+
)
112+
113+
# replay buffer
114+
rb = ReplayBuffer(
115+
storage=LazyStackStorage(args.steps_per_batch),
116+
sampler=SamplerWithoutReplacement(),
117+
batch_size=args.optim_batch_size,
118+
)
119+
120+
# Collector
121+
train_model = GPT2LMHeadModel.from_pretrained("gpt2").eval()
122+
collector = SyncDataCollector(
123+
env,
124+
policy,
125+
frames_per_batch=args.steps_per_batch,
126+
total_frames=1_000_000,
127+
local_weights_updater=HF2vLLMLocalWeightUpdater(
128+
hf_model=train_model, vllm_model=inference_model
129+
),
130+
)
131+
132+
# Loss module
133+
policy_training = from_hf_transformers(
134+
train_model,
135+
tokenizer=tokenizer,
136+
from_text=False,
137+
generate=False,
138+
return_log_probs=True,
139+
)
140+
loss_fn = ClipPPOLoss(
141+
actor_network=policy_training,
142+
critic_network=None,
143+
critic_coef=0.0,
144+
functional=False,
145+
)
146+
loss_fn.set_keys(sample_log_prob="log_probs")
147+
loss_fn._set_in_keys()
148+
optim = torch.optim.Adam(loss_fn.parameters())
149+
150+
# loss_fn = ReinforceLoss(
151+
# actor_network=policy,
152+
# critic_network=None,
153+
# critic_coef=0.0,
154+
# )
155+
156+
for trajs in collector:
157+
trajs = trajs.reshape(-1)
158+
trajs = compute_mc_advantage(trajs)
159+
rb.extend(trajs)
160+
for _ in range(args.epochs):
161+
for batch in rb:
162+
loss = loss_fn(batch)
163+
loss_val = loss.mean(reduce=True)
164+
loss_val.backward()
165+
optim.step()
166+
optim.zero_grad()
167+
collector.update_policy_weights_()

0 commit comments

Comments
 (0)