Skip to content

Commit 38027de

Browse files
authored
add more sage attn impl. (#143)
1 parent b3b442c commit 38027de

File tree

2 files changed

+270
-90
lines changed

2 files changed

+270
-90
lines changed

test/test_hybrid_attn.py

Lines changed: 179 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
import os
2-
from yunchang import (
3-
LongContextAttention,
4-
set_seq_parallel_pg,
5-
EXTRACT_FUNC_DICT
6-
)
2+
from yunchang import LongContextAttention, set_seq_parallel_pg, EXTRACT_FUNC_DICT
73
import torch
84
import torch.distributed as dist
5+
96
try:
107
from flash_attn import flash_attn_func
118
except ImportError:
@@ -14,32 +11,82 @@
1411
from test_utils import attention_ref
1512
import argparse
1613

14+
1715
def parse_args():
18-
parser = argparse.ArgumentParser(description='Test hybrid attention with configurable sequence length')
19-
parser.add_argument('--seqlen', type=int, default=1024,
20-
help='sequence length (default: 1024)')
21-
parser.add_argument('--use_bwd', action='store_true',
22-
help='whether to test backward pass (default: False)')
23-
parser.add_argument('--sp_ulysses_degree', type=int, default=None,
24-
help='sp_ulysses_degree (default: world_size)')
25-
parser.add_argument('--ring_impl_type', type=str, default='basic',
26-
choices=['basic', 'zigzag', 'basic_flashinfer'],
27-
help='ring implementation type (default: basic)')
28-
parser.add_argument('--causal', action='store_true',
29-
help='whether to use causal attention (default: False)')
30-
parser.add_argument('--attn_impl', type=str, default='torch',
31-
choices=['torch', 'fa', 'fa3', 'flashinfer', 'sage_fp16', 'sage_fp16_triton', 'sage_fp8', 'sparse_sage'],
32-
help='attention implementation type (default: torch)')
33-
parser.add_argument('--sparse_sage_l1', type=float, default=0.07,
34-
help='l1 for sparse sage attention (default: 0.07)')
35-
parser.add_argument('--sparse_sage_pv_l1', type=float, default=0.08,
36-
help='pv_l1 for sparse sage attention (default: 0.08)')
37-
parser.add_argument('--sparse_sage_tune_mode', action='store_true', default=False,
38-
help='enable tune mode for sparse sage attention (default: False)')
39-
parser.add_argument('--sparse_sage_tune_path', type=str, default='./sparsesage_autotune.pt',
40-
help='path to the sparse sage autotune results (default: ./sparsesage_autotune.pt)')
16+
parser = argparse.ArgumentParser(
17+
description="Test hybrid attention with configurable sequence length"
18+
)
19+
parser.add_argument(
20+
"--seqlen", type=int, default=1024, help="sequence length (default: 1024)"
21+
)
22+
parser.add_argument(
23+
"--use_bwd",
24+
action="store_true",
25+
help="whether to test backward pass (default: False)",
26+
)
27+
parser.add_argument(
28+
"--sp_ulysses_degree",
29+
type=int,
30+
default=None,
31+
help="sp_ulysses_degree (default: world_size)",
32+
)
33+
parser.add_argument(
34+
"--ring_impl_type",
35+
type=str,
36+
default="basic",
37+
choices=["basic", "zigzag", "basic_flashinfer"],
38+
help="ring implementation type (default: basic)",
39+
)
40+
parser.add_argument(
41+
"--causal",
42+
action="store_true",
43+
help="whether to use causal attention (default: False)",
44+
)
45+
parser.add_argument(
46+
"--attn_impl",
47+
type=str,
48+
default="torch",
49+
choices=[
50+
"torch",
51+
"fa",
52+
"fa3",
53+
"flashinfer",
54+
"sage_fp16",
55+
"sage_fp8",
56+
"sparse_sage",
57+
"sage_fp8_sm90",
58+
"sage_fp16_triton",
59+
"sage_auto",
60+
],
61+
help="attention implementation type (default: torch)",
62+
)
63+
parser.add_argument(
64+
"--sparse_sage_l1",
65+
type=float,
66+
default=0.07,
67+
help="l1 for sparse sage attention (default: 0.07)",
68+
)
69+
parser.add_argument(
70+
"--sparse_sage_pv_l1",
71+
type=float,
72+
default=0.08,
73+
help="pv_l1 for sparse sage attention (default: 0.08)",
74+
)
75+
parser.add_argument(
76+
"--sparse_sage_tune_mode",
77+
action="store_true",
78+
default=False,
79+
help="enable tune mode for sparse sage attention (default: False)",
80+
)
81+
parser.add_argument(
82+
"--sparse_sage_tune_path",
83+
type=str,
84+
default="./sparsesage_autotune.pt",
85+
help="path to the sparse sage autotune results (default: ./sparsesage_autotune.pt)",
86+
)
4187
return parser.parse_args()
4288

89+
4390
def log(msg, a, rank0_only=False):
4491
world_size = dist.get_world_size()
4592
rank = dist.get_rank()
@@ -65,19 +112,20 @@ def log(msg, a, rank0_only=False):
65112
)
66113
dist.barrier()
67114

115+
68116
# test it with:
69117
# torchrun --nproc_per_node=4 test/test_hybrid_attn.py
70118
if __name__ == "__main__":
71119
args = parse_args()
72-
120+
73121
torch.random.manual_seed(0)
74122

75123
dist.init_process_group("nccl")
76124

77125
rank = dist.get_rank()
78126
world_size = dist.get_world_size()
79127

80-
# Inference mainly uses fp16; ROCM flash attention with bf16 precision is slightly larger, will be fixed soon
128+
# Inference mainly uses fp16; ROCM flash attention with bf16 precision is slightly larger, will be fixed soon
81129
dtype = torch.bfloat16
82130
device = torch.device(f"cuda:{rank}")
83131

@@ -88,7 +136,7 @@ def log(msg, a, rank0_only=False):
88136
dropout_p = 0
89137
causal = args.causal
90138
deterministic = False
91-
139+
92140
use_bwd = args.use_bwd
93141

94142
assert seqlen % world_size == 0
@@ -98,13 +146,31 @@ def log(msg, a, rank0_only=False):
98146

99147
# Prepare inputs
100148
q = torch.randn(
101-
batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True if use_bwd else False
149+
batch_size,
150+
seqlen,
151+
nheads,
152+
d,
153+
device=device,
154+
dtype=dtype,
155+
requires_grad=True if use_bwd else False,
102156
)
103157
k = torch.randn(
104-
batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True if use_bwd else False
158+
batch_size,
159+
seqlen,
160+
nheads,
161+
d,
162+
device=device,
163+
dtype=dtype,
164+
requires_grad=True if use_bwd else False,
105165
)
106166
v = torch.randn(
107-
batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True if use_bwd else False
167+
batch_size,
168+
seqlen,
169+
nheads,
170+
d,
171+
device=device,
172+
dtype=dtype,
173+
requires_grad=True if use_bwd else False,
108174
)
109175
dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype)
110176

@@ -116,7 +182,9 @@ def log(msg, a, rank0_only=False):
116182
# prepare process group for hybrid sequence parallelism
117183
use_ring_low_dim = True
118184

119-
sp_ulysses_degree = args.sp_ulysses_degree if args.sp_ulysses_degree is not None else world_size
185+
sp_ulysses_degree = (
186+
args.sp_ulysses_degree if args.sp_ulysses_degree is not None else world_size
187+
)
120188
sp_ring_degree = world_size // sp_ulysses_degree
121189

122190
print(
@@ -126,19 +194,29 @@ def log(msg, a, rank0_only=False):
126194
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
127195

128196
# Use EXTRACT_FUNC_DICT to shard the tensors
129-
local_q = EXTRACT_FUNC_DICT[ring_impl_type](
130-
q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
131-
).detach().clone()
132-
133-
134-
local_k = EXTRACT_FUNC_DICT[ring_impl_type](
135-
k, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
136-
).detach().clone()
197+
local_q = (
198+
EXTRACT_FUNC_DICT[ring_impl_type](
199+
q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
200+
)
201+
.detach()
202+
.clone()
203+
)
137204

205+
local_k = (
206+
EXTRACT_FUNC_DICT[ring_impl_type](
207+
k, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
208+
)
209+
.detach()
210+
.clone()
211+
)
138212

139-
local_v = EXTRACT_FUNC_DICT[ring_impl_type](
140-
v, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
141-
).detach().clone()
213+
local_v = (
214+
EXTRACT_FUNC_DICT[ring_impl_type](
215+
v, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
216+
)
217+
.detach()
218+
.clone()
219+
)
142220

143221
if use_bwd:
144222
local_q.requires_grad = True
@@ -147,32 +225,46 @@ def log(msg, a, rank0_only=False):
147225

148226
# Map argument to AttnType enum
149227
attn_impl_map = {
150-
'torch': AttnType.TORCH,
151-
'fa': AttnType.FA,
152-
'fa3': AttnType.FA3,
153-
'flashinfer': AttnType.FLASHINFER,
154-
'sage_fp16': AttnType.SAGE_FP16,
155-
'sage_fp16_triton': AttnType.SAGE_FP16_TRITON,
156-
'sage_fp8': AttnType.SAGE_FP8,
157-
'sparse_sage': AttnType.SPARSE_SAGE
228+
"torch": AttnType.TORCH,
229+
"fa": AttnType.FA,
230+
"fa3": AttnType.FA3,
231+
"flashinfer": AttnType.FLASHINFER,
232+
"sage_fp16": AttnType.SAGE_FP16,
233+
"sage_fp8": AttnType.SAGE_FP8,
234+
"sage_fp8_sm90": AttnType.SAGE_FP8_SM90,
235+
"sage_fp16_triton": AttnType.SAGE_FP16_TRITON,
236+
"sage_auto": AttnType.SAGE_AUTO,
237+
"sparse_sage": AttnType.SPARSE_SAGE,
158238
}
159239

160-
if args.attn_impl == 'sparse_sage':
240+
if args.attn_impl == "sparse_sage":
161241
if use_bwd:
162242
raise RuntimeError("Sparse Sage attention does not support backward pass")
163-
from spas_sage_attn.autotune import SparseAttentionMeansim, load_sparse_attention_state_dict
164-
attn_processor = SparseAttentionMeansim(l1=args.sparse_sage_l1, pv_l1=args.sparse_sage_pv_l1, tune_pv=True)
243+
from spas_sage_attn.autotune import (
244+
SparseAttentionMeansim,
245+
load_sparse_attention_state_dict,
246+
)
247+
248+
attn_processor = SparseAttentionMeansim(
249+
l1=args.sparse_sage_l1, pv_l1=args.sparse_sage_pv_l1, tune_pv=True
250+
)
165251
else:
166252
attn_processor = None
167253

168-
usp_attn = LongContextAttention(ring_impl_type=ring_impl_type,
169-
attn_type=attn_impl_map[args.attn_impl],
170-
attn_processor=attn_processor)
254+
usp_attn = LongContextAttention(
255+
ring_impl_type=ring_impl_type,
256+
attn_type=attn_impl_map[args.attn_impl],
257+
attn_processor=attn_processor,
258+
)
171259

172-
if args.attn_impl == 'sparse_sage':
260+
if args.attn_impl == "sparse_sage":
173261
if not args.sparse_sage_tune_mode:
174-
saved_state_dict = torch.load(args.sparse_sage_tune_path + f".rank{dist.get_rank()}")
175-
load_sparse_attention_state_dict(usp_attn, saved_state_dict, multigpu=True, verbose=True)
262+
saved_state_dict = torch.load(
263+
args.sparse_sage_tune_path + f".rank{dist.get_rank()}"
264+
)
265+
load_sparse_attention_state_dict(
266+
usp_attn, saved_state_dict, multigpu=True, verbose=True
267+
)
176268
else:
177269
os.environ["sparse_sage_tune_mode"] = "1"
178270

@@ -182,7 +274,7 @@ def log(msg, a, rank0_only=False):
182274
print("#" * 30)
183275

184276
# common test parameters
185-
window_size=(-1, -1)
277+
window_size = (-1, -1)
186278
alibi_slopes, attn_bias = None, None
187279
dropout_mask = None
188280

@@ -203,21 +295,25 @@ def log(msg, a, rank0_only=False):
203295
)
204296

205297
# extract local dout
206-
local_dout = EXTRACT_FUNC_DICT[ring_impl_type](
207-
dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
208-
).detach().clone()
298+
local_dout = (
299+
EXTRACT_FUNC_DICT[ring_impl_type](
300+
dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
301+
)
302+
.detach()
303+
.clone()
304+
)
209305

210-
211-
max_memory = torch.cuda.max_memory_allocated(device) / (1024 * 1024) # Convert to MB
306+
max_memory = torch.cuda.max_memory_allocated(device) / (
307+
1024 * 1024
308+
) # Convert to MB
212309
print(f"[Rank#{rank}] Maximum GPU memory used: {max_memory:.2f} MB")
213310
torch.cuda.reset_peak_memory_stats(device) # Reset stats
214311

215-
216312
if rank == 0:
217313
print("#" * 30)
218314
print("# ds-ulysses backward:")
219315
print("#" * 30)
220-
316+
221317
# usp attn backward
222318
if use_bwd:
223319
local_out.backward(local_dout)
@@ -282,11 +378,19 @@ def log(msg, a, rank0_only=False):
282378
torch.testing.assert_close(local_out, local_out_ref, atol=1e-1, rtol=0)
283379
# torch.testing.assert_close(out_ref, out_pt_ref, atol=1e-2, rtol=0)
284380

285-
if args.attn_impl == 'sparse_sage':
286-
from spas_sage_attn.autotune import SparseAttentionMeansim, extract_sparse_attention_state_dict
381+
if args.attn_impl == "sparse_sage":
382+
from spas_sage_attn.autotune import (
383+
SparseAttentionMeansim,
384+
extract_sparse_attention_state_dict,
385+
)
386+
287387
if args.sparse_sage_tune_mode:
288-
saved_state_dict = extract_sparse_attention_state_dict(usp_attn, verbose=True)
289-
torch.save(saved_state_dict, args.sparse_sage_tune_path + f".rank{dist.get_rank()}")
388+
saved_state_dict = extract_sparse_attention_state_dict(
389+
usp_attn, verbose=True
390+
)
391+
torch.save(
392+
saved_state_dict, args.sparse_sage_tune_path + f".rank{dist.get_rank()}"
393+
)
290394

291395
if use_bwd:
292396
local_dq_ref = EXTRACT_FUNC_DICT[ring_impl_type](

0 commit comments

Comments
 (0)