Skip to content

Commit 8f2f7ad

Browse files
committed
add ptpc-fp8 amd blogpost
Signed-off-by: tanpinsiang <[email protected]>
1 parent ded60c9 commit 8f2f7ad

11 files changed

+301
-0
lines changed

_posts/2025-02-24-ptpc-fp8-rocm.md

+297
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
---
2+
layout: post
3+
title: "PTPC-FP8: Boosting vLLM Performance on AMD ROCm"
4+
author: "AMD and Embedded LLM"
5+
image: /assets/figures/ptpc/PTPC-tumbnail.png
6+
thumbnail-img: /assets/figures/ptpc/PTPC-tumbnail.png
7+
share-img: /assets/figures/ptpc/PTPC-tumbnail.png
8+
math: true
9+
---
10+
11+
**TL;DR**: vLLM on AMD ROCm now has better FP8 performance!
12+
13+
* **What's new?** [PTPC-FP8 quantization](https://github.com/vllm-project/vllm/pull/12501) is now supported in vLLM (v0.7.3+) on AMD ROCm.
14+
* **Why is it good?** You get speeds similar to other FP8 methods, but with accuracy much closer to the original (BF16) model quality. It's the best FP8 option for ROCm.
15+
* **How to use it:**
16+
1. Install ROCm.
17+
2. Get the latest vLLM (v0.7.3 or newer).
18+
3. Add the `--quantization ptpc_fp8` flag when running your Hugging Face model. No need to pre-quantize!
19+
20+
21+
<img align="center" src="/assets/figures/ptpc/PTPC121.png" alt="What is PTPC-FP8" width="90%" height="90%">
22+
23+
**What is PTPC-FP8?** It's a method for FP8 weights *and* activations quantization. It uses per-token scaling for activations and per-channel scaling for weights, giving you better accuracy than traditional per-tensor FP8.
24+
25+
## Introduction
26+
27+
Large Language Models (LLMs) are revolutionizing how we interact with technology, but their immense computational demands can be a barrier. What if you could run these powerful models faster and more efficiently on your AMD GPUs, without sacrificing accuracy? Now you can! This post introduces a breakthrough: PTPC-FP8 quantization in vLLM, optimized for AMD's ROCm platform. Get ready for near-BF16 accuracy at FP8 speeds, directly using Hugging Face models – no pre-quantization needed! We'll show you how it works, benchmark its performance, and get you started.
28+
29+
### The Challenge of LLM Quantization and the PTPC-FP8 Solution
30+
31+
Running large language models is computationally expensive. FP8 (8-bit floating-point) offers a compelling solution by reducing memory footprint and accelerating matrix multiplications, but traditional quantization approaches face a critical challenge with LLMs.
32+
33+
#### The Outlier Problem
34+
35+
LLMs develop activation outliers as they scale beyond certain sizes. These unusually large values create significant quantization challenges:
36+
37+
- Most values receive few effective bits of precision when using per-tensor quantization
38+
- Outliers appear persistently in specific channels across different tokens
39+
- While weights are relatively uniform and easy to quantize, activations are not
40+
#### PTPC: A Precision-Targeted Approach
41+
42+
PTPC-FP8 (Per-Token-Activation, Per-Channel-Weight FP8) addresses this challenge by using tailored scaling factors based on three key observations:
43+
44+
1. Outliers consistently appear in the same channels
45+
2. Channel magnitudes within a token vary widely
46+
3. The same channel's magnitude across different tokens remains relatively stable
47+
48+
This insight led to a dual-granularity approach:
49+
* **Per-Token Activation Quantization**: Each input token receives its own scaling factor
50+
* **Per-Channel Weight Quantization**: Each weight column gets a unique scaling factor
51+
52+
<div align="center">
53+
<img src="/assets/figures/ptpc/PTPC-Diagram.png" alt="Per-Token Activation + Per-Channel Weight Quantization" width="80%">
54+
</div>
55+
56+
#### Understanding the Diagram
57+
58+
The illustration shows two quantization approaches:
59+
60+
**Tensor Dimensions (Both Methods):**
61+
- **$X$**: Input activation tensor ($T \times C_i$)
62+
- **$W$**: Weight tensor ($C_i \times C_o$)
63+
- **$T$**: Token sequence length
64+
- **$C_i/C_o$**: Input/output channels
65+
- **$*$**: Matrix multiplication
66+
67+
**Scaling Factors:**
68+
- **Top (Per-Tensor)**: Single scalars $\Delta_X[1]$ and $\Delta_W[1]$ for entire tensors
69+
- **Bottom (PTPC)**: Vector $\Delta_X[T \times 1]$ with one scale per token and $\Delta_W[1 \times C_o]$ with one scale per input channel
70+
71+
This granular scaling approach allows PTPC-FP8 to achieve accuracy close to BF16 while maintaining the speed and memory benefits of 8-bit computation.
72+
73+
## Deep Dive: How PTPC-FP8 Works in vLLM (and the Fused Kernel)
74+
75+
PTPC-FP8's fine-grained scaling could slow things down without proper optimization. The key to maintaining speed is AMD ROCm's implementation of a **fused FP8 rowwise scaled GEMM** operation.
76+
77+
### The Challenge: 2-Step vs. Fused Approach
78+
79+
Without optimization, matrix multiplication with per-token and per-channel scaling would require two costly steps:
80+
81+
```python
82+
# Naive 2-step approach:
83+
output = torch._scaled_mm(input, weight) # Step 1: FP8 GEMM
84+
output = output * token_scales * channel_scales # Step 2: Apply scaling factors
85+
```
86+
87+
This creates a performance bottleneck:
88+
- Write large intermediate results to memory
89+
- Read them back for scaling operations
90+
- Waste memory bandwidth and compute cycles
91+
92+
### The Solution: Fusion
93+
94+
The fused approach combines matrix multiplication and scaling into a single hardware operation:
95+
96+
```python
97+
# Optimized fused operation:
98+
output = torch._scaled_mm(input, weight,
99+
scale_a=token_scales,
100+
scale_b=channel_scales)
101+
```
102+
103+
<img align="center" src="/assets/figures/ptpc/FusedGEMM.svg" alt="Fused GEMM Operation" width="90%" height="90%">
104+
105+
### Why This Matters
106+
107+
This fusion leverages AMD GPUs' specialized hardware (particularly on MI300X with native FP8 support):
108+
109+
- **Memory Efficiency**: Scaling happens within on-chip memory before writing results
110+
- **Computational Efficiency**: Eliminates redundant operations
111+
- **Performance Boost**: Our tests show up to 2.5× speedup compared to the naive implementation
112+
113+
The fused operation makes PTPC-FP8 practical for real-world deployment, eliminating the performance penalty of using more granular scaling factors while maintaining accuracy benefits.
114+
115+
## Benchmarking PTPC-FP8: Speed and Accuracy on MI300X
116+
117+
We extensively benchmarked PTPC-FP8 using vLLM on AMD MI300X GPUs (commit `4ea48fb35cf67d61a1c3f18e3981c362e1d8e26f`). Here's what we found:
118+
119+
### 1. Throughput Comparison (PTPC-FP8 vs. Per-Tensor FP8):
120+
121+
* **Model:** Llama-3.1-70B-Instruct
122+
* **Dataset:** SharedGPT
123+
* **GPU:** 1x MI300X
124+
* **Result:** PTPC-FP8 achieves virtually identical throughput to per-tensor FP8 (even slightly *better* – 1.01x improvement). This demonstrates that the fused kernel completely overcomes the potential overhead of PTPC-FP8's more complex scaling.
125+
126+
<img align="center" src="/assets/figures/ptpc/PTPCReqs.svg" alt="Throughput in Reqs/s across various input-output sequence length of Llama-3.1-70B-Instruct" width="90%" height="50%">
127+
128+
<img align="center" src="/assets/figures/ptpc/PTPCSpeedup.svg" alt="Request/s Throughput gain over FP8 per-tensor quantization
129+
across different input token length - output token length" width="90%" height="50%">
130+
131+
### 2.1. Accuracy: Perplexity (Lower is Better)
132+
133+
* **Model:** Llama-3.1-8B-Instruct
134+
* **Dataset:** Wikitext
135+
* **Setup:** 2× MI300X GPUs with tensor parallelism
136+
137+
#### Understanding Perplexity: The Prediction Power Test
138+
139+
Think of perplexity as a measure of how "confused" the model is when predicting text. Like a student taking a quiz:
140+
- **Lower perplexity = Better predictions** (the model confidently assigns high probability to the correct next words)
141+
- **Higher perplexity = More uncertainty** (the model is frequently surprised by what comes next)
142+
143+
A small increase in perplexity (even 0.1) can indicate meaningful degradation in model quality, especially for large language models that have been extensively optimized.
144+
145+
#### Results: PTPC-FP8 Maintains BF16-Like Quality
146+
147+
<img align="center" src="/assets/figures/ptpc/PerplexityBits.png" alt="bits and byte perplexity" width="90%">
148+
149+
<img align="center" src="/assets/figures/ptpc/Perplexitywords.png" alt="Word Perplexity Comparison" width="90%">
150+
151+
| Precision | Word Perplexity | % Degradation |
152+
|:----------|:----------------|:--------------|
153+
| BF16 (baseline) | 9.4281 | - |
154+
| PTPC-FP8 | 9.5093 | 0.86% |
155+
| Standard FP8 | 9.5124 | 0.89% |
156+
157+
As shown in both the table and chart:
158+
159+
1. **PTPC-FP8 outperforms standard FP8** quantization (9.5093 vs 9.5124)
160+
2. **The gap to BF16 is minimal** - only 0.86% degradation from the full-precision baseline
161+
3. **Byte-level metrics** (bits_per_byte and byte_perplexity) show the same pattern of results
162+
163+
**Why This Matters:** While standard FP8 already provides decent results, PTPC-FP8's lower perplexity indicates it better preserves the model's ability to make accurate predictions. This is especially important for complex reasoning and generation tasks, where small quality drops can compound into noticeable differences in output quality.
164+
165+
### 2.2. Accuracy on GSM8K: Testing Mathematical Reasoning**
166+
167+
#### What is GSM8K and Why It Matters
168+
169+
GSM8K tests a model's ability to solve grade school math word problems – one of the most challenging tasks for LLMs. Unlike simple text prediction, these problems require:
170+
- Multi-step reasoning
171+
- Numerical accuracy
172+
- Logical consistency
173+
174+
This benchmark provides a strong indicator of whether quantization preserves a model's reasoning abilities.
175+
176+
#### Understanding the Results
177+
178+
We measured accuracy using two methods:
179+
- **Flexible-extract**: Accepts answers if the correct number appears anywhere in the response
180+
- **Strict-match**: Requires the exact answer in the expected format
181+
182+
<img align="center" src="/assets/figures/ptpc/GSM8K8B.png" alt="Accuracy Comparison on Llama-3.1-8B" width="80%" height="80%">
183+
184+
**8B Model Results at a Glance:**
185+
186+
| Method | Strict-match Accuracy | % of BF16 Performance |
187+
|:-------|:----------------------|:----------------------|
188+
| BF16 (baseline) | 73.2% | 100% |
189+
| PTPC-FP8 | 70.8% | 96.7% |
190+
| Standard FP8 | 69.2% | 94.5% |
191+
192+
**70B Model Results:**
193+
194+
<img align="center" src="/assets/figures/ptpc/GSM8K70B.png" alt="Accuracy Comparison on Llama-3.1-70B" width="80%" height="80%">
195+
196+
For the larger 70B model:
197+
- PTPC-FP8 achieves **87.3%** strict-match accuracy
198+
- This is actually **slightly better** than BF16's 86.3%
199+
- Both outperform standard FP8 in strict-match conditions
200+
201+
#### Why These Results Matter
202+
203+
1. **Preservation of reasoning abilities**: Mathematical reasoning is often the first capability to degrade with quantization
204+
205+
2. **PTPC-FP8 consistently outperforms standard FP8** across both model sizes
206+
207+
3. **Near-BF16 quality** with substantially reduced memory and improved performance
208+
209+
4. **Scaling advantage**: The performance gap between quantization methods narrows as model size increases, suggesting PTPC-FP8 is especially valuable for large models
210+
211+
These results demonstrate that PTPC-FP8 quantization preserves the model's ability to perform complex reasoning tasks while delivering the speed and efficiency benefits of 8-bit precision.
212+
213+
## Getting Started
214+
215+
1. **Install ROCm:** Make sure you have a recent version.
216+
2. Clone the latest vLLM commit now! Setup and start exploring this new feature!
217+
218+
```bash
219+
$ git clone https://github.com/vllm-project/vllm.git
220+
$ cd vllm
221+
$ DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm .
222+
$ docker run -it \
223+
--network=host \
224+
--group-add=video \
225+
--ipc=host \
226+
--cap-add=SYS_PTRACE \
227+
--security-opt seccomp=unconfined \
228+
--device /dev/kfd \
229+
--device /dev/dri \
230+
-v <path/to/model>:/app/model \
231+
vllm-rocm \
232+
bash
233+
```
234+
235+
3. **Run vLLM with the `--quantization ptpc_fp8` flag:**
236+
237+
```bash
238+
VLLM_USE_TRITON_FLASH_ATTN=0 vllm serve <your-model> --max-seq-len-to-capture 16384 --enable-chunked-prefill=False --num-scheduler-steps 15 --max-num-seqs 1024 --quantization ptpc_fp8
239+
```
240+
241+
(Replace `<your-model>` with any hugging face model; It will automatically quantize the weight on-the-fly.)
242+
243+
## Conclusion: The Accuracy-Speed Sweet Spot
244+
245+
PTPC-FP8 quantization in vLLM on AMD ROCm represents a significant step towards democratizing access to powerful LLMs. By making near-BF16 accuracy achievable at FP8 speeds, we're breaking down the computational barriers that have limited wider adoption. This advancement empowers a broader community – from individual researchers to resource-constrained organizations – to leverage the power of large language models on accessible AMD hardware. We invite you to explore PTPC-FP8, share your experiences, contribute to the vLLM project, and help us build a future where efficient and accurate AI is available to everyone.
246+
247+
## Appendix
248+
249+
**lm-evaluation-harness Commands:**
250+
251+
```bash
252+
# Unquantized (Bfloat16)
253+
MODEL=meta-llama/Llama-3.1-8B-Instruct
254+
HIP_VISIBLE_DEVICES=0,1 lm_eval \
255+
--model vllm \
256+
--model_args pretrained=$MODEL,add_bos_token=True,tensor_parallel_size=2,kv_cache_dtype=auto,max_model_len=2048,gpu_memory_utilization=0.6 \
257+
--tasks wikitext --batch_size 16
258+
259+
# Per-Tensor FP8 Quantization
260+
MODEL=meta-llama/Llama-3.1-8B-Instruct
261+
HIP_VISIBLE_DEVICES=0,1 lm_eval \
262+
--model vllm \
263+
--model_args pretrained=$MODEL,add_bos_token=True,tensor_parallel_size=2,quantization=fp8,kv_cache_dtype=fp8_e4m3,max_model_len=2048,gpu_memory_utilization=0.6 \
264+
--tasks wikitext --batch_size 16
265+
266+
# Per-Token-Activation Per-Channel-Weight FP8 Quantization
267+
MODEL=meta-llama/Llama-3.1-8B-Instruct
268+
HIP_VISIBLE_DEVICES=0,1 lm_eval \
269+
--model vllm \
270+
--model_args pretrained=$MODEL,add_bos_token=True,tensor_parallel_size=2,quantization=ptpc_fp8,kv_cache_dtype=fp8_e4m3,max_model_len=2048,gpu_memory_utilization=0.6 \
271+
--tasks wikitext --batch_size 16
272+
```
273+
274+
**lm-evaluation-harness Commands (8B Model - adjust for 70B):**
275+
276+
```bash
277+
# FP8 (Per-Tensor)
278+
MODEL=/app/model/Llama-3.1-8B-Instruct/ # Or Llama-3.1-70B-Instruct
279+
lm_eval \
280+
--model vllm \
281+
--model_args pretrained=$MODEL,add_bos_token=True,quantization=fp8,kv_cache_dtype=fp8_e4m3 \
282+
--tasks gsm8k --num_fewshot 5 --batch_size auto --limit 250
283+
284+
# PTPC FP8
285+
MODEL=/app/model/Llama-3.1-8B-Instruct/ # Or Llama-3.1-70B-Instruct
286+
lm_eval \
287+
--model vllm \
288+
--model_args pretrained=$MODEL,add_bos_token=True,quantization=ptpc_fp8,kv_cache_dtype=fp8_e4m3 \
289+
--tasks gsm8k --num_fewshot 5 --batch_size auto --limit 250
290+
291+
# BF16
292+
MODEL=/app/model/Llama-3.1-8B-Instruct/ # Or Llama-3.1-70B-Instruct
293+
lm_eval \
294+
--model vllm \
295+
--model_args pretrained=$MODEL,add_bos_token=True,kv_cache_dtype=auto \
296+
--tasks gsm8k --num_fewshot 5 --batch_size auto --limit 250
297+
```

assets/figures/ptpc/FusedGEMM.svg

+2
Loading

assets/figures/ptpc/GSM8K70B.png

43.9 KB
Loading

assets/figures/ptpc/GSM8K8B.png

43 KB
Loading

assets/figures/ptpc/PTPC-Diagram.png

58 KB
Loading

assets/figures/ptpc/PTPC-tumbnail.png

47.6 KB
Loading

assets/figures/ptpc/PTPC121.png

26.7 KB
Loading

assets/figures/ptpc/PTPCReqs.svg

+1
Loading

0 commit comments

Comments
 (0)