Skip to content

Commit cfecf62

Browse files
committed
[Tutorial] LLM integration
ghstack-source-id: fe507484265b5cd7bbea6739de99e19b3f0b4a92 Pull Request resolved: #2832
1 parent 49a8a42 commit cfecf62

File tree

10 files changed

+652
-167
lines changed

10 files changed

+652
-167
lines changed
+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from argparse import ArgumentParser
6+
7+
import torch
8+
from datasets import load_dataset
9+
from tensordict import TensorDict
10+
from torch.utils.data import DataLoader
11+
from torchrl.collectors import SyncDataCollector
12+
from torchrl.data import LazyStackStorage, ReplayBuffer, SamplerWithoutReplacement
13+
from torchrl.envs import DataLoadingPrimer, KLRewardTransform, LLMEnv, StepCounter, Tokenizer
14+
from torchrl.modules import from_hf_transformers
15+
from torchrl.objectives import ClipPPOLoss, ReinforceLoss
16+
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
17+
from grpo_utils import ShapedCorrectnessReward, PrepareQuestion
18+
from torch.utils._pytree import tree_map
19+
20+
parser = ArgumentParser()
21+
parser.add_argument("--dataset", type=str, default="gsm8k")
22+
parser.add_argument("--batch_size", type=int, default=4)
23+
parser.add_argument("--epochs", type=int, default=10)
24+
parser.add_argument("--repeats", type=int, default=10)
25+
parser.add_argument("--steps_per_batch", type=int, default=16)
26+
parser.add_argument("--optim_batch_size", type=int, default=4)
27+
28+
def compute_mc_advantage(trajectories):
29+
# Get the question
30+
answer = trajectories["answer"]
31+
# Identify indices where the answers match
32+
answer_ids = tree_map(lambda string: hash(string), answer)
33+
answer_ids = torch.tensor(answer_ids)
34+
print("answer_ids", answer_ids)
35+
unique_qs = answer_ids.view(-1).unique()
36+
trajectories["advantage"] = trajectories["next", "reward"] * 0
37+
for u in unique_qs:
38+
idx = answer_ids == u
39+
rewards = trajectories[idx]["next", "reward"]
40+
rewards = (rewards - rewards.mean()) / rewards.std().clamp(min=1e-4)
41+
print("rewards", rewards)
42+
trajectories.set_at_("advantage", rewards, idx)
43+
return trajectories
44+
45+
if __name__ == "__main__":
46+
args = parser.parse_args()
47+
# Create env instance:
48+
# - Load the gsm8k dataset
49+
dataset = load_dataset(args.dataset, "main")
50+
train_dataset = dataset["train"]
51+
52+
def collate_fn(batch):
53+
batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch])
54+
batch.rename_key_("question", "text")
55+
return batch
56+
57+
# LLM
58+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
59+
model = GPT2LMHeadModel(GPT2Config())
60+
61+
tokenizer.pad_token = tokenizer.eos_token
62+
tokenizer.padding_side = "left"
63+
64+
# Env
65+
dataloader = DataLoader(
66+
train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn
67+
)
68+
env = LLMEnv.from_dataloader(
69+
dataloader=dataloader,
70+
tokenizer=tokenizer,
71+
str2str=True,
72+
batch_size=(args.batch_size * args.repeats,),
73+
repeats=args.repeats,
74+
)
75+
for i, trsf in enumerate(env.transform):
76+
if isinstance(trsf, DataLoadingPrimer):
77+
env.insert_transform(i, PrepareQuestion())
78+
break
79+
80+
# Finally, we want the env to stop after the first step
81+
env.append_transform(StepCounter(max_steps=1))
82+
83+
print("env", env)
84+
print(env.reset())
85+
86+
policy = from_hf_transformers(
87+
model,
88+
tokenizer=tokenizer,
89+
from_text=False,
90+
generate=True,
91+
return_log_probs=True,
92+
)
93+
94+
# Reward transform
95+
env.append_transform(ShapedCorrectnessReward(tokenizer=tokenizer))
96+
97+
# Ref model
98+
ref_model = GPT2LMHeadModel(GPT2Config())
99+
ref_model = from_hf_transformers(
100+
ref_model,
101+
tokenizer=tokenizer,
102+
from_text=False,
103+
generate=False,
104+
return_log_probs=True,
105+
)
106+
env.append_transform(KLRewardTransform(actor=ref_model, coef=0.1, log_prob_key="log_probs"))
107+
108+
# replay buffer
109+
rb = ReplayBuffer(storage=LazyStackStorage(args.steps_per_batch), sampler=SamplerWithoutReplacement(), batch_size=args.optim_batch_size)
110+
111+
# Collector
112+
collector = SyncDataCollector(
113+
env, policy, frames_per_batch=args.steps_per_batch, total_frames=1_000_000,
114+
)
115+
116+
# Loss module
117+
policy_traning = from_hf_transformers(
118+
model,
119+
tokenizer=tokenizer,
120+
from_text=False,
121+
generate=False,
122+
return_log_probs=True,
123+
)
124+
loss_fn = ClipPPOLoss(
125+
actor_network=policy_traning,
126+
critic_network=None,
127+
critic_coef=0.0,
128+
functional=False,
129+
)
130+
loss_fn.set_keys(sample_log_prob="log_probs")
131+
loss_fn._set_in_keys()
132+
optim = torch.optim.Adam(loss_fn.parameters())
133+
134+
# loss_fn = ReinforceLoss(
135+
# actor_network=policy,
136+
# critic_network=None,
137+
# critic_coef=0.0,
138+
# )
139+
140+
for trajs in collector:
141+
trajs = trajs.reshape(-1)
142+
print('trajs from collector', trajs)
143+
trajs = compute_mc_advantage(trajs)
144+
rb.extend(trajs)
145+
for i in range(args.epochs):
146+
for batch in rb:
147+
print('running loss with batch', batch)
148+
loss = loss_fn(batch)
149+
loss_val = loss.mean(reduce=True)
150+
loss_val.backward()
151+
optim.step()
152+
optim.zero_grad()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import torch
8+
from tensordict.tensorclass import NonTensorData, NonTensorStack
9+
from torchrl.envs import Transform
10+
from torchrl.data import Composite, TensorSpec, Unbounded
11+
from tensordict.utils import _zip_strict
12+
from tensordict import TensorDictBase, TensorDict
13+
from tensordict import NestedKey
14+
BASE_PROMPT = (
15+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. "
16+
"The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
17+
"The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively, "
18+
"i.e., <think>reasoning process here</think> <answer>answer here</answer>. User: %s. Assistant: <think>"
19+
)
20+
21+
class PrepareQuestion(Transform):
22+
def __init__(self, in_keys: list[NestedKey] | None = None, out_keys: list[NestedKey] | None = None):
23+
if in_keys is None:
24+
in_keys = ["text"]
25+
if out_keys is None:
26+
out_keys = list(in_keys)
27+
super().__init__(in_keys, out_keys)
28+
29+
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
30+
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
31+
string = tensordict.get(in_key)
32+
tensordict.set(out_key, self._modify_str(string))
33+
return tensordict
34+
35+
def _modify_str(self, obs: str | list[str] | NonTensorData | NonTensorStack) -> NonTensorData | NonTensorStack:
36+
if isinstance(obs, NonTensorData):
37+
return self._modify_str(obs.data)
38+
if isinstance(obs, NonTensorStack):
39+
return self._modify_str(obs.tolist())
40+
if isinstance(obs, list):
41+
return NonTensorStack(
42+
*[BASE_PROMPT % obs for obs in obs]
43+
)
44+
return NonTensorData(BASE_PROMPT % obs)
45+
46+
def _apply_transform(self, obs: torch.Tensor) -> None:
47+
return obs
48+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
49+
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
50+
if out_key != in_key:
51+
observation_spec[out_key] = observation_spec[in_key].clone()
52+
return observation_spec
53+
54+
class ShapedCorrectnessReward(Transform):
55+
def __init__(self, tokenizer, in_keys: list[NestedKey] | None=None, out_keys: list[NestedKey] | None = None):
56+
super().__init__()
57+
self.tokenizer = tokenizer
58+
if in_keys is None:
59+
in_keys = ["text", "answer"]
60+
if not isinstance(in_keys, list) or len(in_keys) != 2:
61+
raise ValueError("ShapedCorrectnessReward requires in_keys to be of type list and have 2 elements.")
62+
if out_keys is None:
63+
out_keys = ["reward_answer", "reward_think", "reward_right", "reward_contained", "reward", "success"]
64+
super().__init__(in_keys, out_keys)
65+
66+
def _step(
67+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
68+
) -> TensorDictBase:
69+
from xml.etree import ElementTree as ET
70+
# Get the completion
71+
responses = next_tensordict[self.in_keys[0]] # batch_size, grpo_size, L
72+
answers = next_tensordict[self.in_keys[1]] # batch_size, grpo_size
73+
if isinstance(responses, torch.Tensor):
74+
if responses.ndim == 3:
75+
batch_size, grpo_size, _ = responses.shape
76+
# decode
77+
text_completion = self.tokenizer.decode(
78+
responses.flatten(0, 1).tolist()
79+
)
80+
else:
81+
text_completion = responses
82+
# Decomposed reward
83+
tds = []
84+
for answer, compl in zip(answers, text_completion):
85+
try:
86+
cot, potential_answer = self.extract_tags("<think>" + compl) #.replace("<<", "").replace(">>", ""))
87+
except ET.ParseError:
88+
cot, potential_answer = ("", "")
89+
tds.append(self.single_shaped_correctness_reward(potential_answer, cot))
90+
tds = torch.stack(tds)
91+
if isinstance(responses, torch.Tensor) and responses.ndim == 3:
92+
tds = tds.reshape(batch_size, grpo_size)
93+
tds = tds.apply(lambda t: t.unsqueeze(-1))
94+
return next_tensordict.update(tds)
95+
96+
def transform_reward_spec(self, reward_spec: Composite) -> Composite:
97+
shape = reward_spec.shape + (1,)
98+
reward_spec.update(Composite(
99+
reward_answer=Unbounded(shape),
100+
reward_think=Unbounded(shape),
101+
reward_right=Unbounded(shape),
102+
reward_contained=Unbounded(shape),
103+
reward=Unbounded(shape),
104+
success=Unbounded(shape, dtype=torch.bool),
105+
))
106+
return reward_spec
107+
108+
@classmethod
109+
def single_shaped_correctness_reward(cls, answer: str, cot: str) -> TensorDict:
110+
111+
reward_answer = 5.0 * (len(answer) == 1)
112+
113+
reward_think = 5.0 * (len(cot) == 1)
114+
115+
# One of the answer tags has the right answer
116+
reward_right = 20.0 * (any(attempt == answer for attempt in answer))
117+
118+
# One of the answer tags contains the right answer (might be e.g. $20 instead of 20)
119+
reward_contained = 10.0 * (any((answer in attempt) for attempt in answer))
120+
121+
success = len(answer) > 0 and answer[-1] == answer
122+
# Compose the rewards
123+
reward = 100.0 * float(success) + (reward_answer + reward_think + reward_contained + reward_right) * (1- float(success))
124+
125+
rewards = TensorDict(
126+
reward_answer=reward_answer,
127+
reward_think=reward_think,
128+
reward_right=reward_right,
129+
reward_contained=reward_contained,
130+
reward=reward,
131+
success=success,
132+
)
133+
return rewards
134+
135+
@staticmethod
136+
def extract_tags(text: str) -> Tuple[str, str]:
137+
"""
138+
Parse XML-like tags from text. Returns a dictionary with keys 'think' and 'answer'.
139+
The values are lists of strings, with each string being the content of a tag.
140+
"""
141+
from xml.etree import ElementTree as ET
142+
143+
xml_string = f"<root>{text}</root>"
144+
try:
145+
root = ET.fromstring(xml_string)
146+
except ET.ParseError as e:
147+
return ("", "")
148+
149+
return (
150+
root.find("think").text if root.find("think") is not None else "",
151+
root.find("answer").text if root.find("answer") is not None else "",
152+
)

torchrl/data/tensor_specs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4941,7 +4941,7 @@ def set(self, name: str, spec: TensorSpec) -> Composite:
49414941
spec.shape = self.shape
49424942
else:
49434943
raise ValueError(
4944-
f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first "
4944+
f"The shapes of the spec {type(spec).__name__} and the {type(self).__name__} mismatch: the first "
49454945
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
49464946
f"Composite.shape={self.shape}."
49474947
)

torchrl/envs/common.py

+1
Original file line numberDiff line numberDiff line change
@@ -3383,6 +3383,7 @@ def _rollout_stop_early(
33833383
else:
33843384
tensordict.clear_device_()
33853385
# In case policy(..) does not modify in-place - no-op for TensorDict and related
3386+
print('policy input', tensordict)
33863387
tensordict.update(policy(tensordict))
33873388
if auto_cast_to_device:
33883389
if env_device is not None:

0 commit comments

Comments
 (0)