Skip to content

Commit d9e7603

Browse files
Kovboangkywilliamclaude
authored
WIP: SFT (local backend) (#530)
* SFT data iterator * Add SFT LR utils * train_sft skeleton * SFT Shape 0.1 * Add shuffle to SFTConfig * change SFT args order * Refactor SFT to accept batched trajectories Move batching and shuffling logic from SFTConfig into iterator functions. train_sft now accepts Iterable[List[Trajectory]] instead of individual trajectories, simplifying the API and making batch management more explicit. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * Tokenize SFT Batch * Add num_trainable_tokens to SFTBatch * draft train_sft * Flatten trajectory for train_sft * Tokenize SFT Batches support flat list and add padding * Fix max_length duplicate name issue * Remove unused file * remove unused typing * sft iterator * SFT Iterator * Use Unsloth for train on response * refactoring * implement local backend SFT training * Add SFT to Local Backend * avg loss * refactor, sft works good * remove logging * move tokenizer, update backend * update lr schedule and tests * refactor sft training from file * change batch sft * refactor step count based on checkpoints * update sft warmup script * fix model registration * make local random * refactor backend * refactor * update example * Pyright fix * remove iterate file epochs, refactor * refactor * add serverless endpoint * Rename training_folder_url to training_data_url * update defaults, change reporting * update lables * make sft to produce only one checkpoint step * refactor train from file * refactor * Refactor SFTTrainConfig * refactor * correctly register lora, fix unsloth proxy check * add sft train from file streaming * add openpipe qwen back * lint fix * calculate pbar * rename to training_data_url * accept model run_id from server * update optimizer hparams * add claude command * remove queue, add skills * add docs and colab example * move zero_grad * add final step arg * update docs * update docs and trajectories * lint fix * add cli skills * add chunking * lint fix * remove inline trajectories from skills * update chunking * change default chunk to 10 * remove leftovers --------- Co-authored-by: Angky William <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent 0b1dc43 commit d9e7603

33 files changed

+3121
-272
lines changed

.agents/skills/train-rl/SKILL.md

Lines changed: 386 additions & 0 deletions
Large diffs are not rendered by default.

.agents/skills/train-sft/SKILL.md

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
---
2+
name: train-sft
3+
description: SFT training reference for the ART framework. Use when the user asks to create, write, or help with an SFT training script, fine-tune a model, train from a JSONL dataset, do distillation, or anything related to supervised fine-tuning.
4+
---
5+
6+
# SFT Training Wizard
7+
8+
You are guiding the user through setting up Supervised Fine-Tuning (SFT) for a language model using the ART framework. Act as an interactive wizard: ask questions, validate inputs, and generate a complete runnable script.
9+
10+
**Important**: Ask ONE question at a time. Wait for the user's response before asking the next question. Never bundle multiple questions into a single message.
11+
12+
**Adaptability note**: Some steps reference tools like AskUserQuestion, Glob, or Bash. If you don't have access to these tools, simply ask the user the same questions as plain text and skip any steps that require running code (e.g., file search, dataset validation, hyperparameter computation). Do NOT fabricate results — never pretend you ran a tool or searched for files when you didn't.
13+
14+
## Step 1: Determine Training Scenario
15+
16+
Ask the user ONE question at a time. Wait for their response before moving to the next question.
17+
18+
**Training scenario:**
19+
1. **Train from a JSONL file** — They have a dataset file with chat-formatted examples
20+
2. **Distillation** — They want to train a smaller model using outputs from a larger teacher model
21+
22+
## Step 2: Determine Backend
23+
24+
**Backend:**
25+
1. **ServerlessBackend (Recommended)** — Train on remote managed GPUs. No local GPU needed, production-ready inference endpoint.
26+
2. **LocalBackend** — Train on your local GPU. Full control, fast iteration.
27+
28+
## Step 3: Select and Validate Dataset (JSONL scenario)
29+
30+
**IMPORTANT**: Do NOT assume a dataset. Do NOT make up or hallucinate file paths. Never pretend you searched for files if you didn't actually run a search tool.
31+
32+
If you have access to file system tools (Glob) and can actually execute them, search for `.jsonl` files using Glob (`**/*.jsonl`). Present real results as options. Always include "Provide my own file path" as the last option.
33+
34+
Otherwise, ask the user: "What is the path to your JSONL training file?" — nothing more.
35+
36+
Once the user has provided a file path, validate it if you can run code using the script below. If you cannot run code, skip validation and move on.
37+
38+
```python
39+
import json, sys
40+
ROLES = {"system", "user", "assistant", "developer", "tool", "function"}
41+
errors = []
42+
for i, line in enumerate(open(sys.argv[1]), 1):
43+
try:
44+
r = json.loads(line)
45+
msgs = r.get("input", r).get("messages", [])
46+
assert isinstance(msgs, list) and msgs, "no messages"
47+
for j, m in enumerate(msgs):
48+
assert m.get("role") in ROLES, f"messages[{j}]: invalid role {m.get('role')!r}"
49+
assert m.get("content") or m.get("function_call") or m.get("tool_calls"), f"messages[{j}]: no content"
50+
if "input" not in r:
51+
assert msgs[-1]["role"] == "assistant", "last message must be from assistant"
52+
tools = r.get("tools")
53+
if tools is not None:
54+
assert isinstance(tools, list), "tools must be a list"
55+
except Exception as e:
56+
errors.append(f" Line {i}: {e}")
57+
print(f"{len(errors)} error(s):\n" + "\n".join(errors) if errors else f"Valid! {i} rows")
58+
sys.exit(1 if errors else 0)
59+
```
60+
61+
The JSONL format supports these fields per row:
62+
- **`messages`** (required): List of chat messages
63+
- **`tools`** (optional): List of tool/function definitions for tool-call training
64+
- **`response_format`** (optional): Structured output schema (not used during training, but useful as metadata)
65+
66+
Report the row count and validation result to the user. Do NOT read the whole dataset file. Do NOT name the dataset. If the format is wrong, help them fix it or convert their data.
67+
68+
## Step 4: Gather Base Parameters
69+
70+
Do NOT ask the user to review or confirm their answers after collecting them — just proceed to the next step.
71+
72+
- **Base model**: Recommend ONLY these models:
73+
- `OpenPipe/Qwen3-14B-Instruct`
74+
- `Qwen/Qwen3-30B-A3B-Instruct-2507`
75+
- `meta-llama/Llama-3.1-8B-Instruct`
76+
- **Project name**: A name for this training project (default: `sft-project`)
77+
- **Run name**: A static, descriptive name (e.g., `agent-001`, `pii-redactor-001`, `math-tutor-001`). Ask the user for a meaningful name. Do NOT generate random names.
78+
79+
For **distillation** also ask:
80+
- **Teacher model**: The larger model to distill from (e.g., an OpenRouter model)
81+
- **Teacher API base URL and key**: If using a third-party provider
82+
- **Prompts**: What prompts to send to the teacher model
83+
84+
## Step 5: Gather Hyperparameters
85+
86+
This step only applies if you can run code AND know the row count from validation. If you cannot run code, skip this step entirely — do NOT make up or guess hyperparameter values. The `train_sft_from_file` function has sensible built-in defaults.
87+
88+
Run this Python snippet via Bash to compute defaults (replace `NUM_ROWS` with the actual row count). Do NOT show any formulas or calculation steps to the user — only show the final values.
89+
90+
```python
91+
import math, sys
92+
n = int(sys.argv[1])
93+
epochs = max(1, min(10, round(10000 / n)))
94+
batch_size = 2
95+
total_steps = math.ceil(n * epochs / batch_size)
96+
steps_per_epoch = math.ceil(n / batch_size)
97+
warmup_steps = max(10, min(1000, round(steps_per_epoch * 0.05)))
98+
warmup_ratio = round(warmup_steps / total_steps, 4)
99+
print(f"epochs={epochs} batch_size={batch_size} lr=2e-4 schedule=linear warmup_ratio={warmup_ratio}")
100+
```
101+
102+
Present the output values to the user, then ask:
103+
- **Use defaults (Recommended)** — show all values in the description
104+
- **Customize** — adjust individual hyperparameters
105+
106+
If they choose "Customize", ask which parameters to change.
107+
108+
### For distillation:
109+
Use the same defaults computation as JSONL (replace `NUM_ROWS` with the number of trajectories). `create_sft_dataset_iterator` handles the LR schedule automatically.
110+
111+
## Step 6: Generate the Training Script
112+
113+
Write a complete, runnable Python script. Use the patterns below. Every script MUST:
114+
- Call `await backend.close()` at the end so the process doesn't hang
115+
- Print post-training info and usage examples (see shared block below)
116+
117+
### Post-training block (append to ALL scripts before `backend.close()`):
118+
```python
119+
# --- Training complete ---
120+
step = await model.get_step()
121+
inference_name = model.get_inference_name()
122+
client = model.openai_client()
123+
124+
print("\n" + "=" * 60)
125+
print("SFT TRAINING COMPLETE")
126+
print("=" * 60)
127+
print(f" Model: {inference_name}")
128+
print(f" Base model: <BASE_MODEL>")
129+
print(f" Training step: {step}")
130+
print(f" Inference URL: {client.base_url}")
131+
print(f" W&B run: https://wandb.ai/<YOUR_TEAM>/<PROJECT_NAME>/runs/<RUN_NAME>")
132+
print("=" * 60)
133+
134+
print("\n--- Python usage (openai SDK) ---\n")
135+
print(f'''\
136+
from openai import OpenAI
137+
138+
client = OpenAI(
139+
base_url="{client.base_url}",
140+
api_key="not-needed",
141+
)
142+
143+
response = client.chat.completions.create(
144+
model="{inference_name}",
145+
messages=[
146+
{{"role": "user", "content": "Your prompt here"}},
147+
],
148+
)
149+
print(response.choices[0].message.content)
150+
''')
151+
152+
print("--- curl usage ---\n")
153+
print(f'''\
154+
curl {client.base_url}chat/completions \\
155+
-H "Content-Type: application/json" \\
156+
-d '{{
157+
"model": "{inference_name}",
158+
"messages": [
159+
{{"role": "user", "content": "Your prompt here"}}
160+
]
161+
}}'
162+
''')
163+
164+
await backend.close()
165+
```
166+
167+
### Backend setup
168+
169+
Use the appropriate backend based on the user's choice:
170+
171+
**LocalBackend:**
172+
```python
173+
from art.local import LocalBackend
174+
175+
backend = LocalBackend()
176+
model = art.TrainableModel(
177+
name="<RUN_NAME>",
178+
project="<PROJECT_NAME>",
179+
base_model="<BASE_MODEL>",
180+
_internal_config=art.dev.InternalModelConfig(
181+
engine_args={"gpu_memory_utilization": 0.7},
182+
),
183+
)
184+
await model.register(backend)
185+
```
186+
187+
**ServerlessBackend:**
188+
```python
189+
from art.serverless.backend import ServerlessBackend
190+
191+
backend = ServerlessBackend() # uses WANDB_API_KEY env var
192+
model = art.TrainableModel(
193+
name="<RUN_NAME>",
194+
project="<PROJECT_NAME>",
195+
base_model="<BASE_MODEL>",
196+
)
197+
await model.register(backend)
198+
```
199+
200+
Note: `_internal_config` with `gpu_memory_utilization` is only used with LocalBackend. Do NOT include it for ServerlessBackend.
201+
202+
### JSONL file training pattern:
203+
204+
If hyperparameters were computed in Step 5, pass them explicitly. If Step 5 was skipped, omit them — `train_sft_from_file` has sensible defaults.
205+
206+
```python
207+
"""SFT training script generated by /train-sft wizard."""
208+
import asyncio
209+
import art
210+
<BACKEND_IMPORT>
211+
from art.utils.sft import train_sft_from_file
212+
213+
async def main():
214+
<BACKEND_SETUP>
215+
216+
await train_sft_from_file(
217+
model=model,
218+
file_path="<FILE_PATH>",
219+
# Only include these if hyperparameters were computed:
220+
# epochs=<EPOCHS>,
221+
# batch_size=<BATCH_SIZE>,
222+
# peak_lr=<PEAK_LR>,
223+
# schedule_type="<SCHEDULE_TYPE>",
224+
# warmup_ratio=<WARMUP_RATIO>,
225+
verbose=True,
226+
)
227+
228+
# ... post-training block + backend.close() ...
229+
230+
if __name__ == "__main__":
231+
asyncio.run(main())
232+
```
233+
234+
### Distillation pattern:
235+
```python
236+
"""Distillation SFT script generated by /train-sft wizard."""
237+
import asyncio, os
238+
from dotenv import load_dotenv
239+
from openai import AsyncOpenAI
240+
import art
241+
<BACKEND_IMPORT>
242+
from art.utils.sft import create_sft_dataset_iterator
243+
244+
load_dotenv()
245+
246+
async def main():
247+
teacher_client = AsyncOpenAI(
248+
api_key=os.environ["<API_KEY_ENV_VAR>"],
249+
base_url="<TEACHER_API_BASE>",
250+
)
251+
prompts = ["<PROMPT_1>", "<PROMPT_2>"]
252+
253+
trajectories = []
254+
for prompt in prompts:
255+
completion = await teacher_client.chat.completions.create(
256+
model="<TEACHER_MODEL>",
257+
messages=[{"role": "user", "content": prompt}],
258+
)
259+
trajectories.append(
260+
art.Trajectory(
261+
messages_and_choices=[
262+
{"role": "user", "content": prompt},
263+
{"role": "assistant", "content": completion.choices[0].message.content},
264+
],
265+
tools=<TOOLS_OR_NONE>,
266+
)
267+
)
268+
269+
<BACKEND_SETUP>
270+
271+
for chunk in create_sft_dataset_iterator(
272+
trajectories,
273+
epochs=<EPOCHS>,
274+
batch_size=<BATCH_SIZE>,
275+
peak_lr=<PEAK_LR>,
276+
schedule_type="<SCHEDULE_TYPE>",
277+
warmup_ratio=<WARMUP_RATIO>,
278+
):
279+
await model.train_sft(chunk.trajectories, chunk.config, verbose=True)
280+
281+
# ... post-training block + backend.close() ...
282+
283+
if __name__ == "__main__":
284+
asyncio.run(main())
285+
```
286+
287+
## Step 7: Write and Offer to Run
288+
289+
1. Write the script to a file (suggest `sft_train.py`)
290+
2. Ask the user if they want to run it now with `uv run python <script_path>`
291+
3. If yes, run it **directly using the Bash tool** (do NOT delegate to a Task subagent) so training logs stream live to the user. Use a **2-minute timeout**. If it times out, check progress and decide whether to continue.
292+
4. **LocalBackend only — GPU memory errors**: If training fails with OOM, lower `gpu_memory_utilization` in the existing `_internal_config` (e.g. from `0.7` to `0.5`).
293+
5. **LocalBackend only — Stale GPU memory**: If available GPU memory looks too small, previous training runs may still be occupying memory. Before retrying, run `nvidia-smi` to check, and if needed kill leftover processes with `kill <pid>` to free memory.
294+
295+
## Important Notes
296+
297+
- LocalBackend requires a GPU.
298+
- ServerlessBackend requires a `WANDB_API_KEY` environment variable.

.claude/skills

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../.agents/skills

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ replays/
1515
trajectories/
1616
.DS_Store
1717
.local/
18-
.claude/
18+
.claude/settings.local.json
1919
.vscode/
2020
.ruff_cache/
2121
!/src/art/wandb/

0 commit comments

Comments
 (0)