Skip to content

Commit 3183842

Browse files
MARD1NOStrongSpoon
andauthored
Dev attention[SiliconFlow] (#236)
* init attention * triton 2.3.0 cannot use make blockptr like turtial, change to use origin tl.load * fix some name and add gqa support * support query, kv different seqlen * support attn bias * add mask to prevent illegal mem access, add unittest * add perf script * add coverage test for attention * refine test script * fix perf script * fix unittest. remove writing M * add more config and preload_v * [bugfix] * address review comment * fix invalid config, allow_tf32=false to improve, prune some config * add more dtype * stage3 slightly improve * add missing and * Add more shape * remove fp32 dtype * [Operator] early prune configs to be compatible with triton2.2 * [bugfix] deliver position args for triton v3 --------- Co-authored-by: strongspoon <[email protected]>
1 parent f562266 commit 3183842

File tree

8 files changed

+533
-1
lines changed

8 files changed

+533
-1
lines changed

benchmark/core_shapes.yaml

+8
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,11 @@ ConvBenchmark:
170170
- [16, 32, 24, 24, 24, 3, 3, 1, 1, 2]
171171
- [16, 32, 24, 24, 24, 3, 3, 2, 2, 2]
172172
- [16, 32, 24, 24, 24, 3, 3, 1, 2, 2]
173+
174+
AttentionBenchmark:
175+
shapes:
176+
- [4, 8, 512, 128]
177+
- [4, 8, 1024, 128]
178+
- [4, 8, 2048, 128]
179+
- [4, 8, 3072, 128]
180+
- [4, 8, 4096, 128]

benchmark/test_attention_perf.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import pytest
2+
import torch
3+
4+
from .performance_utils import GenericBenchmark
5+
6+
7+
class AttentionBenchmark(GenericBenchmark):
8+
"""
9+
benchmark for attention
10+
"""
11+
12+
def set_more_shapes(self):
13+
# self.shapes is a list of tuples, each containing three elements:
14+
# (batch, num_heads, seq_len, head_size).
15+
return None
16+
17+
18+
@pytest.mark.attention
19+
def test_perf_scaled_dot_product_attention():
20+
def scaled_dot_product_attention_kwargs(shape, dtype, device):
21+
query = torch.randn(shape, device=device, dtype=dtype)
22+
key = torch.randn(shape, device=device, dtype=dtype)
23+
value = torch.randn(shape, device=device, dtype=dtype)
24+
yield query, key, value, None, 0.0, True
25+
26+
bench = AttentionBenchmark(
27+
op_name="scaled_dot_product_attention",
28+
input_fn=scaled_dot_product_attention_kwargs,
29+
torch_op=torch.nn.functional.scaled_dot_product_attention,
30+
dtypes=[
31+
torch.float16,
32+
torch.bfloat16,
33+
],
34+
)
35+
bench.run()

src/flag_gems/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
146146
("prod.dim_int", prod_dim, Autograd.disable),
147147
("sum", sum, Autograd.disable),
148148
("sum.dim_IntList", sum_dim, Autograd.disable),
149+
(
150+
"scaled_dot_product_attention",
151+
scaled_dot_product_attention,
152+
Autograd.disable,
153+
),
149154
("all", all, Autograd.disable),
150155
("all.dim", all_dim, Autograd.disable),
151156
("all.dims", all_dims, Autograd.disable),

src/flag_gems/ops/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .any import any, any_dim, any_dims
77
from .arange import arange, arange_start
88
from .argmax import argmax
9+
from .attention import scaled_dot_product_attention
910
from .bitwise_and import (
1011
bitwise_and_scalar,
1112
bitwise_and_scalar_tensor,
@@ -274,6 +275,7 @@
274275
"repeat_interleave_self_int",
275276
"vstack",
276277
"repeat_interleave_tensor",
278+
"scaled_dot_product_attention",
277279
"conv2d",
278280
"conv1d",
279281
"_conv_depthwise2d",

0 commit comments

Comments
 (0)