Skip to content

Commit 4bfd585

Browse files
authored
feat: add support for flash_attn>=2.6.0 (#70)
1 parent 478c3a2 commit 4bfd585

24 files changed

+121
-7
lines changed

benchmark/benchmark_longctx.py

+4
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def benchmark(num_iter=100, forward_only=True, log=True, profile=False):
157157
dropout_p=dropout_p,
158158
causal=causal,
159159
window_size=(-1, -1),
160+
softcap=0.0,
160161
alibi_slopes=None,
161162
deterministic=deterministic,
162163
return_attn_probs=False,
@@ -171,6 +172,7 @@ def benchmark(num_iter=100, forward_only=True, log=True, profile=False):
171172
dropout_p=dropout_p,
172173
causal=causal,
173174
window_size=(-1, -1),
175+
softcap=0.0,
174176
alibi_slopes=None,
175177
deterministic=deterministic,
176178
return_attn_probs=False,
@@ -194,6 +196,7 @@ def benchmark(num_iter=100, forward_only=True, log=True, profile=False):
194196
dropout_p=dropout_p,
195197
causal=causal,
196198
window_size=(-1, -1),
199+
softcap=0.0,
197200
alibi_slopes=None,
198201
deterministic=deterministic,
199202
return_attn_probs=False,
@@ -215,6 +218,7 @@ def benchmark(num_iter=100, forward_only=True, log=True, profile=False):
215218
dropout_p=dropout_p,
216219
causal=causal,
217220
window_size=(-1, -1),
221+
softcap=0.0,
218222
alibi_slopes=None,
219223
deterministic=deterministic,
220224
return_attn_probs=False,

benchmark/benchmark_longctx_qkvpacked.py

+3
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def benchmark(num_iter=100, forward_only=True, log=True):
117117
dropout_p=dropout_p,
118118
causal=causal,
119119
window_size=(-1, -1),
120+
softcap=0.0,
120121
alibi_slopes=None,
121122
deterministic=deterministic,
122123
return_attn_probs=False,
@@ -134,6 +135,7 @@ def benchmark(num_iter=100, forward_only=True, log=True):
134135
dropout_p=dropout_p,
135136
causal=causal,
136137
window_size=(-1, -1),
138+
softcap=0.0,
137139
alibi_slopes=None,
138140
deterministic=deterministic,
139141
return_attn_probs=False,
@@ -147,6 +149,7 @@ def benchmark(num_iter=100, forward_only=True, log=True):
147149
dropout_p=dropout_p,
148150
causal=causal,
149151
window_size=(-1, -1),
152+
softcap=0.0,
150153
alibi_slopes=None,
151154
deterministic=deterministic,
152155
return_attn_probs=False,

benchmark/benchmark_qkvpacked_func.py

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
5656
dropout_p=dropout_p,
5757
causal=causal,
5858
window_size=(-1, -1),
59+
softcap=0.0,
5960
alibi_slopes=None,
6061
deterministic=deterministic,
6162
return_attn_probs=False,
@@ -65,6 +66,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
6566
dropout_p=dropout_p,
6667
causal=causal,
6768
window_size=(-1, -1),
69+
softcap=0.0,
6870
alibi_slopes=None,
6971
deterministic=deterministic,
7072
return_attn_probs=False,
@@ -82,6 +84,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
8284
dropout_p=dropout_p,
8385
causal=causal,
8486
window_size=(-1, -1),
87+
softcap=0.0,
8588
alibi_slopes=None,
8689
deterministic=deterministic,
8790
return_attn_probs=False,
@@ -95,6 +98,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
9598
dropout_p=dropout_p,
9699
causal=causal,
97100
window_size=(-1, -1),
101+
softcap=0.0,
98102
alibi_slopes=None,
99103
deterministic=deterministic,
100104
return_attn_probs=False,

benchmark/benchmark_ring_func.py

+4
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
7979
dropout_p=dropout_p,
8080
causal=causal,
8181
window_size=(-1, -1),
82+
softcap=0.0,
8283
alibi_slopes=None,
8384
deterministic=deterministic,
8485
return_attn_probs=False,
@@ -90,6 +91,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
9091
dropout_p=dropout_p,
9192
causal=causal,
9293
window_size=(-1, -1),
94+
softcap=0.0,
9395
alibi_slopes=None,
9496
deterministic=deterministic,
9597
return_attn_probs=False,
@@ -109,6 +111,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
109111
dropout_p=dropout_p,
110112
causal=causal,
111113
window_size=(-1, -1),
114+
softcap=0.0,
112115
alibi_slopes=None,
113116
deterministic=deterministic,
114117
return_attn_probs=False,
@@ -126,6 +129,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
126129
dropout_p=dropout_p,
127130
causal=causal,
128131
window_size=(-1, -1),
132+
softcap=0.0,
129133
alibi_slopes=None,
130134
deterministic=deterministic,
131135
return_attn_probs=False,

benchmark/benchmark_varlen_qkvpacked_func.py

+2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
6464
dropout_p=dropout_p,
6565
causal=causal,
6666
window_size=(-1, -1),
67+
softcap=0.0,
6768
alibi_slopes=None,
6869
deterministic=deterministic,
6970
return_attn_probs=False,
@@ -78,6 +79,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
7879
dropout_p=dropout_p,
7980
causal=causal,
8081
window_size=(-1, -1),
82+
softcap=0.0,
8183
alibi_slopes=None,
8284
deterministic=deterministic,
8385
return_attn_probs=False,

scripts/run_gqa.sh

100644100755
File mode changed.

scripts/run_qkvpack_compare.sh

100644100755
File mode changed.

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
setup(
44
name="yunchang",
5-
version="0.2",
5+
version="0.3",
66
author="Jiarui Fang, Zilin Zhu, Yang Yu",
77
url="https://github.com/feifeibear/long-context-attention",
88
packages=find_packages(exclude=['test', 'benchmark']),
99
install_requires=[
10-
'flash-attn',
10+
'flash-attn>=2.6.0',
1111
],
1212
)

test/test_hybrid_attn.py

+2
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def log(msg, a, rank0_only=False):
110110
dropout_p=dropout_p,
111111
causal=causal,
112112
window_size=(-1, -1),
113+
softcap=0.0,
113114
alibi_slopes=None,
114115
deterministic=deterministic,
115116
return_attn_probs=True,
@@ -137,6 +138,7 @@ def log(msg, a, rank0_only=False):
137138
dropout_p=dropout_p,
138139
causal=causal,
139140
window_size=(-1, -1),
141+
softcap=0.0,
140142
alibi_slopes=None,
141143
deterministic=deterministic,
142144
return_attn_probs=True,

test/test_hybrid_qkvpacked_attn.py

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def test(ring_impl_type="basic"):
111111
dropout_p=dropout_p,
112112
causal=causal,
113113
window_size=(-1, -1),
114+
softcap=0.0,
114115
alibi_slopes=None,
115116
deterministic=deterministic,
116117
return_attn_probs=True,
@@ -124,6 +125,7 @@ def test(ring_impl_type="basic"):
124125
dropout_p=dropout_p,
125126
causal=causal,
126127
window_size=(-1, -1),
128+
softcap=0.0,
127129
alibi_slopes=None,
128130
deterministic=deterministic,
129131
return_attn_probs=True,

test/test_ring_flash_attn_func.py

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def log(msg, a, rank0_only=False):
7373
dropout_p=dropout_p,
7474
causal=causal,
7575
window_size=(-1, -1),
76+
softcap=0.0,
7677
alibi_slopes=None,
7778
deterministic=deterministic,
7879
return_attn_probs=True,
@@ -88,6 +89,7 @@ def log(msg, a, rank0_only=False):
8889
dropout_p=dropout_p,
8990
causal=causal,
9091
window_size=(-1, -1),
92+
softcap=0.0,
9193
alibi_slopes=None,
9294
deterministic=deterministic,
9395
return_attn_probs=True,

test/test_ring_flash_attn_varlen_func.py

+2
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def extract_lse(lse, cu_seqlens):
9999
dropout_p=dropout_p,
100100
causal=causal,
101101
window_size=(-1, -1),
102+
softcap=0.0,
102103
alibi_slopes=None,
103104
deterministic=deterministic,
104105
return_attn_probs=True,
@@ -114,6 +115,7 @@ def extract_lse(lse, cu_seqlens):
114115
dropout_p=dropout_p,
115116
causal=causal,
116117
window_size=(-1, -1),
118+
softcap=0.0,
117119
alibi_slopes=None,
118120
deterministic=deterministic,
119121
return_attn_probs=True,

test/test_stripe_flash_attn_func.py

+2
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def extract_local(value, rank, world_size, dim=1):
8080
dropout_p=dropout_p,
8181
causal=causal,
8282
window_size=(-1, -1),
83+
softcap=0.0,
8384
alibi_slopes=None,
8485
deterministic=deterministic,
8586
return_attn_probs=True,
@@ -93,6 +94,7 @@ def extract_local(value, rank, world_size, dim=1):
9394
dropout_p=dropout_p,
9495
causal=causal,
9596
window_size=(-1, -1),
97+
softcap=0.0,
9698
alibi_slopes=None,
9799
deterministic=deterministic,
98100
return_attn_probs=True,

test/test_ulysses_attn.py

+2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def log(msg, a, rank0_only=False):
9393
dropout_p=dropout_p,
9494
causal=causal,
9595
window_size=(-1, -1),
96+
softcap=0.0,
9697
alibi_slopes=None,
9798
deterministic=deterministic,
9899
return_attn_probs=True,
@@ -119,6 +120,7 @@ def log(msg, a, rank0_only=False):
119120
dropout_p=dropout_p,
120121
causal=causal,
121122
window_size=(-1, -1),
123+
softcap=0.0,
122124
alibi_slopes=None,
123125
deterministic=deterministic,
124126
return_attn_probs=True,

test/test_zigzag_ring_flash_attn_func.py

+2
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def extract_local(value, rank, world_size, dim=1):
8080
dropout_p=dropout_p,
8181
causal=causal,
8282
window_size=(-1, -1),
83+
softcap=0.0,
8384
alibi_slopes=None,
8485
deterministic=deterministic,
8586
return_attn_probs=True,
@@ -93,6 +94,7 @@ def extract_local(value, rank, world_size, dim=1):
9394
dropout_p=dropout_p,
9495
causal=causal,
9596
window_size=(-1, -1),
97+
softcap=0.0,
9698
alibi_slopes=None,
9799
deterministic=deterministic,
98100
return_attn_probs=True,

test/test_zigzag_ring_flash_attn_varlen_func.py

+2
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def extract_lse(lse, cu_seqlens):
104104
dropout_p=dropout_p,
105105
causal=causal,
106106
window_size=(-1, -1),
107+
softcap=0.0,
107108
alibi_slopes=None,
108109
deterministic=deterministic,
109110
return_attn_probs=True,
@@ -119,6 +120,7 @@ def extract_lse(lse, cu_seqlens):
119120
dropout_p=dropout_p,
120121
causal=causal,
121122
window_size=(-1, -1),
123+
softcap=0.0,
122124
alibi_slopes=None,
123125
deterministic=deterministic,
124126
return_attn_probs=True,

yunchang/hybrid/async_attn_layer.py

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def forward(
5050
softmax_scale=None,
5151
causal=False,
5252
window_size=(-1, -1),
53+
softcap=0.0,
5354
alibi_slopes=None,
5455
deterministic=False,
5556
return_attn_probs=False,
@@ -148,6 +149,7 @@ def forward(
148149
softmax_scale=softmax_scale,
149150
causal=causal,
150151
window_size=window_size,
152+
softcap=softcap,
151153
alibi_slopes=alibi_slopes,
152154
deterministic=deterministic,
153155
return_attn_probs=return_attn_probs,

yunchang/hybrid/attn_layer.py

+5
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def forward(
4949
softmax_scale=None,
5050
causal=False,
5151
window_size=(-1, -1),
52+
softcap=0.0,
5253
alibi_slopes=None,
5354
deterministic=False,
5455
return_attn_probs=False,
@@ -84,6 +85,7 @@ def forward(
8485
softmax_scale=softmax_scale,
8586
causal=causal,
8687
window_size=window_size,
88+
softcap=softcap,
8789
alibi_slopes=alibi_slopes,
8890
deterministic=deterministic,
8991
return_attn_probs=return_attn_probs,
@@ -108,6 +110,7 @@ def forward(
108110
softmax_scale=softmax_scale,
109111
causal=causal,
110112
window_size=window_size,
113+
softcap=softcap,
111114
alibi_slopes=alibi_slopes,
112115
deterministic=deterministic,
113116
return_attn_probs=return_attn_probs,
@@ -166,6 +169,7 @@ def forward(
166169
softmax_scale=None,
167170
causal=False,
168171
window_size=(-1, -1),
172+
softcap=0.0,
169173
alibi_slopes=None,
170174
deterministic=False,
171175
return_attn_probs=False,
@@ -198,6 +202,7 @@ def forward(
198202
softmax_scale=softmax_scale,
199203
causal=causal,
200204
window_size=window_size,
205+
softcap=softcap,
201206
alibi_slopes=alibi_slopes,
202207
deterministic=deterministic,
203208
return_attn_probs=return_attn_probs,

0 commit comments

Comments
 (0)