|
10 | 10 | """FlashAttention kernel benchmarks. |
11 | 11 |
|
12 | 12 | Tor run: python3 gpu_attention_benchmark.py > out.txt |
13 | | -Requires Jax >= 0.4.36. Sample numbers on H100 SXM5: |
| 13 | +Requires Jax >= 0.4.36. Sample numbers on H100 SXM5 with Jax == 0.4.36: |
14 | 14 | is_decode=True, use_bwd=False, num_heads=8, num_kv_heads=8, per_head_dim=128, sw_sz=-1 |
15 | 15 | jax axlearn jax-cudnn |
16 | | -bs=1,seq_len=1024 0.020608 0.018656 0.023680 |
17 | | -bs=1,seq_len=4096 0.037856 0.022784 0.056704 |
18 | | -bs=1,seq_len=8192 0.033792 0.032768 0.104448 |
19 | | -bs=1,seq_len=131072 0.227808 0.198816 1.486752 |
20 | | -bs=4,seq_len=1024 0.021440 0.022208 0.024032 |
21 | | -bs=4,seq_len=4096 0.069728 0.054624 0.059584 |
22 | | -bs=4,seq_len=8192 0.081952 0.076064 0.105920 |
23 | | -bs=4,seq_len=131072 0.823104 0.705056 1.488832 |
24 | | -bs=8,seq_len=1024 0.032544 0.030688 0.024608 |
25 | | -bs=8,seq_len=4096 0.089728 0.071648 0.063584 |
26 | | -bs=8,seq_len=8192 0.129184 0.114944 0.109856 |
27 | | -bs=8,seq_len=131072 1.616800 1.376288 1.503360 |
28 | | -bs=16,seq_len=1024 0.050976 0.048608 0.037504 |
29 | | -bs=16,seq_len=4096 0.136768 0.117312 0.104224 |
30 | | -bs=16,seq_len=8192 0.234688 0.200128 0.190944 |
31 | | -bs=16,seq_len=131072 3.211200 2.727040 2.779872 |
32 | | -bs=32,seq_len=1024 0.078656 0.072992 0.061440 |
33 | | -bs=32,seq_len=4096 0.236576 0.204512 0.190752 |
34 | | -bs=32,seq_len=8192 0.443488 0.372352 0.361216 |
35 | | -bs=32,seq_len=131072 6.392320 5.453344 5.495488 |
| 16 | +bs=1,seq_len=1024 0.020832 0.017536 0.024128 |
| 17 | +bs=1,seq_len=4096 0.037472 0.021248 0.058656 |
| 18 | +bs=1,seq_len=8192 0.034016 0.032576 0.108576 |
| 19 | +bs=1,seq_len=131072 0.229856 0.198944 1.558464 |
| 20 | +bs=4,seq_len=1024 0.021632 0.023296 0.024352 |
| 21 | +bs=4,seq_len=4096 0.068064 0.055168 0.061312 |
| 22 | +bs=4,seq_len=8192 0.080352 0.075968 0.109696 |
| 23 | +bs=4,seq_len=131072 0.824576 0.703360 1.560768 |
| 24 | +bs=8,seq_len=1024 0.033536 0.030304 0.024448 |
| 25 | +bs=8,seq_len=4096 0.089056 0.071712 0.062944 |
| 26 | +bs=8,seq_len=8192 0.128960 0.114848 0.112736 |
| 27 | +bs=8,seq_len=131072 1.620032 1.373088 1.566208 |
| 28 | +bs=16,seq_len=1024 0.050368 0.048064 0.036608 |
| 29 | +bs=16,seq_len=4096 0.134816 0.116320 0.104320 |
| 30 | +bs=16,seq_len=8192 0.234880 0.200384 0.191936 |
| 31 | +bs=16,seq_len=131072 3.219008 2.726912 2.784768 |
| 32 | +bs=32,seq_len=1024 0.078112 0.070816 0.061568 |
| 33 | +bs=32,seq_len=4096 0.235648 0.203296 0.191936 |
| 34 | +bs=32,seq_len=8192 0.442080 0.371936 0.365152 |
| 35 | +bs=32,seq_len=131072 6.404832 5.448480 5.541504 |
36 | 36 | is_decode=True, use_bwd=False, num_heads=8, seq_len=32768, per_head_dim=128, sw_sz=-1 |
37 | 37 | jax axlearn jax-cudnn |
38 | | -bs=1,num_kv_heads=1 0.049280 0.059296 0.378304 |
39 | | -bs=1,num_kv_heads=8 0.076352 0.070912 0.377344 |
40 | | -bs=8,num_kv_heads=1 0.111072 0.080480 0.377696 |
41 | | -bs=8,num_kv_heads=8 0.425536 0.368576 0.386880 |
| 38 | +bs=1,num_kv_heads=1 0.027648 0.058464 0.398816 |
| 39 | +bs=1,num_kv_heads=8 0.076096 0.070368 0.398912 |
| 40 | +bs=8,num_kv_heads=1 0.101696 0.078560 0.399040 |
| 41 | +bs=8,num_kv_heads=8 0.426656 0.367616 0.403360 |
42 | 42 | is_decode=True, use_bwd=False, num_heads=8, num_kv_heads=8, per_head_dim=128 |
43 | 43 | jax axlearn jax-cudnn |
44 | | -bs=1,seq_len=131072,sw_sz=-1 0.228640 0.199040 1.476928 |
45 | | -bs=1,seq_len=131072,sw_sz=4096 0.232320 0.053824 4.441376 |
46 | | -bs=1,seq_len=131072,sw_sz=16384 0.233696 0.061120 4.420992 |
47 | | -bs=8,seq_len=131072,sw_sz=-1 1.621696 1.374080 1.496224 |
48 | | -bs=8,seq_len=131072,sw_sz=4096 1.626016 0.193792 4.463296 |
49 | | -bs=8,seq_len=131072,sw_sz=16384 1.628704 0.318176 4.451648 |
| 44 | +bs=1,seq_len=131072,sw_sz=-1 0.230336 0.199968 1.559168 |
| 45 | +bs=1,seq_len=131072,sw_sz=4096 0.235296 0.051296 4.414048 |
| 46 | +bs=1,seq_len=131072,sw_sz=16384 0.235904 0.062976 4.385216 |
| 47 | +bs=8,seq_len=131072,sw_sz=-1 1.619008 1.372768 1.570272 |
| 48 | +bs=8,seq_len=131072,sw_sz=4096 1.635424 0.194720 4.390976 |
| 49 | +bs=8,seq_len=131072,sw_sz=16384 1.632832 0.321280 4.361984 |
50 | 50 | is_decode=False, use_bwd=False, num_heads=32, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1 |
51 | 51 | jax axlearn jax-cudnn jax-pallas |
52 | | -bs=2 3.502944 0.915360 0.467744 0.845792 |
53 | | -bs=4 6.969376 1.753152 0.890496 1.617280 |
54 | | -bs=8 13.962816 3.415232 1.735232 3.150752 |
| 52 | +bs=2 3.583424 0.894912 0.488480 0.852960 |
| 53 | +bs=4 7.107168 1.712448 0.922592 1.629888 |
| 54 | +bs=8 14.202400 3.341568 1.801920 3.184064 |
55 | 55 | is_decode=False, use_bwd=False, bs=2, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1 |
56 | 56 | jax axlearn jax-cudnn jax-pallas |
57 | | -num_heads=12 1.262560 0.393536 0.205952 0.362304 |
58 | | -num_heads=16 1.786816 0.498304 0.257664 0.459936 |
59 | | -num_heads=32 3.507488 2.591456 0.468672 2.443296 |
60 | | -num_heads=48 5.246336 1.338272 0.675968 1.231328 |
61 | | -num_heads=72 7.866848 1.961152 0.995712 1.805376 |
| 57 | +num_heads=12 1.287712 0.383200 0.214400 0.365120 |
| 58 | +num_heads=16 1.803232 0.485408 0.270496 0.463040 |
| 59 | +num_heads=32 3.578208 0.896576 0.488544 2.468096 |
| 60 | +num_heads=48 5.346112 1.305856 0.707872 1.241728 |
| 61 | +num_heads=72 8.001568 1.915776 1.035200 1.820288 |
62 | 62 | is_decode=False, use_bwd=False, bs=2, num_heads=32, num_kv_heads=None, per_head_dim=128, sw_sz=-1 |
63 | 63 | jax axlearn jax-cudnn jax-pallas |
64 | | -seq_len=128 0.030592 0.011584 0.013024 0.012960 |
65 | | -seq_len=256 0.051520 0.015648 0.016640 0.015744 |
66 | | -seq_len=512 0.118720 0.038976 0.028224 0.037152 |
67 | | -seq_len=1024 0.310880 0.096256 0.054784 0.090368 |
68 | | -seq_len=2048 0.931072 0.277312 0.150784 0.256928 |
69 | | -seq_len=4096 3.516672 2.595872 0.465568 2.448128 |
| 64 | +seq_len=256 0.049184 0.015360 0.016352 0.015488 |
| 65 | +seq_len=512 0.110400 0.038624 0.028480 0.037760 |
| 66 | +seq_len=1024 0.302304 0.094560 0.056736 0.090464 |
| 67 | +seq_len=2048 0.936832 0.269856 0.154304 0.258944 |
| 68 | +seq_len=4096 3.584800 0.895776 0.487104 2.462560 |
| 69 | +seq_len=8192 14.260608 3.268320 1.742048 3.104640 |
70 | 70 | is_decode=False, use_bwd=False, bs=2, num_heads=32, num_kv_heads=None, seq_len=4096, sw_sz=-1 |
71 | 71 | jax axlearn jax-cudnn jax-pallas |
72 | | -per_head_dim=16 3.220960 0.487808 0.332928 0.478720 |
73 | | -per_head_dim=32 3.277824 0.530240 0.334624 0.515040 |
74 | | -per_head_dim=64 3.345376 0.696480 0.338944 0.631296 |
75 | | -per_head_dim=128 3.515616 2.594208 0.465824 2.442784 |
| 72 | +per_head_dim=16 3.262592 0.518912 0.356544 0.477120 |
| 73 | +per_head_dim=32 3.323552 0.563520 0.358944 0.533344 |
| 74 | +per_head_dim=64 3.411744 0.690464 0.360192 0.635296 |
| 75 | +per_head_dim=128 3.585920 0.896032 0.488416 2.461696 |
76 | 76 | is_decode=False, use_bwd=True, num_heads=32, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1 |
77 | 77 | jax axlearn jax-cudnn jax-pallas |
78 | | -bs=2 10.780096 4.573344 2.080672 4.487104 |
79 | | -bs=4 21.426336 9.336192 3.988224 9.159904 |
80 | | -bs=8 42.808033 18.926559 7.975296 18.075487 |
| 78 | +bs=2 10.878624 3.924992 2.123008 4.504256 |
| 79 | +bs=4 21.626017 8.043040 4.071552 9.186080 |
| 80 | +bs=8 43.269279 16.195999 8.124896 18.184799 |
81 | 81 | is_decode=False, use_bwd=True, bs=2, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1 |
82 | 82 | jax axlearn jax-cudnn jax-pallas |
83 | | -num_heads=12 4.128352 1.738016 0.882976 1.696704 |
84 | | -num_heads=16 5.467808 2.307488 1.120608 2.247904 |
85 | | -num_heads=32 10.782432 4.559456 2.082592 4.488448 |
86 | | -num_heads=48 16.119776 6.958272 3.027808 6.858144 |
87 | | -num_heads=72 24.140833 10.706656 4.560288 10.279136 |
| 83 | +num_heads=12 4.159424 1.519680 0.898816 1.711808 |
| 84 | +num_heads=16 5.486912 2.001952 1.142144 2.256960 |
| 85 | +num_heads=32 10.886848 3.928896 2.114496 4.502976 |
| 86 | +num_heads=48 16.224319 6.085408 3.093696 6.888640 |
| 87 | +num_heads=72 24.367489 9.190560 4.642720 10.323552 |
88 | 88 | is_decode=False, use_bwd=True, bs=2, num_heads=32, num_kv_heads=None, per_head_dim=128, sw_sz=-1 |
89 | 89 | jax axlearn jax-cudnn jax-pallas |
90 | | -seq_len=128 0.058944 0.037824 0.039040 0.036384 |
91 | | -seq_len=256 0.100384 0.069024 0.052608 0.067872 |
92 | | -seq_len=512 0.317056 0.159904 0.111840 0.158912 |
93 | | -seq_len=1024 0.906400 0.431104 0.244160 0.421792 |
94 | | -seq_len=2048 2.861056 1.319648 0.655840 1.297728 |
95 | | -seq_len=4096 10.762560 4.576864 2.079904 4.489056 |
| 90 | +seq_len=256 0.094496 0.060096 0.053184 0.065760 |
| 91 | +seq_len=512 0.297440 0.139328 0.112736 0.161664 |
| 92 | +seq_len=1024 0.886304 0.361536 0.246848 0.418720 |
| 93 | +seq_len=2048 2.857952 1.118368 0.675168 1.294144 |
| 94 | +seq_len=4096 10.880512 3.914048 2.119808 4.503936 |
| 95 | +seq_len=8192 43.000095 14.913824 7.484128 16.730017 |
96 | 96 | is_decode=False, use_bwd=True, bs=2, num_heads=32, num_kv_heads=None, seq_len=4096, sw_sz=-1 |
97 | 97 | jax axlearn jax-cudnn jax-pallas |
98 | | -per_head_dim=16 10.084800 1.744640 1.263264 1.711296 |
99 | | -per_head_dim=32 10.204480 2.098816 1.291104 2.041184 |
100 | | -per_head_dim=64 10.374720 2.649888 1.335200 2.510304 |
101 | | -per_head_dim=128 10.779680 4.568096 2.079264 4.489792 |
| 98 | +per_head_dim=16 10.150080 1.826656 1.288192 1.718688 |
| 99 | +per_head_dim=32 10.277440 2.028608 1.316512 2.048864 |
| 100 | +per_head_dim=64 10.463904 2.569408 1.364448 2.540512 |
| 101 | +per_head_dim=128 10.875328 3.929568 2.124192 4.502912 |
102 | 102 | """ |
103 | 103 | # pylint: enable=line-too-long |
104 | 104 | import itertools |
@@ -365,8 +365,8 @@ def bench_flash_attention_fwd_bwd(use_bwd: bool): |
365 | 365 | libraries = ["jax", "axlearn", "jax-cudnn", "jax-pallas"] |
366 | 366 | benchmark_sweep(libraries, common_kwargs, bs=[2, 4, 8]) |
367 | 367 | benchmark_sweep(libraries, common_kwargs, num_heads=[12, 16, 32, 48, 72]) |
368 | | - # 128 to 4096. |
369 | | - benchmark_sweep(libraries, common_kwargs, seq_len=[int(2**i) for i in range(7, 13)]) |
| 368 | + # 256 to 8192. |
| 369 | + benchmark_sweep(libraries, common_kwargs, seq_len=[int(2**i) for i in range(8, 14)]) |
370 | 370 | benchmark_sweep(libraries, common_kwargs, per_head_dim=[16, 32, 64, 128]) |
371 | 371 |
|
372 | 372 |
|
|
0 commit comments