Skip to content

Commit 6559036

Browse files
authored
Flash2 and supports cross attention and dropout (#905)
* Support cross attention and dropout * Fix comments * Disable cudnn dropout
1 parent ee7d60d commit 6559036

File tree

8 files changed

+804
-700
lines changed

8 files changed

+804
-700
lines changed

axlearn/common/flash_attention/gpu_attention.py

Lines changed: 396 additions & 533 deletions
Large diffs are not rendered by default.

axlearn/common/flash_attention/gpu_attention_benchmark.py

Lines changed: 69 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -10,95 +10,95 @@
1010
"""FlashAttention kernel benchmarks.
1111
1212
Tor run: python3 gpu_attention_benchmark.py > out.txt
13-
Requires Jax >= 0.4.36. Sample numbers on H100 SXM5:
13+
Requires Jax >= 0.4.36. Sample numbers on H100 SXM5 with Jax == 0.4.36:
1414
is_decode=True, use_bwd=False, num_heads=8, num_kv_heads=8, per_head_dim=128, sw_sz=-1
1515
jax axlearn jax-cudnn
16-
bs=1,seq_len=1024 0.020608 0.018656 0.023680
17-
bs=1,seq_len=4096 0.037856 0.022784 0.056704
18-
bs=1,seq_len=8192 0.033792 0.032768 0.104448
19-
bs=1,seq_len=131072 0.227808 0.198816 1.486752
20-
bs=4,seq_len=1024 0.021440 0.022208 0.024032
21-
bs=4,seq_len=4096 0.069728 0.054624 0.059584
22-
bs=4,seq_len=8192 0.081952 0.076064 0.105920
23-
bs=4,seq_len=131072 0.823104 0.705056 1.488832
24-
bs=8,seq_len=1024 0.032544 0.030688 0.024608
25-
bs=8,seq_len=4096 0.089728 0.071648 0.063584
26-
bs=8,seq_len=8192 0.129184 0.114944 0.109856
27-
bs=8,seq_len=131072 1.616800 1.376288 1.503360
28-
bs=16,seq_len=1024 0.050976 0.048608 0.037504
29-
bs=16,seq_len=4096 0.136768 0.117312 0.104224
30-
bs=16,seq_len=8192 0.234688 0.200128 0.190944
31-
bs=16,seq_len=131072 3.211200 2.727040 2.779872
32-
bs=32,seq_len=1024 0.078656 0.072992 0.061440
33-
bs=32,seq_len=4096 0.236576 0.204512 0.190752
34-
bs=32,seq_len=8192 0.443488 0.372352 0.361216
35-
bs=32,seq_len=131072 6.392320 5.453344 5.495488
16+
bs=1,seq_len=1024 0.020832 0.017536 0.024128
17+
bs=1,seq_len=4096 0.037472 0.021248 0.058656
18+
bs=1,seq_len=8192 0.034016 0.032576 0.108576
19+
bs=1,seq_len=131072 0.229856 0.198944 1.558464
20+
bs=4,seq_len=1024 0.021632 0.023296 0.024352
21+
bs=4,seq_len=4096 0.068064 0.055168 0.061312
22+
bs=4,seq_len=8192 0.080352 0.075968 0.109696
23+
bs=4,seq_len=131072 0.824576 0.703360 1.560768
24+
bs=8,seq_len=1024 0.033536 0.030304 0.024448
25+
bs=8,seq_len=4096 0.089056 0.071712 0.062944
26+
bs=8,seq_len=8192 0.128960 0.114848 0.112736
27+
bs=8,seq_len=131072 1.620032 1.373088 1.566208
28+
bs=16,seq_len=1024 0.050368 0.048064 0.036608
29+
bs=16,seq_len=4096 0.134816 0.116320 0.104320
30+
bs=16,seq_len=8192 0.234880 0.200384 0.191936
31+
bs=16,seq_len=131072 3.219008 2.726912 2.784768
32+
bs=32,seq_len=1024 0.078112 0.070816 0.061568
33+
bs=32,seq_len=4096 0.235648 0.203296 0.191936
34+
bs=32,seq_len=8192 0.442080 0.371936 0.365152
35+
bs=32,seq_len=131072 6.404832 5.448480 5.541504
3636
is_decode=True, use_bwd=False, num_heads=8, seq_len=32768, per_head_dim=128, sw_sz=-1
3737
jax axlearn jax-cudnn
38-
bs=1,num_kv_heads=1 0.049280 0.059296 0.378304
39-
bs=1,num_kv_heads=8 0.076352 0.070912 0.377344
40-
bs=8,num_kv_heads=1 0.111072 0.080480 0.377696
41-
bs=8,num_kv_heads=8 0.425536 0.368576 0.386880
38+
bs=1,num_kv_heads=1 0.027648 0.058464 0.398816
39+
bs=1,num_kv_heads=8 0.076096 0.070368 0.398912
40+
bs=8,num_kv_heads=1 0.101696 0.078560 0.399040
41+
bs=8,num_kv_heads=8 0.426656 0.367616 0.403360
4242
is_decode=True, use_bwd=False, num_heads=8, num_kv_heads=8, per_head_dim=128
4343
jax axlearn jax-cudnn
44-
bs=1,seq_len=131072,sw_sz=-1 0.228640 0.199040 1.476928
45-
bs=1,seq_len=131072,sw_sz=4096 0.232320 0.053824 4.441376
46-
bs=1,seq_len=131072,sw_sz=16384 0.233696 0.061120 4.420992
47-
bs=8,seq_len=131072,sw_sz=-1 1.621696 1.374080 1.496224
48-
bs=8,seq_len=131072,sw_sz=4096 1.626016 0.193792 4.463296
49-
bs=8,seq_len=131072,sw_sz=16384 1.628704 0.318176 4.451648
44+
bs=1,seq_len=131072,sw_sz=-1 0.230336 0.199968 1.559168
45+
bs=1,seq_len=131072,sw_sz=4096 0.235296 0.051296 4.414048
46+
bs=1,seq_len=131072,sw_sz=16384 0.235904 0.062976 4.385216
47+
bs=8,seq_len=131072,sw_sz=-1 1.619008 1.372768 1.570272
48+
bs=8,seq_len=131072,sw_sz=4096 1.635424 0.194720 4.390976
49+
bs=8,seq_len=131072,sw_sz=16384 1.632832 0.321280 4.361984
5050
is_decode=False, use_bwd=False, num_heads=32, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1
5151
jax axlearn jax-cudnn jax-pallas
52-
bs=2 3.502944 0.915360 0.467744 0.845792
53-
bs=4 6.969376 1.753152 0.890496 1.617280
54-
bs=8 13.962816 3.415232 1.735232 3.150752
52+
bs=2 3.583424 0.894912 0.488480 0.852960
53+
bs=4 7.107168 1.712448 0.922592 1.629888
54+
bs=8 14.202400 3.341568 1.801920 3.184064
5555
is_decode=False, use_bwd=False, bs=2, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1
5656
jax axlearn jax-cudnn jax-pallas
57-
num_heads=12 1.262560 0.393536 0.205952 0.362304
58-
num_heads=16 1.786816 0.498304 0.257664 0.459936
59-
num_heads=32 3.507488 2.591456 0.468672 2.443296
60-
num_heads=48 5.246336 1.338272 0.675968 1.231328
61-
num_heads=72 7.866848 1.961152 0.995712 1.805376
57+
num_heads=12 1.287712 0.383200 0.214400 0.365120
58+
num_heads=16 1.803232 0.485408 0.270496 0.463040
59+
num_heads=32 3.578208 0.896576 0.488544 2.468096
60+
num_heads=48 5.346112 1.305856 0.707872 1.241728
61+
num_heads=72 8.001568 1.915776 1.035200 1.820288
6262
is_decode=False, use_bwd=False, bs=2, num_heads=32, num_kv_heads=None, per_head_dim=128, sw_sz=-1
6363
jax axlearn jax-cudnn jax-pallas
64-
seq_len=128 0.030592 0.011584 0.013024 0.012960
65-
seq_len=256 0.051520 0.015648 0.016640 0.015744
66-
seq_len=512 0.118720 0.038976 0.028224 0.037152
67-
seq_len=1024 0.310880 0.096256 0.054784 0.090368
68-
seq_len=2048 0.931072 0.277312 0.150784 0.256928
69-
seq_len=4096 3.516672 2.595872 0.465568 2.448128
64+
seq_len=256 0.049184 0.015360 0.016352 0.015488
65+
seq_len=512 0.110400 0.038624 0.028480 0.037760
66+
seq_len=1024 0.302304 0.094560 0.056736 0.090464
67+
seq_len=2048 0.936832 0.269856 0.154304 0.258944
68+
seq_len=4096 3.584800 0.895776 0.487104 2.462560
69+
seq_len=8192 14.260608 3.268320 1.742048 3.104640
7070
is_decode=False, use_bwd=False, bs=2, num_heads=32, num_kv_heads=None, seq_len=4096, sw_sz=-1
7171
jax axlearn jax-cudnn jax-pallas
72-
per_head_dim=16 3.220960 0.487808 0.332928 0.478720
73-
per_head_dim=32 3.277824 0.530240 0.334624 0.515040
74-
per_head_dim=64 3.345376 0.696480 0.338944 0.631296
75-
per_head_dim=128 3.515616 2.594208 0.465824 2.442784
72+
per_head_dim=16 3.262592 0.518912 0.356544 0.477120
73+
per_head_dim=32 3.323552 0.563520 0.358944 0.533344
74+
per_head_dim=64 3.411744 0.690464 0.360192 0.635296
75+
per_head_dim=128 3.585920 0.896032 0.488416 2.461696
7676
is_decode=False, use_bwd=True, num_heads=32, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1
7777
jax axlearn jax-cudnn jax-pallas
78-
bs=2 10.780096 4.573344 2.080672 4.487104
79-
bs=4 21.426336 9.336192 3.988224 9.159904
80-
bs=8 42.808033 18.926559 7.975296 18.075487
78+
bs=2 10.878624 3.924992 2.123008 4.504256
79+
bs=4 21.626017 8.043040 4.071552 9.186080
80+
bs=8 43.269279 16.195999 8.124896 18.184799
8181
is_decode=False, use_bwd=True, bs=2, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1
8282
jax axlearn jax-cudnn jax-pallas
83-
num_heads=12 4.128352 1.738016 0.882976 1.696704
84-
num_heads=16 5.467808 2.307488 1.120608 2.247904
85-
num_heads=32 10.782432 4.559456 2.082592 4.488448
86-
num_heads=48 16.119776 6.958272 3.027808 6.858144
87-
num_heads=72 24.140833 10.706656 4.560288 10.279136
83+
num_heads=12 4.159424 1.519680 0.898816 1.711808
84+
num_heads=16 5.486912 2.001952 1.142144 2.256960
85+
num_heads=32 10.886848 3.928896 2.114496 4.502976
86+
num_heads=48 16.224319 6.085408 3.093696 6.888640
87+
num_heads=72 24.367489 9.190560 4.642720 10.323552
8888
is_decode=False, use_bwd=True, bs=2, num_heads=32, num_kv_heads=None, per_head_dim=128, sw_sz=-1
8989
jax axlearn jax-cudnn jax-pallas
90-
seq_len=128 0.058944 0.037824 0.039040 0.036384
91-
seq_len=256 0.100384 0.069024 0.052608 0.067872
92-
seq_len=512 0.317056 0.159904 0.111840 0.158912
93-
seq_len=1024 0.906400 0.431104 0.244160 0.421792
94-
seq_len=2048 2.861056 1.319648 0.655840 1.297728
95-
seq_len=4096 10.762560 4.576864 2.079904 4.489056
90+
seq_len=256 0.094496 0.060096 0.053184 0.065760
91+
seq_len=512 0.297440 0.139328 0.112736 0.161664
92+
seq_len=1024 0.886304 0.361536 0.246848 0.418720
93+
seq_len=2048 2.857952 1.118368 0.675168 1.294144
94+
seq_len=4096 10.880512 3.914048 2.119808 4.503936
95+
seq_len=8192 43.000095 14.913824 7.484128 16.730017
9696
is_decode=False, use_bwd=True, bs=2, num_heads=32, num_kv_heads=None, seq_len=4096, sw_sz=-1
9797
jax axlearn jax-cudnn jax-pallas
98-
per_head_dim=16 10.084800 1.744640 1.263264 1.711296
99-
per_head_dim=32 10.204480 2.098816 1.291104 2.041184
100-
per_head_dim=64 10.374720 2.649888 1.335200 2.510304
101-
per_head_dim=128 10.779680 4.568096 2.079264 4.489792
98+
per_head_dim=16 10.150080 1.826656 1.288192 1.718688
99+
per_head_dim=32 10.277440 2.028608 1.316512 2.048864
100+
per_head_dim=64 10.463904 2.569408 1.364448 2.540512
101+
per_head_dim=128 10.875328 3.929568 2.124192 4.502912
102102
"""
103103
# pylint: enable=line-too-long
104104
import itertools
@@ -365,8 +365,8 @@ def bench_flash_attention_fwd_bwd(use_bwd: bool):
365365
libraries = ["jax", "axlearn", "jax-cudnn", "jax-pallas"]
366366
benchmark_sweep(libraries, common_kwargs, bs=[2, 4, 8])
367367
benchmark_sweep(libraries, common_kwargs, num_heads=[12, 16, 32, 48, 72])
368-
# 128 to 4096.
369-
benchmark_sweep(libraries, common_kwargs, seq_len=[int(2**i) for i in range(7, 13)])
368+
# 256 to 8192.
369+
benchmark_sweep(libraries, common_kwargs, seq_len=[int(2**i) for i in range(8, 14)])
370370
benchmark_sweep(libraries, common_kwargs, per_head_dim=[16, 32, 64, 128])
371371

372372

0 commit comments

Comments
 (0)