Skip to content

Refactor configs #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
06a28a5
Add docs to sampling config and add missing sampling parameters
mikasenghaas Jun 9, 2025
48ed5aa
Extract parallel config
mikasenghaas Jun 10, 2025
557bf83
Extract model config
mikasenghaas Jun 10, 2025
07c30b2
Extract data config
mikasenghaas Jun 10, 2025
1af55b3
Correct var naming from batch_size to max_batch_size
mikasenghaas Jun 10, 2025
a37b579
Extract RL config
mikasenghaas Jun 10, 2025
dc4c786
Delete deprecated configs
mikasenghaas Jun 10, 2025
5e7bbe4
Increment data offset by problems per batch (excluding sampling.n)
mikasenghaas Jun 10, 2025
16b6e8e
Align ckpt/rollout path definition
mikasenghaas Jun 10, 2025
22aa653
Ignore ckpt/rollout dir
mikasenghaas Jun 10, 2025
c15a28b
Adapt configs
mikasenghaas Jun 10, 2025
54e2bca
Fix inference integration tests
mikasenghaas Jun 10, 2025
6836a51
Switch to using annotated fields for configs
mikasenghaas Jun 10, 2025
17c67ce
Improved config logs
mikasenghaas Jun 10, 2025
8660fb1
Migrate inference configs to pydantic-settings
mikasenghaas Jun 10, 2025
ddb24f3
Fix broken model validation
mikasenghaas Jun 10, 2025
910a3bf
Fix integration test
mikasenghaas Jun 10, 2025
98b3061
Skip ge/le checks on non-numeric type
mikasenghaas Jun 10, 2025
6e67ae6
Update README
mikasenghaas Jun 10, 2025
2b2a96a
Move comments into field description
mikasenghaas Jun 10, 2025
79ab026
Use path type
mikasenghaas Jun 10, 2025
129dd4e
Fix tab tab error
mikasenghaas Jun 10, 2025
45edead
Use implicit boolean flags
mikasenghaas Jun 10, 2025
591bec9
Fix funny legacy import error
mikasenghaas Jun 10, 2025
e08821d
Fix unit tests
mikasenghaas Jun 10, 2025
030fdab
Extract default syn-2 configs (allows for composed configs)
mikasenghaas Jun 10, 2025
291b545
Allow passing nested configs from env
mikasenghaas Jun 11, 2025
6105742
Define shared base config to remove pydantic_config dep
mikasenghaas Jun 11, 2025
32a8c3f
The return of the space and snake
mikasenghaas Jun 11, 2025
61f0039
Move group and task ID and log level into configs
mikasenghaas Jun 11, 2025
2eb80a2
Do not init logger at module level (do not trigger init when loading …
mikasenghaas Jun 11, 2025
f6ff121
Do not enable monitor by default (should be set by CLI arg in orchest…
mikasenghaas Jun 11, 2025
9c1da62
Setup task ID via arg
mikasenghaas Jun 11, 2025
bbbc71f
Fix monitor
mikasenghaas Jun 11, 2025
f8517b0
Allow passing config via env using __ delimiter
mikasenghaas Jun 11, 2025
76a7e64
Use clean argv fixture in inference config test
mikasenghaas Jun 11, 2025
76ba9d6
Explain configs in README
mikasenghaas Jun 11, 2025
b50c327
Fix bug from rebase
mikasenghaas Jun 11, 2025
379c1fd
Skip doing hex string validation
mikasenghaas Jun 11, 2025
8912cac
Optionally parse from CLI and move toml extraction away from module def
mikasenghaas Jun 12, 2025
b12b5ac
Move training configs to pydantic-settings
mikasenghaas Jun 12, 2025
521e23e
Update README
mikasenghaas Jun 12, 2025
34c0d6b
Remove pydanctic_config from deps
mikasenghaas Jun 12, 2025
31b526a
Move config description to end of file
mikasenghaas Jun 12, 2025
960af2e
Remove descriptions
mikasenghaas Jun 12, 2025
f2da109
Remove sentence
mikasenghaas Jun 12, 2025
4bcb407
location of logprobs has changed
Jackmin801 Jun 12, 2025
e64b3ae
Fix forma task configs
mikasenghaas Jun 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ datasets/*
output/*
outputs/*
data/*
data_rollout/*
rollouts/*
checkpoints/*
*.ipynb

# Byte-compiled / optimized / DLL files
Expand Down
85 changes: 62 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ If you have 2 GPUs, run the following commands:
# Start inference worker
export CUDA_VISIBLE_DEVICES=0
export VLLM_WORKER_MULTIPROC_METHOD=spawn
uv run python src/zeroband/infer.py @ configs/inference/simple_math.toml --dp 1 --batch-size 512
uv run python src/zeroband/infer.py @ configs/inference/simple_math.toml --parallel.dp 1 --max-batch-size 512
```

```bash
Expand All @@ -87,7 +87,7 @@ If you have 4 GPUs, run the following commands:
# Start inference workers
export CUDA_VISIBLE_DEVICES=0,1
export VLLM_WORKER_MULTIPROC_METHOD=spawn
uv run python src/zeroband/infer.py @ configs/inference/simple_math.toml --dp 2 --batch-size 256
uv run python src/zeroband/infer.py @ configs/inference/simple_math.toml --parallel.dp 2 --max-batch-size 256
```

```bash
Expand Down Expand Up @@ -148,34 +148,30 @@ CUDA_VISIBLE_DEVICES=0 uv run python src/zeroband/infer.py @ configs/inference/d
Only TP (TP=2, PP=1, DP=1, *requires 2 GPUs*)

```bash
CUDA_VISIBLE_DEVICES=0,1 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --tp 2
CUDA_VISIBLE_DEVICES=0,1 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --parallel.tp 2
```

Only DP (DP=2, TP=1, PP=1, *requires 2 GPUs*)

```bash
CUDA_VISIBLE_DEVICES=0,1 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --dp 2
CUDA_VISIBLE_DEVICES=0,1 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --parallel.dp 2
```

Only PP (DP=1, TP=1, PP=2, *requires 2 GPUs*)

```bash
# Node 1
CUDA_VISIBLE_DEVICES=0 uv run python src/zeroband/infer.py @ configs/inference/debug.toml \
--pp.rank 0 \
--pp.world-size 2 \
--pp.iroh-seed 0 \
--pp.iroh-peer-id ff87a0b0a3c7c0ce827e9cada5ff79e75a44a0633bfcb5b50f99307ddb26b337 \
--parallel.pp.rank 0 \
--parallel.pp.world-size 2 \
--seed 69
```

```bash
# Node 2
CUDA_VISIBLE_DEVICES=1 uv run python src/zeroband/infer.py @ configs/inference/debug.toml \
--pp.rank 1 \
--pp.world-size 2 \
--pp.iroh-seed 1 \
--pp.iroh-peer-id ee1aa49a4459dfe813a3cf6eb882041230c7b2558469de81f87c9bf23bf10a03 \
--parallel.pp.rank 1 \
--parallel.pp.world-size 2 \
--seed 69
```

Expand All @@ -184,30 +180,26 @@ CUDA_VISIBLE_DEVICES=1 uv run python src/zeroband/infer.py @ configs/inference/d
DP+TP (DP=2, TP=2, PP=1, *requires 4 GPUs*)

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --dp 2 --tp auto
CUDA_VISIBLE_DEVICES=0,1,2,3 uv run python src/zeroband/infer.py @ configs/inference/debug.toml --parallel.dp 2 --parallel.tp auto
```

PP+TP (DP=1, TP=2, PP=2, *requires 4 GPUs*)

```bash
# Node 1
CUDA_VISIBLE_DEVICES=0,1 uv run python src/zeroband/infer.py @ configs/inference/debug.toml \
--tp auto \
--pp.rank 0 \
--pp.world-size 2 \
--pp.iroh-seed 0 \
--pp.iroh-peer-id ff87a0b0a3c7c0ce827e9cada5ff79e75a44a0633bfcb5b50f99307ddb26b337 \
--parallel.tp auto \
--parallel.pp.rank 0 \
--parallel.pp.world-size 2 \
--seed 69
```

```bash
# Node 2
CUDA_VISIBLE_DEVICES=2,3 uv run python src/zeroband/infer.py @ configs/inference/debug.toml \
--tp auto \
--pp.rank 1 \
--pp.world-size 2 \
--pp.iroh-seed 1 \
--pp.iroh-peer-id ee1aa49a4459dfe813a3cf6eb882041230c7b2558469de81f87c9bf23bf10a03 \
--parallel.tp auto \
--parallel.pp.rank 1 \
--parallel.pp.world-size 2 \
--seed 69
```

Expand Down Expand Up @@ -247,6 +239,53 @@ To run fast tests, use the inverse of the `slow` marker:
uv run pytest -v -m "not slow"
```

## Configs

We use `pydantic-settings` to configure `prime-rl`. To get an overview of the available configurations, run the following command:

```bash
uv run python src/zeroband/train.py --help
```

```bash
uv run python src/zeroband/infer.py --help
```

### Sources

We support the following sources for configuration, in this order of precedence:

1. **Command-line arguments**: You can pass (nested) arguments as `--key.subkey value` to the script. For example, to set the model name you can run `--model.name`

2. **Config files**: You can pass `.toml` config files (defined in the `configs` directory) using the `@` prefix. For example, to use the `debug.toml` config file, you can run `uv run python src/zeroband/infer.py @ configs/inference/debug.toml`. (*If you leave a space between the `@` and the config file, you will get shell path auto-completions.*)

3. **Environment variables**: You can set environment variables to override the config values. All environment variables must be prefixed with `PRIME_` and use the `__` delimiter to nest the keys. For example, to set the model name you can run `export PRIME_MODEL__NAME=Qwen/Qwen3-0.6B`.

4. **Defaults**: For almost all config arguments, we have a default value which will be used if no other source is provided.

In general we recommend setting configurations via config files to define reproducible experiments and use command-line arguments to override the config values to run variants of the same experiment. Environment variables are usually only used in production settings to communicate with the [Prime Protocol](https://github.com/PrimeIntellect-ai/protocol) worker. In most cases, you should not need to use environment variables.

The precendence order will be important if multiple sources try to configure the same argument. For example, in the following command, all sources will define a model name

```toml
# qwen8b.toml
[model]
name = "Qwen/Qwen3-8B"
```

```toml
# qwen14b.toml
[model]
name = "Qwen/Qwen-14B"
```

```bash
PRIME_MODEL__NAME=Qwen/Qwen3-4B uv run src/zeroband/infer.py @qwen8b.toml @qwen14b.toml --model.name Qwen/Qwen3-32B
```

In this example, the CLI argument `--model.name Qwen/Qwen3-32B` will take precendence and the script will use `Qwen/Qwen3-32B` as the model name. If the CLI argument wasn't set, then the second config file would take precedence and the script would use `Qwen/Qwen-14B` as the model name. If the second config file wasn't set, then the first config file would take precedence and the script would use `Qwen/Qwen3-8B` as the model name. Finally, if the first config file wasn't set, then the environment variable would take precedence and the script would use `Qwen/Qwen-4B` as the model name. If the environment variable wasn't set, then the default value would be used and the script would use `Qwen/Qwen3-0.6B` as the model name.


## Citation

If you find `prime-rl` useful, feel free to cite our work:
Expand Down
11 changes: 6 additions & 5 deletions configs/inference/debug.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
model_name = "Qwen/Qwen3-0.6B"
dataset = "PrimeIntellect/INTELLECT-2-RL-Dataset"
total_step = 2
batch_size = 16
enforce_eager = true
max_steps = 2
max_batch_size = 16
clean_rollout_path = true
toploc = true

[model]
enforce_eager = true

[sampling]
max_tokens = 16
n = 2
21 changes: 15 additions & 6 deletions configs/inference/deepscaler.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
batch_size = 352
dp = 6
rollout_path = "outputs"
output_path = "data_rollout"
max_batch_size = 352
clean_rollout_path = true

[model]
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
max_model_len = 2048
dataset = "justus27/deepscaler-math-genesys-format"

[data]
name = "justus27/deepscaler-math-genesys-format"

[parallel]
dp = 6

[rl]
ckpt_path = "checkpoints"
clean_ckpt_path = true
18 changes: 11 additions & 7 deletions configs/inference/formatask.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
dataset = "kalomaze/general-formatask-it2"
batch_size = 256
dp = 4
rollout_path = "outputs"
clean_output_path = true
output_path = "data_rollout"
max_batch_size = 256
clean_rollout_path = true

[model]
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
max_model_len = 2048

[data]
name = "kalomaze/general-formatask-it2"

[parallel]
dp = 4

[sampling]
temperature = 1.0
n = 16
26 changes: 18 additions & 8 deletions configs/inference/intellect2_qwen_distill.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
dataset = "PrimeIntellect/INTELLECT-2-RL-Dataset"
max_batch_size = 256
clean_rollout_path = true

batch_size = 256
tp = "auto"
[model]
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
max_model_len = 16000

[sampling]
n = 32
[data]
name = "PrimeIntellect/INTELLECT-2-RL-Dataset"

[difficulty_filtering]
[data.difficulty_filtering]
solve_rate_field = "solve_rate_qwen_r1_distill_7b"
min_solve_rate = 0.0
max_solve_rate = 0.7
max_solve_rate = 0.7

[parallel]
tp = "auto"

[sampling]
n = 32

[rl]
ckpt_path = "checkpoints"
clean_ckpt_path = true
31 changes: 21 additions & 10 deletions configs/inference/intellect2_qwq.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
model_name = "Qwen/QwQ-32B"
dataset = "PrimeIntellect/INTELLECT-2-RL-Dataset"
batch_size = 96
tp = "auto"
max_batch_size = 96
clean_rollout_path = true

[model]
name = "Qwen/QwQ-32B"
max_model_len = 16_000

[sampling]
n = 16
[data]
name = "PrimeIntellect/INTELLECT-2-RL-Dataset"

[data.difficulty_filtering]
solve_rate_field = "solve_rate_qwen_r1_distill_7b"
min_solve_rate = 0.1
max_solve_rate = 0.8

[rewards.len_reward]
reward_type = "exact"
length_prompt_location = "instruction"
target_lengths = [2000, 4000, 6000, 8000, 10000]

[difficulty_filtering]
solve_rate_field = "solve_rate_qwen_r1_distill_7b"
min_solve_rate = 0.1
max_solve_rate = 0.8
[parallel]
tp = "auto"

[sampling]
n = 16

[rl]
ckpt_path = "checkpoints"
clean_ckpt_path = true
21 changes: 15 additions & 6 deletions configs/inference/simple_ascii_tree.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
dataset = "kalomaze/ascii-tree-mix-it1"
batch_size = 256
dp = 4
rollout_path = "outputs"
output_path = "data_rollout"
max_batch_size = 256
clean_rollout_path = true

[model]
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
max_model_len = 2048

[data]
name = "kalomaze/ascii-tree-mix-it1"

[parallel]
dp = 4

[rl]
ckpt_path = "checkpoints"
clean_ckpt_path = true

[sampling]
temperature = 1.0
n = 16
21 changes: 14 additions & 7 deletions configs/inference/simple_math.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
dataset = "justus27/math-hendrycks-genesys-format"
batch_size = 112
clean_rollout_path = true

[model]
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
max_model_len = 2048

[data]
name = "justus27/math-hendrycks-genesys-format"

[parallel]
dp = 6
rollout_path = "outputs"
output_path = "data_rollout"
clean_output_path = true
max_model_len = 2048

[rl]
ckpt_path = "checkpoints"
clean_ckpt_path = true
21 changes: 15 additions & 6 deletions configs/inference/simple_multitask.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
dataset = "kalomaze/multi-task-ascii-unscramble-it1"
batch_size = 256
dp = 4
rollout_path = "outputs"
output_path = "data_rollout"
max_batch_size = 256
clean_rollout_path = true

[model]
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
max_model_len = 2048

[data]
name = "kalomaze/multi-task-ascii-unscramble-it1"

[parallel]
dp = 4

[rl]
ckpt_path = "checkpoints"
clean_ckpt_path = true

[sampling]
temperature = 1.0
n = 16
Loading
Loading