|
| 1 | +--- |
| 2 | +name: train-sft |
| 3 | +description: SFT training reference for the ART framework. Use when the user asks to create, write, or help with an SFT training script, fine-tune a model, train from a JSONL dataset, do distillation, or anything related to supervised fine-tuning. |
| 4 | +--- |
| 5 | + |
| 6 | +# SFT Training Wizard |
| 7 | + |
| 8 | +You are guiding the user through setting up Supervised Fine-Tuning (SFT) for a language model using the ART framework. Act as an interactive wizard: ask questions, validate inputs, and generate a complete runnable script. |
| 9 | + |
| 10 | +**Important**: Ask ONE question at a time. Wait for the user's response before asking the next question. Never bundle multiple questions into a single message. |
| 11 | + |
| 12 | +**Adaptability note**: Some steps reference tools like AskUserQuestion, Glob, or Bash. If you don't have access to these tools, simply ask the user the same questions as plain text and skip any steps that require running code (e.g., file search, dataset validation, hyperparameter computation). Do NOT fabricate results — never pretend you ran a tool or searched for files when you didn't. |
| 13 | + |
| 14 | +## Step 1: Determine Training Scenario |
| 15 | + |
| 16 | +Ask the user ONE question at a time. Wait for their response before moving to the next question. |
| 17 | + |
| 18 | +**Training scenario:** |
| 19 | +1. **Train from a JSONL file** — They have a dataset file with chat-formatted examples |
| 20 | +2. **Distillation** — They want to train a smaller model using outputs from a larger teacher model |
| 21 | + |
| 22 | +## Step 2: Determine Backend |
| 23 | + |
| 24 | +**Backend:** |
| 25 | +1. **ServerlessBackend (Recommended)** — Train on remote managed GPUs. No local GPU needed, production-ready inference endpoint. |
| 26 | +2. **LocalBackend** — Train on your local GPU. Full control, fast iteration. |
| 27 | + |
| 28 | +## Step 3: Select and Validate Dataset (JSONL scenario) |
| 29 | + |
| 30 | +**IMPORTANT**: Do NOT assume a dataset. Do NOT make up or hallucinate file paths. Never pretend you searched for files if you didn't actually run a search tool. |
| 31 | + |
| 32 | +If you have access to file system tools (Glob) and can actually execute them, search for `.jsonl` files using Glob (`**/*.jsonl`). Present real results as options. Always include "Provide my own file path" as the last option. |
| 33 | + |
| 34 | +Otherwise, ask the user: "What is the path to your JSONL training file?" — nothing more. |
| 35 | + |
| 36 | +Once the user has provided a file path, validate it if you can run code using the script below. If you cannot run code, skip validation and move on. |
| 37 | + |
| 38 | +```python |
| 39 | +import json, sys |
| 40 | +ROLES = {"system", "user", "assistant", "developer", "tool", "function"} |
| 41 | +errors = [] |
| 42 | +for i, line in enumerate(open(sys.argv[1]), 1): |
| 43 | + try: |
| 44 | + r = json.loads(line) |
| 45 | + msgs = r.get("input", r).get("messages", []) |
| 46 | + assert isinstance(msgs, list) and msgs, "no messages" |
| 47 | + for j, m in enumerate(msgs): |
| 48 | + assert m.get("role") in ROLES, f"messages[{j}]: invalid role {m.get('role')!r}" |
| 49 | + assert m.get("content") or m.get("function_call") or m.get("tool_calls"), f"messages[{j}]: no content" |
| 50 | + if "input" not in r: |
| 51 | + assert msgs[-1]["role"] == "assistant", "last message must be from assistant" |
| 52 | + tools = r.get("tools") |
| 53 | + if tools is not None: |
| 54 | + assert isinstance(tools, list), "tools must be a list" |
| 55 | + except Exception as e: |
| 56 | + errors.append(f" Line {i}: {e}") |
| 57 | +print(f"{len(errors)} error(s):\n" + "\n".join(errors) if errors else f"Valid! {i} rows") |
| 58 | +sys.exit(1 if errors else 0) |
| 59 | +``` |
| 60 | + |
| 61 | +The JSONL format supports these fields per row: |
| 62 | +- **`messages`** (required): List of chat messages |
| 63 | +- **`tools`** (optional): List of tool/function definitions for tool-call training |
| 64 | +- **`response_format`** (optional): Structured output schema (not used during training, but useful as metadata) |
| 65 | + |
| 66 | +Report the row count and validation result to the user. Do NOT read the whole dataset file. Do NOT name the dataset. If the format is wrong, help them fix it or convert their data. |
| 67 | + |
| 68 | +## Step 4: Gather Base Parameters |
| 69 | + |
| 70 | +Do NOT ask the user to review or confirm their answers after collecting them — just proceed to the next step. |
| 71 | + |
| 72 | +- **Base model**: Recommend ONLY these models: |
| 73 | + - `OpenPipe/Qwen3-14B-Instruct` |
| 74 | + - `Qwen/Qwen3-30B-A3B-Instruct-2507` |
| 75 | + - `meta-llama/Llama-3.1-8B-Instruct` |
| 76 | +- **Project name**: A name for this training project (default: `sft-project`) |
| 77 | +- **Run name**: A static, descriptive name (e.g., `agent-001`, `pii-redactor-001`, `math-tutor-001`). Ask the user for a meaningful name. Do NOT generate random names. |
| 78 | + |
| 79 | +For **distillation** also ask: |
| 80 | +- **Teacher model**: The larger model to distill from (e.g., an OpenRouter model) |
| 81 | +- **Teacher API base URL and key**: If using a third-party provider |
| 82 | +- **Prompts**: What prompts to send to the teacher model |
| 83 | + |
| 84 | +## Step 5: Gather Hyperparameters |
| 85 | + |
| 86 | +This step only applies if you can run code AND know the row count from validation. If you cannot run code, skip this step entirely — do NOT make up or guess hyperparameter values. The `train_sft_from_file` function has sensible built-in defaults. |
| 87 | + |
| 88 | +Run this Python snippet via Bash to compute defaults (replace `NUM_ROWS` with the actual row count). Do NOT show any formulas or calculation steps to the user — only show the final values. |
| 89 | + |
| 90 | +```python |
| 91 | +import math, sys |
| 92 | +n = int(sys.argv[1]) |
| 93 | +epochs = max(1, min(10, round(10000 / n))) |
| 94 | +batch_size = 2 |
| 95 | +total_steps = math.ceil(n * epochs / batch_size) |
| 96 | +steps_per_epoch = math.ceil(n / batch_size) |
| 97 | +warmup_steps = max(10, min(1000, round(steps_per_epoch * 0.05))) |
| 98 | +warmup_ratio = round(warmup_steps / total_steps, 4) |
| 99 | +print(f"epochs={epochs} batch_size={batch_size} lr=2e-4 schedule=linear warmup_ratio={warmup_ratio}") |
| 100 | +``` |
| 101 | + |
| 102 | +Present the output values to the user, then ask: |
| 103 | +- **Use defaults (Recommended)** — show all values in the description |
| 104 | +- **Customize** — adjust individual hyperparameters |
| 105 | + |
| 106 | +If they choose "Customize", ask which parameters to change. |
| 107 | + |
| 108 | +### For distillation: |
| 109 | +Use the same defaults computation as JSONL (replace `NUM_ROWS` with the number of trajectories). `create_sft_dataset_iterator` handles the LR schedule automatically. |
| 110 | + |
| 111 | +## Step 6: Generate the Training Script |
| 112 | + |
| 113 | +Write a complete, runnable Python script. Use the patterns below. Every script MUST: |
| 114 | +- Call `await backend.close()` at the end so the process doesn't hang |
| 115 | +- Print post-training info and usage examples (see shared block below) |
| 116 | + |
| 117 | +### Post-training block (append to ALL scripts before `backend.close()`): |
| 118 | +```python |
| 119 | + # --- Training complete --- |
| 120 | + step = await model.get_step() |
| 121 | + inference_name = model.get_inference_name() |
| 122 | + client = model.openai_client() |
| 123 | + |
| 124 | + print("\n" + "=" * 60) |
| 125 | + print("SFT TRAINING COMPLETE") |
| 126 | + print("=" * 60) |
| 127 | + print(f" Model: {inference_name}") |
| 128 | + print(f" Base model: <BASE_MODEL>") |
| 129 | + print(f" Training step: {step}") |
| 130 | + print(f" Inference URL: {client.base_url}") |
| 131 | + print(f" W&B run: https://wandb.ai/<YOUR_TEAM>/<PROJECT_NAME>/runs/<RUN_NAME>") |
| 132 | + print("=" * 60) |
| 133 | + |
| 134 | + print("\n--- Python usage (openai SDK) ---\n") |
| 135 | + print(f'''\ |
| 136 | +from openai import OpenAI |
| 137 | +
|
| 138 | +client = OpenAI( |
| 139 | + base_url="{client.base_url}", |
| 140 | + api_key="not-needed", |
| 141 | +) |
| 142 | +
|
| 143 | +response = client.chat.completions.create( |
| 144 | + model="{inference_name}", |
| 145 | + messages=[ |
| 146 | + {{"role": "user", "content": "Your prompt here"}}, |
| 147 | + ], |
| 148 | +) |
| 149 | +print(response.choices[0].message.content) |
| 150 | +''') |
| 151 | + |
| 152 | + print("--- curl usage ---\n") |
| 153 | + print(f'''\ |
| 154 | +curl {client.base_url}chat/completions \\ |
| 155 | + -H "Content-Type: application/json" \\ |
| 156 | + -d '{{ |
| 157 | + "model": "{inference_name}", |
| 158 | + "messages": [ |
| 159 | + {{"role": "user", "content": "Your prompt here"}} |
| 160 | + ] |
| 161 | + }}' |
| 162 | +''') |
| 163 | + |
| 164 | + await backend.close() |
| 165 | +``` |
| 166 | + |
| 167 | +### Backend setup |
| 168 | + |
| 169 | +Use the appropriate backend based on the user's choice: |
| 170 | + |
| 171 | +**LocalBackend:** |
| 172 | +```python |
| 173 | +from art.local import LocalBackend |
| 174 | + |
| 175 | +backend = LocalBackend() |
| 176 | +model = art.TrainableModel( |
| 177 | + name="<RUN_NAME>", |
| 178 | + project="<PROJECT_NAME>", |
| 179 | + base_model="<BASE_MODEL>", |
| 180 | + _internal_config=art.dev.InternalModelConfig( |
| 181 | + engine_args={"gpu_memory_utilization": 0.7}, |
| 182 | + ), |
| 183 | +) |
| 184 | +await model.register(backend) |
| 185 | +``` |
| 186 | + |
| 187 | +**ServerlessBackend:** |
| 188 | +```python |
| 189 | +from art.serverless.backend import ServerlessBackend |
| 190 | + |
| 191 | +backend = ServerlessBackend() # uses WANDB_API_KEY env var |
| 192 | +model = art.TrainableModel( |
| 193 | + name="<RUN_NAME>", |
| 194 | + project="<PROJECT_NAME>", |
| 195 | + base_model="<BASE_MODEL>", |
| 196 | +) |
| 197 | +await model.register(backend) |
| 198 | +``` |
| 199 | + |
| 200 | +Note: `_internal_config` with `gpu_memory_utilization` is only used with LocalBackend. Do NOT include it for ServerlessBackend. |
| 201 | + |
| 202 | +### JSONL file training pattern: |
| 203 | + |
| 204 | +If hyperparameters were computed in Step 5, pass them explicitly. If Step 5 was skipped, omit them — `train_sft_from_file` has sensible defaults. |
| 205 | + |
| 206 | +```python |
| 207 | +"""SFT training script generated by /train-sft wizard.""" |
| 208 | +import asyncio |
| 209 | +import art |
| 210 | +<BACKEND_IMPORT> |
| 211 | +from art.utils.sft import train_sft_from_file |
| 212 | + |
| 213 | +async def main(): |
| 214 | + <BACKEND_SETUP> |
| 215 | + |
| 216 | + await train_sft_from_file( |
| 217 | + model=model, |
| 218 | + file_path="<FILE_PATH>", |
| 219 | + # Only include these if hyperparameters were computed: |
| 220 | + # epochs=<EPOCHS>, |
| 221 | + # batch_size=<BATCH_SIZE>, |
| 222 | + # peak_lr=<PEAK_LR>, |
| 223 | + # schedule_type="<SCHEDULE_TYPE>", |
| 224 | + # warmup_ratio=<WARMUP_RATIO>, |
| 225 | + verbose=True, |
| 226 | + ) |
| 227 | + |
| 228 | + # ... post-training block + backend.close() ... |
| 229 | + |
| 230 | +if __name__ == "__main__": |
| 231 | + asyncio.run(main()) |
| 232 | +``` |
| 233 | + |
| 234 | +### Distillation pattern: |
| 235 | +```python |
| 236 | +"""Distillation SFT script generated by /train-sft wizard.""" |
| 237 | +import asyncio, os |
| 238 | +from dotenv import load_dotenv |
| 239 | +from openai import AsyncOpenAI |
| 240 | +import art |
| 241 | +<BACKEND_IMPORT> |
| 242 | +from art.utils.sft import create_sft_dataset_iterator |
| 243 | + |
| 244 | +load_dotenv() |
| 245 | + |
| 246 | +async def main(): |
| 247 | + teacher_client = AsyncOpenAI( |
| 248 | + api_key=os.environ["<API_KEY_ENV_VAR>"], |
| 249 | + base_url="<TEACHER_API_BASE>", |
| 250 | + ) |
| 251 | + prompts = ["<PROMPT_1>", "<PROMPT_2>"] |
| 252 | + |
| 253 | + trajectories = [] |
| 254 | + for prompt in prompts: |
| 255 | + completion = await teacher_client.chat.completions.create( |
| 256 | + model="<TEACHER_MODEL>", |
| 257 | + messages=[{"role": "user", "content": prompt}], |
| 258 | + ) |
| 259 | + trajectories.append( |
| 260 | + art.Trajectory( |
| 261 | + messages_and_choices=[ |
| 262 | + {"role": "user", "content": prompt}, |
| 263 | + {"role": "assistant", "content": completion.choices[0].message.content}, |
| 264 | + ], |
| 265 | + tools=<TOOLS_OR_NONE>, |
| 266 | + ) |
| 267 | + ) |
| 268 | + |
| 269 | + <BACKEND_SETUP> |
| 270 | + |
| 271 | + for chunk in create_sft_dataset_iterator( |
| 272 | + trajectories, |
| 273 | + epochs=<EPOCHS>, |
| 274 | + batch_size=<BATCH_SIZE>, |
| 275 | + peak_lr=<PEAK_LR>, |
| 276 | + schedule_type="<SCHEDULE_TYPE>", |
| 277 | + warmup_ratio=<WARMUP_RATIO>, |
| 278 | + ): |
| 279 | + await model.train_sft(chunk.trajectories, chunk.config, verbose=True) |
| 280 | + |
| 281 | + # ... post-training block + backend.close() ... |
| 282 | + |
| 283 | +if __name__ == "__main__": |
| 284 | + asyncio.run(main()) |
| 285 | +``` |
| 286 | + |
| 287 | +## Step 7: Write and Offer to Run |
| 288 | + |
| 289 | +1. Write the script to a file (suggest `sft_train.py`) |
| 290 | +2. Ask the user if they want to run it now with `uv run python <script_path>` |
| 291 | +3. If yes, run it **directly using the Bash tool** (do NOT delegate to a Task subagent) so training logs stream live to the user. Use a **2-minute timeout**. If it times out, check progress and decide whether to continue. |
| 292 | +4. **LocalBackend only — GPU memory errors**: If training fails with OOM, lower `gpu_memory_utilization` in the existing `_internal_config` (e.g. from `0.7` to `0.5`). |
| 293 | +5. **LocalBackend only — Stale GPU memory**: If available GPU memory looks too small, previous training runs may still be occupying memory. Before retrying, run `nvidia-smi` to check, and if needed kill leftover processes with `kill <pid>` to free memory. |
| 294 | + |
| 295 | +## Important Notes |
| 296 | + |
| 297 | +- LocalBackend requires a GPU. |
| 298 | +- ServerlessBackend requires a `WANDB_API_KEY` environment variable. |
0 commit comments