Skip to content

Commit aa23b8d

Browse files
committed
SpecDec Bench: February Update
Signed-off-by: Izzy Putterman <[email protected]>
1 parent 9e38041 commit aa23b8d

28 files changed

+2234
-120
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ repos:
2424
hooks:
2525
- id: ruff-check
2626
args: [--fix, --exit-non-zero-on-fix]
27+
exclude: ^examples/specdec_bench/specdec_bench/datasets/speed\.py$
2728
- id: ruff-format
29+
exclude: ^examples/specdec_bench/specdec_bench/datasets/speed\.py$
2830

2931
- repo: https://github.com/pre-commit/mirrors-mypy
3032
rev: v1.17.1

examples/specdec_bench/README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,52 @@ python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b -
4141

4242
```
4343

44+
### Running [SPEED-Bench](https://huggingface.co/datasets/nvidia/SPEED-Bench) on Llama 3.3 70B + Eagle 3
45+
46+
1. Install the requirements file using `pip install -r requirements.txt`
47+
48+
2. Prepare the data using the provided script:
49+
50+
```bash
51+
python3 prepare_data.py --dataset speed --config all
52+
```
53+
54+
The data will be saved to `data/` directory, each config type (qualitative, throughput_1k, ...) to each own directory.
55+
56+
#### License
57+
58+
GOVERNING TERMS: This dataset is governed by the NVIDIA Evaluation Dataset License Agreement.
59+
60+
ADDITIONAL INFORMATION: MIT for bigcode/humanevalpack, RUCAIBox/MMATH, RUCAIBox/BAMBOO and EQ-Bench. Apache 2.0 for Writing Bench and Spec-Bench. CC BY 4.0 for FBK-MT/MCIF. MIT and Apache 2.0 for tianyang/repobench_python_v1.1, JetBrains-Research/lca-project-level-code-completion and tianyang/repobench_java_v1.1.
61+
62+
NOTICE: For each dataset a user elects to use, the user is responsible for checking if the dataset license is fit for the intended purpose. The `prepare_data.py` script automatically fetches data from all the source datasets.
63+
64+
Additional details are in [HuggingFace dataset repository](https://huggingface.co/datasets/nvidia/SPEED-Bench).
65+
66+
#### Qualitative split
67+
68+
```bash
69+
python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/qualitative --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress
70+
```
71+
72+
#### Throughput split
73+
74+
```bash
75+
python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/throughput_1k --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress
76+
```
77+
78+
For longer context (>8192 tokens), please use the following configuration when using TRTLLM:
79+
80+
```yaml
81+
engine_args:
82+
max_seq_len: 131072 # Model max context length (for Llama 3.3 70B)
83+
enable_chunked_prefill: true
84+
```
85+
86+
```bash
87+
python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/throughput_16k --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress --runtime_params runtime_args_long_context.yaml
88+
```
89+
4490
## Notes
4591

4692
The goal of this benchmark is to provide an easy way to configure, run, and compare speculative implementations across frameworks in an apples-to-apples method.
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
# Porting Spec-Bench Inference Runners to specdec_bench
2+
3+
This guide explains how to convert any `inference_*.py` runner from [Spec-Bench](https://github.com/hemingkx/Spec-Bench) to a model class compatible with `specdec_bench`.
4+
5+
## Overview
6+
7+
Spec-Bench inference runners follow a pattern where:
8+
9+
1. A `*_forward()` function handles the speculative decoding logic
10+
2. The `run_eval()` function orchestrates evaluation with tokenized inputs
11+
3. Models are loaded in `__main__` and passed to `run_eval()`
12+
13+
In contrast, `specdec_bench` uses a class-based approach where:
14+
15+
1. Models inherit from the `Model` base class
16+
2. `__init__()` handles model loading
17+
3. `run()` is an async method that processes single requests
18+
4. `stop()` handles cleanup
19+
20+
## The specdec_bench Model Interface
21+
22+
```python
23+
class Model:
24+
def __init__(self, model_dir, tokenizer, max_draft_length):
25+
raise NotImplementedError
26+
27+
async def run(self, prompt_ids, sampling_params, request_id, turn_id):
28+
"""
29+
prompt_ids: list of token IDs (not a tensor!)
30+
Returns dict with:
31+
- output_ids: list of list of token chunks per step [[chunk1, chunk2, ...]]
32+
- output_logits: optional logits (usually None)
33+
- token_times: list of timestamps per decoding step
34+
"""
35+
raise NotImplementedError
36+
37+
def stop(self):
38+
pass
39+
```
40+
41+
## Step-by-Step Porting Guide
42+
43+
### Step 1: Identify the Key Components in Spec-Bench
44+
45+
Look at the `inference_*.py` file and identify:
46+
47+
1. **The forward function** (e.g., `medusa_forward`, `ea_forward`)
48+
- This contains the core speculative decoding loop
49+
- Signature: `forward_func(inputs, model, tokenizer, max_new_tokens, **kwargs)`
50+
- Returns: `(output_ids, new_token_count, num_steps, accept_length_list)`
51+
52+
2. **The model class** (e.g., `MedusaModel`, `EaModel`)
53+
- Found in `model/<method>/` directory
54+
- Has a `from_pretrained()` class method
55+
56+
3. **Required utilities** from the method's module:
57+
- Buffer generation (e.g., `generate_medusa_buffers`)
58+
- Initialization functions (e.g., `initialize_medusa`, `initialize_past_key_values`)
59+
- Decoding functions (e.g., `tree_decoding`, `generate_candidates`)
60+
- State update functions (e.g., `update_inference_inputs`)
61+
62+
4. **Method-specific choices/configs** (e.g., `mc_sim_7b_63` for Medusa)
63+
64+
### Step 2: Create the specdec_bench Model Class
65+
66+
```python
67+
# specdec_bench/specdec_bench/models/specbench_<method>.py
68+
69+
from .base import Model
70+
import asyncio
71+
import time
72+
import torch
73+
74+
# Import dependencies from Spec-Bench
75+
try:
76+
import sys
77+
import os
78+
spec_bench_path = os.path.join(os.getcwd(), "Spec-Bench")
79+
sys.path.insert(0, spec_bench_path)
80+
from model.<method>.<model_file> import <ModelClass>
81+
from model.<method>.kv_cache import initialize_past_key_values
82+
from model.<method>.utils import (
83+
# Import all required utilities
84+
)
85+
from model.<method>.<choices_file> import <default_choices>
86+
except ImportError as e:
87+
print(f"<Method> dependencies not found: {e}")
88+
<ModelClass> = None
89+
90+
91+
class SpecBench<Method>Model(Model):
92+
def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs):
93+
# 1. Validate dependencies
94+
if <ModelClass> is None:
95+
raise ImportError("<Method> dependencies not found.")
96+
97+
# 2. Extract configuration from kwargs
98+
self.dtype = kwargs.get("dtype", "float16")
99+
self.max_steps = kwargs.get("max_steps", 512)
100+
self.temperature = sampling_kwargs.get("temperature", 0.0)
101+
# ... other method-specific parameters
102+
103+
# 3. Set up device (avoid device_map="auto" for multi-GPU issues)
104+
self.device = torch.device(kwargs.get("device", "cuda:0"))
105+
106+
# 4. Convert dtype string to torch dtype
107+
dtype_map = {
108+
"float32": torch.float32,
109+
"float16": torch.float16,
110+
"bfloat16": torch.bfloat16,
111+
}
112+
torch_dtype = dtype_map.get(self.dtype, torch.float16)
113+
114+
# 5. Load the model
115+
self.model = <ModelClass>.from_pretrained(
116+
model_dir,
117+
# ... other args from Spec-Bench's __main__
118+
torch_dtype=torch_dtype,
119+
low_cpu_mem_usage=True,
120+
)
121+
self.model = self.model.to(self.device)
122+
123+
self.sampling_kwargs = sampling_kwargs
124+
```
125+
126+
### Step 3: Port the Forward Function
127+
128+
Convert the standalone `*_forward()` function to an internal method:
129+
130+
```python
131+
def _forward(self, input_ids, max_new_tokens, end_id):
132+
"""
133+
Port of the original *_forward function.
134+
135+
Key changes from Spec-Bench:
136+
1. input_ids is already a tensor (converted in run())
137+
2. Add timing list to track per-step timestamps
138+
3. Use self.device instead of model.base_model.device
139+
4. Return timing along with other outputs
140+
"""
141+
accept_length_list = []
142+
timing = [time.perf_counter()] # ADD: Track timing
143+
144+
# === COPY THE FORWARD LOGIC FROM SPEC-BENCH ===
145+
# Replace: device=model.base_model.device
146+
# With: device=self.device
147+
148+
# Initialize buffers...
149+
# Initialize KV cache...
150+
# Main decoding loop...
151+
152+
for idx in range(self.max_steps):
153+
# Generate candidates...
154+
# Tree decoding...
155+
# Evaluate posterior...
156+
# Update inputs...
157+
158+
timing.append(time.perf_counter()) # ADD: Record time per step
159+
160+
# Check for EOS
161+
if end_id in input_ids[0, input_len:].tolist():
162+
break
163+
if new_token > max_new_tokens:
164+
break
165+
166+
return input_ids, new_token, idx + 1, accept_length_list, timing # ADD timing
167+
```
168+
169+
### Step 4: Implement the run() Method
170+
171+
```python
172+
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
173+
"""
174+
Async interface for specdec_bench.
175+
176+
Args:
177+
prompt_ids: List of input token IDs (NOT a tensor)
178+
max_length: Maximum new tokens to generate
179+
end_id: EOS token ID
180+
request_id: Request identifier
181+
turn_id: Turn identifier
182+
183+
Returns:
184+
dict with output_ids, output_logits, token_times
185+
"""
186+
output_dict = {}
187+
188+
# Convert prompt_ids list to tensor
189+
input_ids = torch.tensor(
190+
[prompt_ids], dtype=torch.long, device=self.device
191+
)
192+
193+
# Run forward pass (use asyncio.to_thread for sync code)
194+
result = await asyncio.to_thread(
195+
self._forward, input_ids, max_length, end_id
196+
)
197+
input_ids_out, new_token, num_steps, accept_length_list, timing = result
198+
199+
# Extract generated tokens (excluding prompt)
200+
original_len = len(prompt_ids)
201+
generated_tokens = input_ids_out[0, original_len:].tolist()
202+
203+
# Remove EOS token if present
204+
if end_id in generated_tokens:
205+
eos_idx = generated_tokens.index(end_id)
206+
generated_tokens = generated_tokens[:eos_idx]
207+
208+
# Format output_ids as list of token chunks per step
209+
# This matches specdec_bench's expected format
210+
reformatted_output_ids = [[]]
211+
start = 0
212+
for accept_len in accept_length_list:
213+
if accept_len > 0 and start < len(generated_tokens):
214+
chunk = generated_tokens[start:start + accept_len]
215+
if chunk:
216+
reformatted_output_ids[0].append(chunk)
217+
start += accept_len
218+
219+
# Handle remaining tokens
220+
if start < len(generated_tokens):
221+
reformatted_output_ids[0].append(generated_tokens[start:])
222+
223+
output_dict['output_ids'] = reformatted_output_ids
224+
output_dict['output_logits'] = None
225+
output_dict['token_times'] = timing
226+
227+
return output_dict
228+
```
229+
230+
### Step 5: Implement stop() for Cleanup
231+
232+
```python
233+
def stop(self):
234+
"""Clean up resources."""
235+
# Clear any cached states
236+
if hasattr(self.model, "past_key_values"):
237+
del self.model.past_key_values
238+
del self.model.past_key_values_data
239+
del self.model.current_length_data
240+
241+
# Clear method-specific buffers
242+
if hasattr(self.model, "<method>_buffers"):
243+
del self.model.<method>_buffers
244+
245+
# Free GPU memory
246+
if hasattr(self, 'model') and self.model is not None:
247+
del self.model
248+
torch.cuda.empty_cache()
249+
```
250+
251+
### Step 6: Register the Model (Optional)
252+
253+
Add to `specdec_bench/specdec_bench/models/__init__.py`:
254+
255+
```python
256+
from .specbench_<method> import SpecBench<Method>Model
257+
```
258+
259+
## Key Differences Summary
260+
261+
| Aspect | Spec-Bench | specdec_bench |
262+
|--------|-----------|---------------|
263+
| Input format | `inputs.input_ids` (tensor from tokenizer) | `prompt_ids` (list of ints) |
264+
| Output format | `(output_ids, new_token, steps, accept_lengths)` | `dict` with `output_ids`, `output_logits`, `token_times` |
265+
| Output IDs | Full sequence tensor | List of token chunks per step |
266+
| Timing | External (in `run_eval`) | Internal (in `run()`) |
267+
| Device | `device_map="auto"` | Explicit single device |
268+
| Interface | Function-based | Class-based with async `run()` |
269+
270+
## Common Pitfalls
271+
272+
1. **Device Mismatch**: Avoid `device_map="auto"` which spreads model across GPUs. Use explicit `.to(device)`.
273+
274+
2. **Tensor vs List**: `prompt_ids` in specdec_bench is a Python list, not a tensor. Convert it in `run()`.
275+
276+
3. **Output Format**: specdec_bench expects `output_ids` as `[[chunk1, chunk2, ...]]` (list of lists of lists for beam_width=1).
277+
278+
4. **Timing**: Add `time.perf_counter()` calls to track per-step latency.
279+
280+
5. **EOS Handling**: Strip EOS tokens from output before formatting.
281+
282+
6. **Async Wrapper**: Use `asyncio.to_thread()` to wrap synchronous forward passes.
283+
284+
## Example: Mapping Spec-Bench Methods
285+
286+
| Spec-Bench File | Model Class | Forward Function | Key Utils |
287+
|-----------------|-------------|------------------|-----------|
288+
| `inference_medusa.py` | `MedusaModel` | `medusa_forward` | `generate_medusa_buffers`, `initialize_medusa` |
289+
| `inference_eagle.py` | `EaModel` | `ea_forward` | `generate_tree_buffers`, `initialize_tree` |
290+
| `inference_eagle2.py` | `EaModel` | `ea_forward` | Same as EAGLE |
291+
| `inference_hydra.py` | `HydraModel` | `hydra_forward` | `generate_hydra_buffers`, `initialize_hydra` |
292+
| `inference_lookahead.py` | `LookaheadModel` | `lookahead_forward` | Lookahead-specific utils |
293+
294+
## Testing Your Port
295+
296+
```python
297+
import asyncio
298+
299+
async def test():
300+
model = SpecBench<Method>Model(
301+
model_dir="/path/to/model",
302+
max_concurrent_requests=1,
303+
sampling_kwargs={"temperature": 0.0},
304+
# method-specific kwargs...
305+
)
306+
307+
result = await model.run(
308+
prompt_ids=[1, 2, 3, 4, 5], # Example token IDs
309+
max_length=100,
310+
end_id=2, # EOS token
311+
request_id="test",
312+
turn_id=0
313+
)
314+
315+
print("Output chunks:", result['output_ids'])
316+
print("Timing:", result['token_times'])
317+
318+
model.stop()
319+
320+
asyncio.run(test())
321+
```
322+
323+
Adjust the vicuna chat template to be in the tokenizer_config to be
324+
325+
Insert to tokenizer_config (for vicuna)
326+
327+
```json
328+
"chat_template": "{% set ns = namespace(system='') %}{% for m in messages %}{% if m['role'] == 'system' %}{% set ns.system = m['content'] %}{% endif %}{% endfor %}{{ ns.system | trim }}{% if ns.system | trim != '' %} {% endif %}{% for m in messages %}{% if m['role'] == 'user' %}USER: {{ m['content'] | trim }} ASSISTANT:{% elif m['role'] == 'assistant' %}{{ m['content'] | trim }}{% endif %}{% endfor %}"
329+
```

0 commit comments

Comments
 (0)