1
1
import os
2
- from yunchang import (
3
- LongContextAttention ,
4
- set_seq_parallel_pg ,
5
- EXTRACT_FUNC_DICT
6
- )
2
+ from yunchang import LongContextAttention , set_seq_parallel_pg , EXTRACT_FUNC_DICT
7
3
import torch
8
4
import torch .distributed as dist
5
+
9
6
try :
10
7
from flash_attn import flash_attn_func
11
8
except ImportError :
14
11
from test_utils import attention_ref
15
12
import argparse
16
13
14
+
17
15
def parse_args ():
18
- parser = argparse .ArgumentParser (description = 'Test hybrid attention with configurable sequence length' )
19
- parser .add_argument ('--seqlen' , type = int , default = 1024 ,
20
- help = 'sequence length (default: 1024)' )
21
- parser .add_argument ('--use_bwd' , action = 'store_true' ,
22
- help = 'whether to test backward pass (default: False)' )
23
- parser .add_argument ('--sp_ulysses_degree' , type = int , default = None ,
24
- help = 'sp_ulysses_degree (default: world_size)' )
25
- parser .add_argument ('--ring_impl_type' , type = str , default = 'basic' ,
26
- choices = ['basic' , 'zigzag' , 'basic_flashinfer' ],
27
- help = 'ring implementation type (default: basic)' )
28
- parser .add_argument ('--causal' , action = 'store_true' ,
29
- help = 'whether to use causal attention (default: False)' )
30
- parser .add_argument ('--attn_impl' , type = str , default = 'torch' ,
31
- choices = ['torch' , 'fa' , 'fa3' , 'flashinfer' , 'sage_fp16' , 'sage_fp16_triton' , 'sage_fp8' , 'sparse_sage' ],
32
- help = 'attention implementation type (default: torch)' )
33
- parser .add_argument ('--sparse_sage_l1' , type = float , default = 0.07 ,
34
- help = 'l1 for sparse sage attention (default: 0.07)' )
35
- parser .add_argument ('--sparse_sage_pv_l1' , type = float , default = 0.08 ,
36
- help = 'pv_l1 for sparse sage attention (default: 0.08)' )
37
- parser .add_argument ('--sparse_sage_tune_mode' , action = 'store_true' , default = False ,
38
- help = 'enable tune mode for sparse sage attention (default: False)' )
39
- parser .add_argument ('--sparse_sage_tune_path' , type = str , default = './sparsesage_autotune.pt' ,
40
- help = 'path to the sparse sage autotune results (default: ./sparsesage_autotune.pt)' )
16
+ parser = argparse .ArgumentParser (
17
+ description = "Test hybrid attention with configurable sequence length"
18
+ )
19
+ parser .add_argument (
20
+ "--seqlen" , type = int , default = 1024 , help = "sequence length (default: 1024)"
21
+ )
22
+ parser .add_argument (
23
+ "--use_bwd" ,
24
+ action = "store_true" ,
25
+ help = "whether to test backward pass (default: False)" ,
26
+ )
27
+ parser .add_argument (
28
+ "--sp_ulysses_degree" ,
29
+ type = int ,
30
+ default = None ,
31
+ help = "sp_ulysses_degree (default: world_size)" ,
32
+ )
33
+ parser .add_argument (
34
+ "--ring_impl_type" ,
35
+ type = str ,
36
+ default = "basic" ,
37
+ choices = ["basic" , "zigzag" , "basic_flashinfer" ],
38
+ help = "ring implementation type (default: basic)" ,
39
+ )
40
+ parser .add_argument (
41
+ "--causal" ,
42
+ action = "store_true" ,
43
+ help = "whether to use causal attention (default: False)" ,
44
+ )
45
+ parser .add_argument (
46
+ "--attn_impl" ,
47
+ type = str ,
48
+ default = "torch" ,
49
+ choices = [
50
+ "torch" ,
51
+ "fa" ,
52
+ "fa3" ,
53
+ "flashinfer" ,
54
+ "sage_fp16" ,
55
+ "sage_fp8" ,
56
+ "sparse_sage" ,
57
+ "sage_fp8_sm90" ,
58
+ "sage_fp16_triton" ,
59
+ "sage_auto" ,
60
+ ],
61
+ help = "attention implementation type (default: torch)" ,
62
+ )
63
+ parser .add_argument (
64
+ "--sparse_sage_l1" ,
65
+ type = float ,
66
+ default = 0.07 ,
67
+ help = "l1 for sparse sage attention (default: 0.07)" ,
68
+ )
69
+ parser .add_argument (
70
+ "--sparse_sage_pv_l1" ,
71
+ type = float ,
72
+ default = 0.08 ,
73
+ help = "pv_l1 for sparse sage attention (default: 0.08)" ,
74
+ )
75
+ parser .add_argument (
76
+ "--sparse_sage_tune_mode" ,
77
+ action = "store_true" ,
78
+ default = False ,
79
+ help = "enable tune mode for sparse sage attention (default: False)" ,
80
+ )
81
+ parser .add_argument (
82
+ "--sparse_sage_tune_path" ,
83
+ type = str ,
84
+ default = "./sparsesage_autotune.pt" ,
85
+ help = "path to the sparse sage autotune results (default: ./sparsesage_autotune.pt)" ,
86
+ )
41
87
return parser .parse_args ()
42
88
89
+
43
90
def log (msg , a , rank0_only = False ):
44
91
world_size = dist .get_world_size ()
45
92
rank = dist .get_rank ()
@@ -65,19 +112,20 @@ def log(msg, a, rank0_only=False):
65
112
)
66
113
dist .barrier ()
67
114
115
+
68
116
# test it with:
69
117
# torchrun --nproc_per_node=4 test/test_hybrid_attn.py
70
118
if __name__ == "__main__" :
71
119
args = parse_args ()
72
-
120
+
73
121
torch .random .manual_seed (0 )
74
122
75
123
dist .init_process_group ("nccl" )
76
124
77
125
rank = dist .get_rank ()
78
126
world_size = dist .get_world_size ()
79
127
80
- # Inference mainly uses fp16; ROCM flash attention with bf16 precision is slightly larger, will be fixed soon
128
+ # Inference mainly uses fp16; ROCM flash attention with bf16 precision is slightly larger, will be fixed soon
81
129
dtype = torch .bfloat16
82
130
device = torch .device (f"cuda:{ rank } " )
83
131
@@ -88,7 +136,7 @@ def log(msg, a, rank0_only=False):
88
136
dropout_p = 0
89
137
causal = args .causal
90
138
deterministic = False
91
-
139
+
92
140
use_bwd = args .use_bwd
93
141
94
142
assert seqlen % world_size == 0
@@ -98,13 +146,31 @@ def log(msg, a, rank0_only=False):
98
146
99
147
# Prepare inputs
100
148
q = torch .randn (
101
- batch_size , seqlen , nheads , d , device = device , dtype = dtype , requires_grad = True if use_bwd else False
149
+ batch_size ,
150
+ seqlen ,
151
+ nheads ,
152
+ d ,
153
+ device = device ,
154
+ dtype = dtype ,
155
+ requires_grad = True if use_bwd else False ,
102
156
)
103
157
k = torch .randn (
104
- batch_size , seqlen , nheads , d , device = device , dtype = dtype , requires_grad = True if use_bwd else False
158
+ batch_size ,
159
+ seqlen ,
160
+ nheads ,
161
+ d ,
162
+ device = device ,
163
+ dtype = dtype ,
164
+ requires_grad = True if use_bwd else False ,
105
165
)
106
166
v = torch .randn (
107
- batch_size , seqlen , nheads , d , device = device , dtype = dtype , requires_grad = True if use_bwd else False
167
+ batch_size ,
168
+ seqlen ,
169
+ nheads ,
170
+ d ,
171
+ device = device ,
172
+ dtype = dtype ,
173
+ requires_grad = True if use_bwd else False ,
108
174
)
109
175
dout = torch .randn (batch_size , seqlen , nheads , d , device = device , dtype = dtype )
110
176
@@ -116,7 +182,9 @@ def log(msg, a, rank0_only=False):
116
182
# prepare process group for hybrid sequence parallelism
117
183
use_ring_low_dim = True
118
184
119
- sp_ulysses_degree = args .sp_ulysses_degree if args .sp_ulysses_degree is not None else world_size
185
+ sp_ulysses_degree = (
186
+ args .sp_ulysses_degree if args .sp_ulysses_degree is not None else world_size
187
+ )
120
188
sp_ring_degree = world_size // sp_ulysses_degree
121
189
122
190
print (
@@ -126,19 +194,29 @@ def log(msg, a, rank0_only=False):
126
194
set_seq_parallel_pg (sp_ulysses_degree , sp_ring_degree , rank , world_size )
127
195
128
196
# Use EXTRACT_FUNC_DICT to shard the tensors
129
- local_q = EXTRACT_FUNC_DICT [ring_impl_type ](
130
- q , rank , world_size = world_size , rd = sp_ring_degree , ud = sp_ulysses_degree
131
- ).detach ().clone ()
132
-
133
-
134
- local_k = EXTRACT_FUNC_DICT [ring_impl_type ](
135
- k , rank , world_size = world_size , rd = sp_ring_degree , ud = sp_ulysses_degree
136
- ).detach ().clone ()
197
+ local_q = (
198
+ EXTRACT_FUNC_DICT [ring_impl_type ](
199
+ q , rank , world_size = world_size , rd = sp_ring_degree , ud = sp_ulysses_degree
200
+ )
201
+ .detach ()
202
+ .clone ()
203
+ )
137
204
205
+ local_k = (
206
+ EXTRACT_FUNC_DICT [ring_impl_type ](
207
+ k , rank , world_size = world_size , rd = sp_ring_degree , ud = sp_ulysses_degree
208
+ )
209
+ .detach ()
210
+ .clone ()
211
+ )
138
212
139
- local_v = EXTRACT_FUNC_DICT [ring_impl_type ](
140
- v , rank , world_size = world_size , rd = sp_ring_degree , ud = sp_ulysses_degree
141
- ).detach ().clone ()
213
+ local_v = (
214
+ EXTRACT_FUNC_DICT [ring_impl_type ](
215
+ v , rank , world_size = world_size , rd = sp_ring_degree , ud = sp_ulysses_degree
216
+ )
217
+ .detach ()
218
+ .clone ()
219
+ )
142
220
143
221
if use_bwd :
144
222
local_q .requires_grad = True
@@ -147,32 +225,46 @@ def log(msg, a, rank0_only=False):
147
225
148
226
# Map argument to AttnType enum
149
227
attn_impl_map = {
150
- 'torch' : AttnType .TORCH ,
151
- 'fa' : AttnType .FA ,
152
- 'fa3' : AttnType .FA3 ,
153
- 'flashinfer' : AttnType .FLASHINFER ,
154
- 'sage_fp16' : AttnType .SAGE_FP16 ,
155
- 'sage_fp16_triton' : AttnType .SAGE_FP16_TRITON ,
156
- 'sage_fp8' : AttnType .SAGE_FP8 ,
157
- 'sparse_sage' : AttnType .SPARSE_SAGE
228
+ "torch" : AttnType .TORCH ,
229
+ "fa" : AttnType .FA ,
230
+ "fa3" : AttnType .FA3 ,
231
+ "flashinfer" : AttnType .FLASHINFER ,
232
+ "sage_fp16" : AttnType .SAGE_FP16 ,
233
+ "sage_fp8" : AttnType .SAGE_FP8 ,
234
+ "sage_fp8_sm90" : AttnType .SAGE_FP8_SM90 ,
235
+ "sage_fp16_triton" : AttnType .SAGE_FP16_TRITON ,
236
+ "sage_auto" : AttnType .SAGE_AUTO ,
237
+ "sparse_sage" : AttnType .SPARSE_SAGE ,
158
238
}
159
239
160
- if args .attn_impl == ' sparse_sage' :
240
+ if args .attn_impl == " sparse_sage" :
161
241
if use_bwd :
162
242
raise RuntimeError ("Sparse Sage attention does not support backward pass" )
163
- from spas_sage_attn .autotune import SparseAttentionMeansim , load_sparse_attention_state_dict
164
- attn_processor = SparseAttentionMeansim (l1 = args .sparse_sage_l1 , pv_l1 = args .sparse_sage_pv_l1 , tune_pv = True )
243
+ from spas_sage_attn .autotune import (
244
+ SparseAttentionMeansim ,
245
+ load_sparse_attention_state_dict ,
246
+ )
247
+
248
+ attn_processor = SparseAttentionMeansim (
249
+ l1 = args .sparse_sage_l1 , pv_l1 = args .sparse_sage_pv_l1 , tune_pv = True
250
+ )
165
251
else :
166
252
attn_processor = None
167
253
168
- usp_attn = LongContextAttention (ring_impl_type = ring_impl_type ,
169
- attn_type = attn_impl_map [args .attn_impl ],
170
- attn_processor = attn_processor )
254
+ usp_attn = LongContextAttention (
255
+ ring_impl_type = ring_impl_type ,
256
+ attn_type = attn_impl_map [args .attn_impl ],
257
+ attn_processor = attn_processor ,
258
+ )
171
259
172
- if args .attn_impl == ' sparse_sage' :
260
+ if args .attn_impl == " sparse_sage" :
173
261
if not args .sparse_sage_tune_mode :
174
- saved_state_dict = torch .load (args .sparse_sage_tune_path + f".rank{ dist .get_rank ()} " )
175
- load_sparse_attention_state_dict (usp_attn , saved_state_dict , multigpu = True , verbose = True )
262
+ saved_state_dict = torch .load (
263
+ args .sparse_sage_tune_path + f".rank{ dist .get_rank ()} "
264
+ )
265
+ load_sparse_attention_state_dict (
266
+ usp_attn , saved_state_dict , multigpu = True , verbose = True
267
+ )
176
268
else :
177
269
os .environ ["sparse_sage_tune_mode" ] = "1"
178
270
@@ -182,7 +274,7 @@ def log(msg, a, rank0_only=False):
182
274
print ("#" * 30 )
183
275
184
276
# common test parameters
185
- window_size = (- 1 , - 1 )
277
+ window_size = (- 1 , - 1 )
186
278
alibi_slopes , attn_bias = None , None
187
279
dropout_mask = None
188
280
@@ -203,21 +295,25 @@ def log(msg, a, rank0_only=False):
203
295
)
204
296
205
297
# extract local dout
206
- local_dout = EXTRACT_FUNC_DICT [ring_impl_type ](
207
- dout , rank , world_size = world_size , rd = sp_ring_degree , ud = sp_ulysses_degree
208
- ).detach ().clone ()
298
+ local_dout = (
299
+ EXTRACT_FUNC_DICT [ring_impl_type ](
300
+ dout , rank , world_size = world_size , rd = sp_ring_degree , ud = sp_ulysses_degree
301
+ )
302
+ .detach ()
303
+ .clone ()
304
+ )
209
305
210
-
211
- max_memory = torch .cuda .max_memory_allocated (device ) / (1024 * 1024 ) # Convert to MB
306
+ max_memory = torch .cuda .max_memory_allocated (device ) / (
307
+ 1024 * 1024
308
+ ) # Convert to MB
212
309
print (f"[Rank#{ rank } ] Maximum GPU memory used: { max_memory :.2f} MB" )
213
310
torch .cuda .reset_peak_memory_stats (device ) # Reset stats
214
311
215
-
216
312
if rank == 0 :
217
313
print ("#" * 30 )
218
314
print ("# ds-ulysses backward:" )
219
315
print ("#" * 30 )
220
-
316
+
221
317
# usp attn backward
222
318
if use_bwd :
223
319
local_out .backward (local_dout )
@@ -282,11 +378,19 @@ def log(msg, a, rank0_only=False):
282
378
torch .testing .assert_close (local_out , local_out_ref , atol = 1e-1 , rtol = 0 )
283
379
# torch.testing.assert_close(out_ref, out_pt_ref, atol=1e-2, rtol=0)
284
380
285
- if args .attn_impl == 'sparse_sage' :
286
- from spas_sage_attn .autotune import SparseAttentionMeansim , extract_sparse_attention_state_dict
381
+ if args .attn_impl == "sparse_sage" :
382
+ from spas_sage_attn .autotune import (
383
+ SparseAttentionMeansim ,
384
+ extract_sparse_attention_state_dict ,
385
+ )
386
+
287
387
if args .sparse_sage_tune_mode :
288
- saved_state_dict = extract_sparse_attention_state_dict (usp_attn , verbose = True )
289
- torch .save (saved_state_dict , args .sparse_sage_tune_path + f".rank{ dist .get_rank ()} " )
388
+ saved_state_dict = extract_sparse_attention_state_dict (
389
+ usp_attn , verbose = True
390
+ )
391
+ torch .save (
392
+ saved_state_dict , args .sparse_sage_tune_path + f".rank{ dist .get_rank ()} "
393
+ )
290
394
291
395
if use_bwd :
292
396
local_dq_ref = EXTRACT_FUNC_DICT [ring_impl_type ](
0 commit comments