Skip to content

Commit fa7f5fa

Browse files
willccbbsnimu
andauthored
Support branching rollouts via trajectories; refactor state handling (#549)
* big chungus refactor for branching rollouts + cleaner state handling * tests passing * 3.11 fix; ruff * vllm logprob args * dict indexing for messages * remove generateinputs * optional truncation in trajectorystep for tokens * small tweaks * optional decorator rank for sorting order * minor tweak * change rank -> priority * add cleanup to is_completed * tool_env error handling, sandbox command timeout * handle updated context length msg * duplicate is_truncated field * add model/sampling to state * client/model/sampling in init_state * updated config * add kimi overlong prompt message * add kimi overlong prompt message * set_max_seq_len * Add numpy, sympy, and scipy to PythonEnv * pin prime-rl to will/trajectories branch * update prime-rl wiki-search config * ruff, ty * fix init_state tests * version, release notes * env version bumps * ty fixes * opt deps for ty CI * pin trajectories for configs * use verifiers commit for configs * ty for verifiers only * bump vllm version * process overlong prompt into trajectory * skip steps with None tokens --------- Co-authored-by: Sebastian <[email protected]>
1 parent 3f20f0f commit fa7f5fa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+4365
-3428
lines changed

.github/workflows/style.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ jobs:
2929
- name: Set up Python
3030
uses: actions/setup-python@v6
3131
with:
32-
python-version: '3.11'
32+
python-version: "3.11"
3333
- name: Install uv
3434
uses: astral-sh/setup-uv@v4
3535
with:
3636
version: "latest"
3737
- name: Install dependencies
38-
run: uv sync
38+
run: uv sync --extra rl
3939
- name: Run ty
40-
run: uv run ty check .
40+
run: uv run ty check verifiers

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ docs/build/
2929
*.pyc
3030

3131
# libraries
32-
prime-rl/
32+
prime-rl
3333

3434
# outputs
3535
wandb/

configs/prime-rl/wiki-search.toml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
inference_gpu_ids = [0]
2-
trainer_gpu_ids = [1]
1+
inference_gpu_ids = [0,1,2,3,4,5]
2+
trainer_gpu_ids = [6,7]
33

44
max_steps = 500
5+
max_async_level = 4
56

67
[model]
7-
name = "Qwen/Qwen3-4B-Instruct-2507"
8+
name = "Qwen/Qwen3-4B-Thinking-2507"
89

910
[wandb]
10-
project = "wiki-search"
11+
project = "wiki-search-debug"
1112
name = "wiki-search-4b"
1213

1314
[trainer.optim]
@@ -31,16 +32,17 @@ target_modules = [
3132
[orchestrator]
3233
batch_size = 512
3334
rollouts_per_example = 16
34-
seq_len = 4096
35+
seq_len = 16384
3536
mask_truncated_completions = false
3637
zero_truncated_completions = true
38+
oversampling_factor = 2.0
39+
3740

3841
[orchestrator.sampling]
39-
max_tokens = 512
42+
max_tokens = 4096
4043

4144
[orchestrator.buffer]
42-
type = "online-difficulty"
43-
oversampling_factor = 2.0
45+
online_difficulty_filtering = true
4446

4547
[[orchestrator.env]]
4648
id = "primeintellect/wiki-search"

configs/vf-rl/reasoning-gym.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@ num_eval_examples = 2000
1010
seed = 1
1111

1212
[inference]
13-
gpus = 4
14-
tensor_parallel_size = 2
15-
data_parallel_size = 2
13+
gpus = 6
1614
enforce_eager = true
1715

1816
[trainer]
19-
gpus = 4
17+
gpus = 2
2018
batch_size = 512
2119
micro_batch_size = 2
2220
max_seq_len = 4096

configs/vf-rl/wiki-search.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
model = "Qwen/Qwen3-4B-Instruct-2507"
22

33
[env]
4-
id = "primeintellect/wiki-search"
4+
id = "wiki-search"
55

66
[env.args]
77
max_turns = 10
@@ -20,7 +20,7 @@ gpus = 1
2020
run_name = "wiki-search"
2121
micro_batch_size = 4
2222
rollouts_per_example = 16
23-
batch_size = 1024
23+
batch_size = 512
2424
max_steps = 500
2525
max_tokens = 512
2626
max_seq_len = 4096

configs/vf-rl/wordle.toml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
11
model = "Qwen/Qwen3-4B-Instruct-2507"
22

33
[env]
4-
id = "will/wordle"
4+
id = "wordle"
55

66
[inference]
77
gpus = 1
88

9-
[inference.args]
10-
enforce_eager = true
11-
129
[trainer]
1310
gpus = 1
1411

1512
[trainer.args]
1613
lora_target_modules = "all-linear"
1714
run_name = "wordle"
18-
micro_batch_size = 8
15+
micro_batch_size = 4
1916
rollouts_per_example = 16
2017
batch_size = 512
2118
max_steps = 500

environments/math_group/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
[project]
22
name = "math-group"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
dependencies = [
5-
"verifiers>=0.1.4",
5+
"verifiers>=0.1.8",
66
"math-verify>=0.8.0",
77
]
88

environments/math_python/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
name = "math-python"
33
description = "Solve math problems using Python in a sandbox environment"
44
tags = ["tool-use", "math", "sandbox", "train", "prime-sandboxes", "python", "coding"]
5-
version = "0.1.7"
5+
version = "0.1.8"
66
requires-python = ">=3.11"
77
dependencies = [
8-
"verifiers>=0.1.5.post0",
8+
"verifiers>=0.1.8",
99
"math-verify>=0.8.0",
1010
]
1111

environments/sentence_repeater/sentence_repeater.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import random
22
from copy import deepcopy
33
from difflib import SequenceMatcher
4-
from typing import List, Tuple
4+
from typing import List
55

66
from datasets import Dataset, load_dataset
77

@@ -75,19 +75,19 @@ class SentenceRepeaterEnv(vf.MultiTurnEnv):
7575
def __init__(self, **kwargs):
7676
super().__init__(**kwargs)
7777

78-
async def is_completed(self, messages: Messages, state: State, **kwargs) -> bool:
79-
max_turns_reached = await super().is_completed(messages, state, **kwargs)
80-
return state["turn"] >= len(state["info"]["questions"]) or max_turns_reached
78+
@vf.stop
79+
async def all_questions_answered(self, state: State) -> bool:
80+
return len(state["trajectory"]) >= len(state["info"]["questions"])
8181

8282
async def env_response(
8383
self, messages: Messages, state: State, **kwargs
84-
) -> Tuple[Messages, State]:
84+
) -> Messages:
8585
return [
8686
{
8787
"role": "user",
8888
"content": state["info"]["questions"][state["turn"]],
8989
}
90-
], state
90+
]
9191

9292

9393
def load_environment(**kwargs) -> vf.Environment:

environments/wiki_search/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ name = "wiki-search"
33
description = "Agentic RAG over Wikipedia pages for trivia Q&A"
44
tags = ["wikipedia", "multi-turn", "agentic-search", "rag", "train", "eval", "llm-judge"]
55
requires-python = ">=3.11"
6-
version = "0.1.20"
6+
version = "0.1.21"
77
dependencies = [
8-
"verifiers>=0.1.7",
8+
"verifiers>=0.1.8",
99
"chromadb",
1010
"datasets",
1111
"openai",

0 commit comments

Comments
 (0)