Skip to content

Commit ed051bc

Browse files
authored
[Feature] AddThinkingPrompt transform (#3027)
1 parent 205243c commit ed051bc

File tree

9 files changed

+443
-4
lines changed

9 files changed

+443
-4
lines changed

test/llm/test_envs.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,128 @@ def test_async_mcp_tools(self):
11301130
env_pool.close()
11311131

11321132

1133+
class TestThinkingPrompt:
1134+
@pytest.fixture(autouse=True, scope="class")
1135+
def base_env(self):
1136+
from transformers import AutoTokenizer
1137+
1138+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
1139+
env = GSM8KEnv(shuffle=False, tokenizer=tokenizer, max_steps=10)
1140+
return env
1141+
1142+
@pytest.mark.skipif(not _has_transformers, reason="requires transformers")
1143+
@pytest.mark.skipif(not _has_datasets, reason="requires gsm8k")
1144+
@pytest.mark.parametrize(
1145+
"role,edit_last_turn",
1146+
[("assistant", True), ("assistant", False), ("user", False)],
1147+
)
1148+
@pytest.mark.parametrize("zero_reward", [True, False])
1149+
@pytest.mark.parametrize("undo_done", [True, False])
1150+
@pytest.mark.parametrize("random_prompt", [True, False])
1151+
def test_thinking_prompt_wrong_answer(
1152+
self,
1153+
role,
1154+
edit_last_turn,
1155+
zero_reward,
1156+
undo_done,
1157+
random_prompt,
1158+
tmp_path,
1159+
base_env,
1160+
):
1161+
from torchrl.envs.llm.transforms import AddThinkingPrompt
1162+
1163+
if isinstance(base_env.transform[-1], AddThinkingPrompt):
1164+
base_env.transform.pop()
1165+
env = base_env.reset_dataloader()
1166+
env = base_env.append_transform(
1167+
AddThinkingPrompt(
1168+
cond=lambda td: td["reward"] < 50,
1169+
role=role,
1170+
edit_last_turn=edit_last_turn,
1171+
zero_reward=zero_reward,
1172+
undo_done=undo_done,
1173+
random_prompt=random_prompt,
1174+
)
1175+
)
1176+
reset = env.reset()
1177+
assert reset[0]["history"][-1].content.startswith(
1178+
"Natalia sold clips to 48 of her friends in April"
1179+
)
1180+
policy_anser = (
1181+
"<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
1182+
"To find the total, I need to add April and May: 48 + 24 = 72. Therefore, Natalia sold 72 clips altogether in April and May.</think>\n<answer>322 clips</answer><|im_end|>"
1183+
)
1184+
reset["text_response"] = [policy_anser]
1185+
s = env.step(reset)
1186+
if zero_reward:
1187+
assert (s["next", "reward"] == 0).all()
1188+
else:
1189+
assert (s["next", "reward"] != 0).all()
1190+
if undo_done:
1191+
assert (s["next", "done"] == 0).all()
1192+
else:
1193+
assert (s["next", "done"] != 0).all()
1194+
if edit_last_turn:
1195+
assert s["next", "history"].shape == (1, 3)
1196+
else:
1197+
assert s["next", "history"].shape == (1, 4)
1198+
if role == "assistant":
1199+
assert s[0]["next", "history", "role"][-1] == "assistant"
1200+
else:
1201+
assert s[0]["next", "history", "role"][-1] == "user"
1202+
1203+
@pytest.mark.skipif(not _has_transformers, reason="requires transformers")
1204+
@pytest.mark.skipif(not _has_datasets, reason="requires gsm8k")
1205+
@pytest.mark.parametrize(
1206+
"role,edit_last_turn",
1207+
[("assistant", True), ("assistant", False), ("user", False)],
1208+
)
1209+
@pytest.mark.parametrize("zero_reward", [True, False])
1210+
@pytest.mark.parametrize("undo_done", [True, False])
1211+
@pytest.mark.parametrize("random_prompt", [True, False])
1212+
def test_thinking_prompt_correct_answer(
1213+
self,
1214+
role,
1215+
edit_last_turn,
1216+
zero_reward,
1217+
undo_done,
1218+
random_prompt,
1219+
tmp_path,
1220+
base_env,
1221+
):
1222+
# checks that if cond returns False, nothing is changed
1223+
from torchrl.envs.llm.transforms import AddThinkingPrompt
1224+
1225+
if isinstance(base_env.transform[-1], AddThinkingPrompt):
1226+
base_env.transform.pop()
1227+
env = base_env
1228+
env = env.reset_dataloader()
1229+
env = env.append_transform(
1230+
AddThinkingPrompt(
1231+
cond=lambda td: td["reward"] < 50,
1232+
role=role,
1233+
edit_last_turn=edit_last_turn,
1234+
zero_reward=zero_reward,
1235+
undo_done=undo_done,
1236+
random_prompt=random_prompt,
1237+
)
1238+
)
1239+
reset = env.reset()
1240+
assert reset[0]["history"][-1].content.startswith(
1241+
"Natalia sold clips to 48 of her friends in April"
1242+
)
1243+
policy_anser = (
1244+
"<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
1245+
"To find the total, I need to add April and May: 48 + 24 = 72. Therefore, Natalia sold 72 clips altogether in April and May.</think>\n<answer>72</answer><|im_end|>"
1246+
)
1247+
reset["text_response"] = [policy_anser]
1248+
s = env.step(reset)
1249+
assert (s["next", "reward"] != 0).all(), s["next", "reward"]
1250+
assert s[0]["next", "history", "role"][-1] == "assistant"
1251+
assert s["next", "done"].all()
1252+
assert len(s[0]["next", "history", "content"]) == 3
1253+
1254+
11331255
if __name__ == "__main__":
11341256
args, unknown = argparse.ArgumentParser().parse_known_args()
11351257
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .libs import make_mlgym, MLGymWrapper
1616
from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer
1717
from .transforms import (
18+
AddThinkingPrompt,
1819
as_nested_tensor,
1920
as_padded_tensor,
2021
BrowserTransform,
@@ -33,6 +34,7 @@
3334
"ChatEnv",
3435
"DataLoadingPrimer",
3536
"DatasetChatEnv",
37+
"AddThinkingPrompt",
3638
"GSM8KEnv",
3739
"GSM8KPrepareQuestion",
3840
"GSM8KRewardParser",

torchrl/envs/llm/chat.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ class DatasetChatEnv(TransformedEnv):
284284
285285
Keyword Args:
286286
dataset (str): The name of the dataset.
287+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
287288
name (str, optional): name of the dataset configuration.
288289
split (str, optional): the split to use (usually from `"train"`, `"val"` or `"test"`). Defaults to `None` (no split).
289290
num_envs (int, optional): The number of environments to create. Defaults to `1`.
@@ -317,6 +318,7 @@ def __init__(
317318
self,
318319
*,
319320
dataset: str,
321+
shuffle: bool = True,
320322
name: str | None = None,
321323
split: Literal["train", "val", "test"] | None = None,
322324
num_envs: int = 1,
@@ -355,7 +357,7 @@ def __init__(
355357
dataloader = DataLoader( # noqa: TOR401
356358
dataset,
357359
batch_size=batch_size_dl,
358-
shuffle=True,
360+
shuffle=shuffle,
359361
collate_fn=collate_fn,
360362
generator=generator,
361363
)
@@ -375,3 +377,14 @@ def __init__(
375377
apply_template=apply_template,
376378
)
377379
return super().__init__(env_base, primer)
380+
381+
def reset_dataloader(self):
382+
"""Reset the dataloader.
383+
384+
This is useful when the dataloader is not infinite and we want to reset it.
385+
386+
Returns:
387+
self: The environment itself.
388+
"""
389+
self.transform[0].reset_dataloader()
390+
return self

torchrl/envs/llm/datasets/gsm8k.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class GSM8KEnv(DatasetChatEnv):
135135
136136
Keyword Args:
137137
dataset (str, optional): The name of the dataset. Defaults to `"gsm8k"`.
138+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
138139
num_envs (int, optional): The number of environments to create. Defaults to `1`.
139140
repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
140141
based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
@@ -284,12 +285,13 @@ class GSM8KEnv(DatasetChatEnv):
284285
SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
285286
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
286287
The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively,
287-
i.e., <think>reasoning process here</think> <answer>answer here</answer>."""
288+
i.e., <think>reasoning process here</think> <answer>answer here</answer>. The answer should be a number."""
288289

289290
def __init__(
290291
self,
291292
*,
292293
dataset: str = "gsm8k",
294+
shuffle: bool = True,
293295
num_envs: int = 1,
294296
repeats: int | None = None,
295297
batch_size_dl: int = 1,
@@ -307,6 +309,7 @@ def __init__(
307309
collate_fn = _collate_fn
308310
super().__init__(
309311
dataset=dataset,
312+
shuffle=shuffle,
310313
name="main",
311314
num_envs=num_envs,
312315
repeats=repeats,

torchrl/envs/llm/datasets/ifeval.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class IFEvalEnv(DatasetChatEnv):
4141
4242
Keyword Args:
4343
dataset (str, optional): The name of the dataset. Defaults to `"google/IFeval"`.
44+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
4445
num_envs (int, optional): The number of environments to create. Defaults to `1`.
4546
repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
4647
based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
@@ -146,6 +147,7 @@ def __init__(
146147
self,
147148
*,
148149
dataset: str = "google/IFeval",
150+
shuffle: bool = True,
149151
num_envs: int = 1,
150152
repeats: int | None = None,
151153
batch_size_dl: int = 1,
@@ -163,6 +165,7 @@ def __init__(
163165
collate_fn = _collate_fn
164166
super().__init__(
165167
dataset=dataset,
168+
shuffle=shuffle,
166169
num_envs=num_envs,
167170
repeats=repeats,
168171
batch_size_dl=batch_size_dl,

torchrl/envs/llm/reward/gsm8k.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class GSM8KRewardParser(Transform):
2020
in_keys (list of NestedKey): the input keys. Defaults to `["text_response", "answer"]`.
2121
out_keys (list of NestedKey): the output keys. Defaults to `[ "reward_answer", "reward_think", "reward_right", "reward_contained", "reward", "success"]`.
2222
eos_token (str): the end of sentence token. Defaults to `tokenizer.eos_token` if not provided.
23+
set_done_if_answer (bool): whether to set the done flag to `True` when an answer is present. Defaults to `True`.
2324
2425
"""
2526

@@ -29,10 +30,18 @@ def __init__(
2930
in_keys: list[NestedKey] | None = None,
3031
out_keys: list[NestedKey] | None = None,
3132
eos_token: str | None = None,
33+
set_done_if_answer: bool = True,
3234
):
3335
super().__init__()
3436
self.tokenizer = tokenizer
35-
self.eos_token = eos_token if eos_token is not None else tokenizer.eos_token
37+
self.eos_token = (
38+
eos_token
39+
if eos_token is not None
40+
else tokenizer.eos_token
41+
if tokenizer is not None
42+
else None
43+
)
44+
self.set_done_if_answer = set_done_if_answer
3645
if in_keys is None:
3746
in_keys = ["text_response", "answer"]
3847
if not isinstance(in_keys, list) or len(in_keys) != 2:
@@ -118,7 +127,20 @@ def _step(
118127
tds = tds.add(
119128
next_td_exist, default=torch.zeros((), device=next_tensordict.device)
120129
)
121-
return next_tensordict.update(tds)
130+
next_tensordict = next_tensordict.update(tds)
131+
if (
132+
self.set_done_if_answer
133+
and (reward_answer := (next_tensordict["reward_answer"] > 0)).any()
134+
):
135+
done = next_tensordict.get("done")
136+
if done is not None:
137+
next_tensordict.set("done", reward_answer.view_as(done) | done)
138+
terminated = next_tensordict.get("terminated")
139+
if terminated is not None:
140+
next_tensordict.set(
141+
"terminated", reward_answer.view_as(terminated) | terminated
142+
)
143+
return next_tensordict
122144

123145
def transform_reward_spec(self, reward_spec: Composite) -> Composite:
124146
shape = reward_spec.shape + (1, 1)

torchrl/envs/llm/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .format import TemplateTransform
99
from .kl import KLRewardTransform, RetrieveLogProb
1010
from .policy_version import PolicyVersion
11+
from .reason import AddThinkingPrompt
1112
from .tokenizer import Tokenizer
1213
from .tools import MCPToolTransform, PythonInterpreter
1314

@@ -19,6 +20,7 @@
1920
"MCPToolTransform",
2021
"PolicyVersion",
2122
"PythonInterpreter",
23+
"AddThinkingPrompt",
2224
"TemplateTransform",
2325
"Tokenizer",
2426
"as_nested_tensor",

torchrl/envs/llm/transforms/dataloading.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,18 @@ def __init__(
447447
)
448448
self._reset_key = "_reset"
449449

450+
def reset_dataloader(self):
451+
"""Reset the dataloader.
452+
453+
This is useful when the dataloader is not infinite and we want to reset it.
454+
455+
Returns:
456+
self: The transform itself.
457+
"""
458+
self._queue.clear()
459+
self.endless_dataloader = self._endless_iter(self.dataloader)
460+
return self
461+
450462
@classmethod
451463
def _endless_iter(self, obj):
452464
while True:

0 commit comments

Comments
 (0)