Skip to content

Commit 2c4f548

Browse files
authored
FlashAttentionImpl -> AttnType (#117)
1 parent c7460f5 commit 2c4f548

12 files changed

+49
-49
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ As shown in the figure below, there are three usage methods based on the flash_a
4646

4747
2. For A100, L40, hardware that supports FA v2, ring_flash_attn uses FA v2.
4848

49-
3. For hardware such as NPUs that does not support FA, use torch to implement attention computation. In this case, there is no need to install `flash_attn`, and you should apply `LongContextAttention(ring_impl_type="basic", attn_type=FlashAttentionImpl.TORCH)`.
49+
3. For hardware such as NPUs that does not support FA, use torch to implement attention computation. In this case, there is no need to install `flash_attn`, and you should apply `LongContextAttention(ring_impl_type="basic", attn_type=AttnType.TORCH)`.
5050

5151
Option 1: pip install
5252

@@ -85,7 +85,7 @@ from yunchang import (
8585
set_seq_parallel_pg,
8686
EXTRACT_FUNC_DICT
8787
)
88-
from yunchang.kernels import FlashAttentionImpl
88+
from yunchang.kernels import AttnType
8989

9090
sp_ulysses_degree = 2
9191
sp_ring_degree = 4
@@ -94,10 +94,10 @@ sp_ring_degree = 4
9494
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
9595

9696
# attn_type could be FA, FA3, TORCH.
97-
longctx_attn = LongContextAttention(ring_impl_type="zigzag", attn_type=FlashAttentionImpl.FA)
97+
longctx_attn = LongContextAttention(ring_impl_type="zigzag", attn_type=AttnType.FA)
9898

9999
# if you use NPUs, where no flash_attn is supported, you can use the following code.
100-
# LongContextAttention(ring_impl_type="zigzag", attn_type=FlashAttentionImpl.TORCH)
100+
# LongContextAttention(ring_impl_type="zigzag", attn_type=AttnType.TORCH)
101101

102102
# extract a local shard for the global Q, K, V.
103103
local_q = EXTRACT_FUNC_DICT["zigzag"](

benchmark/benchmark_longctx.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def benchmark(num_iter=10, forward_only=True, log=True, profile=False):
160160
sp_ulysses_degree, sp_ring_degree, rank, world_size, args.use_ulysses_lowdim
161161
)
162162

163-
from yunchang.kernels import FlashAttentionImpl
164-
attn_type = FlashAttentionImpl.from_string(args.attn_type)
163+
from yunchang.kernels import AttnType
164+
attn_type = AttnType.from_string(args.attn_type)
165165
if args.use_ulysses:
166166
longctx_attn = UlyssesAttention(attn_type=attn_type)
167167
else:

test/test_hybrid_attn.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import torch.distributed as dist
88
from flash_attn import flash_attn_func
9-
from yunchang.kernels import FlashAttentionImpl
9+
from yunchang.kernels import AttnType
1010
from test_utils import attention_ref
1111
import argparse
1212

@@ -133,11 +133,11 @@ def log(msg, a, rank0_only=False):
133133
local_k.requires_grad = True
134134
local_v.requires_grad = True
135135

136-
# Map argument to FlashAttentionImpl enum
136+
# Map argument to AttnType enum
137137
attn_impl_map = {
138-
'torch': FlashAttentionImpl.TORCH,
139-
'fa': FlashAttentionImpl.FA,
140-
'fa3': FlashAttentionImpl.FA3
138+
'torch': AttnType.TORCH,
139+
'fa': AttnType.FA,
140+
'fa3': AttnType.FA3
141141
}
142142

143143
usp_attn = LongContextAttention(ring_impl_type=ring_impl_type,

test/test_hybrid_qkvpacked_attn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
EXTRACT_FUNC_DICT,
77
RING_IMPL_QKVPACKED_DICT
88
)
9-
from yunchang.kernels import FlashAttentionImpl
9+
from yunchang.kernels import AttnType
1010

1111

1212
def log(msg, a, rank0_only=False):
@@ -66,7 +66,7 @@ def test(ring_impl_type="zigzag"):
6666

6767
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
6868

69-
longctx_attn = LongContextAttentionQKVPacked(ring_impl_type=ring_impl_type, attn_type=FlashAttentionImpl.TORCH)
69+
longctx_attn = LongContextAttentionQKVPacked(ring_impl_type=ring_impl_type, attn_type=AttnType.TORCH)
7070

7171
## prepare input and output tensors
7272

test/test_ulysses_attn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from yunchang import UlyssesAttention
44

55
from flash_attn import flash_attn_func
6-
from yunchang.kernels import FlashAttentionImpl
6+
from yunchang.kernels import AttnType
77

88
def log(msg, a, rank0_only=False):
99
world_size = dist.get_world_size()
@@ -79,7 +79,7 @@ def log(msg, a, rank0_only=False):
7979
# prcess_group == sequence_process_group
8080
sp_pg = None #dist.new_group(ranks=[i for i in range(world_size)])
8181

82-
dist_attn = UlyssesAttention(sp_pg, attn_type=FlashAttentionImpl.FA)
82+
dist_attn = UlyssesAttention(sp_pg, attn_type=AttnType.FA)
8383

8484
if rank == 0:
8585
print("#" * 30)

yunchang/hybrid/attn_layer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.distributed as dist
99
from .utils import RING_IMPL_DICT, RING_IMPL_QKVPACKED_DICT
1010
from yunchang.globals import PROCESS_GROUP
11-
from yunchang.kernels import FlashAttentionImpl
11+
from yunchang.kernels import AttnType
1212

1313

1414
class LongContextAttention(torch.nn.Module):
@@ -29,7 +29,7 @@ def __init__(
2929
ring_impl_type: str = "basic",
3030
use_pack_qkv: bool = False,
3131
use_sync: bool = False,
32-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
32+
attn_type: AttnType = AttnType.FA,
3333
) -> None:
3434

3535
super(LongContextAttention, self).__init__()
@@ -157,7 +157,7 @@ def __init__(
157157
gather_idx: int = 1,
158158
ring_impl_type: str = "basic",
159159
use_sync: bool = False,
160-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
160+
attn_type: AttnType = AttnType.FA,
161161
) -> None:
162162

163163
super(LongContextAttentionQKVPacked, self).__init__()

yunchang/kernels/__init__.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
if HAS_FLASH_ATTN:
1515
from flash_attn import flash_attn_func
1616

17-
class FlashAttentionImpl(Enum):
17+
class AttnType(Enum):
1818
FA = "fa"
1919
FA3 = "fa3"
2020
TORCH = "torch"
@@ -26,8 +26,8 @@ def from_string(cls, s: str):
2626
return member
2727
raise ValueError(f"'{s}' is not a valid {cls.__name__}")
2828

29-
def select_flash_attn_impl(impl_type: FlashAttentionImpl, stage : str = "fwd-bwd"):
30-
if impl_type == FlashAttentionImpl.FA:
29+
def select_flash_attn_impl(impl_type: AttnType, stage : str = "fwd-bwd"):
30+
if impl_type == AttnType.FA:
3131
if stage == "fwd-only":
3232
return flash_attn_forward
3333
elif stage == "bwd-only":
@@ -38,7 +38,7 @@ def select_flash_attn_impl(impl_type: FlashAttentionImpl, stage : str = "fwd-bwd
3838
else:
3939
raise ValueError(f"Unknown stage: {stage}")
4040

41-
elif impl_type == FlashAttentionImpl.FA3:
41+
elif impl_type == AttnType.FA3:
4242
if stage == "fwd-only":
4343
return flash_attn3_func_forward
4444
elif stage == "bwd-only":
@@ -64,7 +64,7 @@ def fn(q,
6464
else:
6565
raise ValueError(f"Unknown stage: {stage}")
6666

67-
elif impl_type == FlashAttentionImpl.TORCH:
67+
elif impl_type == AttnType.TORCH:
6868
if stage == "fwd-only":
6969
return pytorch_attn_forward
7070
elif stage == "bwd-only":
@@ -77,4 +77,4 @@ def fn(q,
7777
else:
7878
raise ValueError(f"Unknown flash attention implementation: {impl_type}")
7979

80-
__all__ = ["flash_attn_forward", "flash_attn_backward", "flash_attn3_func_forward", "flash_attn3_func_forward", "FlashAttentionImpl"]
80+
__all__ = ["flash_attn_forward", "flash_attn_backward", "flash_attn3_func_forward", "flash_attn3_func_forward", "AttnType"]

yunchang/ring/ring_flash_attn.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.distributed as dist
33
# from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
44
from .utils import RingComm, update_out_and_lse
5-
from yunchang.kernels import select_flash_attn_impl, FlashAttentionImpl
5+
from yunchang.kernels import select_flash_attn_impl, AttnType
66

77
def ring_flash_attn_forward(
88
process_group,
@@ -16,7 +16,7 @@ def ring_flash_attn_forward(
1616
softcap=0.0,
1717
alibi_slopes=None,
1818
deterministic=False,
19-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
19+
attn_type: AttnType = AttnType.FA,
2020
):
2121
comm = RingComm(process_group)
2222

@@ -72,7 +72,7 @@ def ring_flash_attn_backward(
7272
softcap=0.0,
7373
alibi_slopes=None,
7474
deterministic=False,
75-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
75+
attn_type: AttnType = AttnType.FA,
7676
):
7777
kv_comm = RingComm(process_group)
7878
d_kv_comm = RingComm(process_group)
@@ -227,7 +227,7 @@ def ring_flash_attn_qkvpacked_func(
227227
deterministic=False,
228228
return_attn_probs=False,
229229
group=None,
230-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
230+
attn_type: AttnType = AttnType.FA,
231231
):
232232
return RingFlashAttnFunc.apply(
233233
qkv[:, :, 0],
@@ -258,7 +258,7 @@ def ring_flash_attn_kvpacked_func(
258258
deterministic=False,
259259
return_attn_probs=False,
260260
group=None,
261-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
261+
attn_type: AttnType = AttnType.FA,
262262
):
263263
return RingFlashAttnFunc.apply(
264264
q,
@@ -290,7 +290,7 @@ def ring_flash_attn_func(
290290
deterministic=False,
291291
return_attn_probs=False,
292292
group=None,
293-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
293+
attn_type: AttnType = AttnType.FA,
294294
):
295295
return RingFlashAttnFunc.apply(
296296
q,

yunchang/ring/ring_pytorch_attn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import torch.nn.functional as F
77
from typing import Any, Optional, Tuple
8-
from yunchang.kernels import select_flash_attn_impl, FlashAttentionImpl
8+
from yunchang.kernels import select_flash_attn_impl, AttnType
99
from .utils import RingComm, update_out_and_lse
1010
from yunchang.kernels.attention import pytorch_attn_forward, pytorch_attn_backward
1111

@@ -22,7 +22,7 @@ def ring_pytorch_attn_func(
2222
deterministic=False,
2323
return_attn_probs=False,
2424
group=None,
25-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
25+
attn_type: AttnType = AttnType.FA,
2626
):
2727
return RingAttentionFunc.apply(group, q, k, v, softmax_scale, causal)
2828

yunchang/ring/stripe_flash_attn.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from yunchang.kernels import select_flash_attn_impl, FlashAttentionImpl
2+
from yunchang.kernels import select_flash_attn_impl, AttnType
33
from .utils import RingComm, update_out_and_lse
44

55

@@ -15,7 +15,7 @@ def stripe_flash_attn_forward(
1515
softcap=0.0,
1616
alibi_slopes=None,
1717
deterministic=False,
18-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
18+
attn_type: AttnType = AttnType.FA,
1919
):
2020
assert (
2121
causal
@@ -91,7 +91,7 @@ def stripe_flash_attn_backward(
9191
softcap=0.0,
9292
alibi_slopes=None,
9393
deterministic=False,
94-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
94+
attn_type: AttnType = AttnType.FA,
9595
):
9696
assert (
9797
causal
@@ -211,7 +211,7 @@ def forward(
211211
deterministic,
212212
return_softmax,
213213
group,
214-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
214+
attn_type: AttnType = AttnType.FA,
215215
):
216216
if softmax_scale is None:
217217
softmax_scale = q.shape[-1] ** (-0.5)
@@ -280,7 +280,7 @@ def stripe_flash_attn_qkvpacked_func(
280280
deterministic=False,
281281
return_attn_probs=False,
282282
group=None,
283-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
283+
attn_type: AttnType = AttnType.FA,
284284
):
285285
return StripeFlashAttnFunc.apply(
286286
qkv[:, :, 0],
@@ -311,7 +311,7 @@ def stripe_flash_attn_kvpacked_func(
311311
deterministic=False,
312312
return_attn_probs=False,
313313
group=None,
314-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
314+
attn_type: AttnType = AttnType.FA,
315315
):
316316
return StripeFlashAttnFunc.apply(
317317
q,
@@ -343,7 +343,7 @@ def stripe_flash_attn_func(
343343
deterministic=False,
344344
return_attn_probs=False,
345345
group=None,
346-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
346+
attn_type: AttnType = AttnType.FA,
347347
):
348348
return StripeFlashAttnFunc.apply(
349349
q,

yunchang/ring/zigzag_ring_flash_attn.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from .utils import RingComm, update_out_and_lse
3-
from yunchang.kernels import FlashAttentionImpl, select_flash_attn_impl
3+
from yunchang.kernels import AttnType, select_flash_attn_impl
44

55
def zigzag_ring_flash_attn_forward(
66
process_group,
@@ -14,7 +14,7 @@ def zigzag_ring_flash_attn_forward(
1414
softcap=0.0,
1515
alibi_slopes=None,
1616
deterministic=False,
17-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
17+
attn_type: AttnType = AttnType.FA,
1818
):
1919
assert causal == True, "zigzag ring is meaningless for causal=False"
2020
comm = RingComm(process_group)
@@ -91,7 +91,7 @@ def zigzag_ring_flash_attn_backward(
9191
softcap=0.0,
9292
alibi_slopes=None,
9393
deterministic=False,
94-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
94+
attn_type: AttnType = AttnType.FA,
9595
):
9696
assert causal == True, "zigzag ring is meaningless for causal=False"
9797
kv_comm = RingComm(process_group)
@@ -268,7 +268,7 @@ def zigzag_ring_flash_attn_qkvpacked_func(
268268
deterministic=False,
269269
return_attn_probs=False,
270270
group=None,
271-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
271+
attn_type: AttnType = AttnType.FA,
272272
):
273273
return ZigZagRingFlashAttnFunc.apply(
274274
qkv[:, :, 0],
@@ -299,7 +299,7 @@ def zigzag_ring_flash_attn_kvpacked_func(
299299
deterministic=False,
300300
return_attn_probs=False,
301301
group=None,
302-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
302+
attn_type: AttnType = AttnType.FA,
303303
):
304304
return ZigZagRingFlashAttnFunc.apply(
305305
q,
@@ -331,7 +331,7 @@ def zigzag_ring_flash_attn_func(
331331
deterministic=False,
332332
return_attn_probs=False,
333333
group=None,
334-
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
334+
attn_type: AttnType = AttnType.FA,
335335
):
336336
return ZigZagRingFlashAttnFunc.apply(
337337
q,

yunchang/ulysses/attn_layer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from typing import Any
99
from torch import Tensor
10-
from yunchang.kernels import FlashAttentionImpl, select_flash_attn_impl
10+
from yunchang.kernels import AttnType, select_flash_attn_impl
1111
import torch.distributed as dist
1212
from yunchang.comm.all_to_all import SeqAllToAll4D
1313

@@ -21,7 +21,7 @@ class UlyssesAttention(torch.nn.Module):
2121
scatter_idx (int): scatter_idx for all2all comm
2222
gather_idx (int): gather_idx for all2all comm
2323
use_sync (bool): whether to synchronize after all-to-all. This flag can save cuda memory but will slow down the speed.
24-
attn_type (FlashAttentionImpl): attention type enum
24+
attn_type (AttnType): attention type enum
2525
"""
2626

2727
def __init__(
@@ -30,7 +30,7 @@ def __init__(
3030
scatter_idx: int = 2,
3131
gather_idx: int = 1,
3232
use_sync: bool = False,
33-
attn_type : FlashAttentionImpl = FlashAttentionImpl.FA,
33+
attn_type : AttnType = AttnType.FA,
3434
) -> None:
3535

3636
super(UlyssesAttention, self).__init__()
@@ -43,7 +43,7 @@ def __init__(
4343
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4444
gpu_name = torch.cuda.get_device_name(device)
4545
if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name:
46-
self.attn_type = FlashAttentionImpl.TORCH
46+
self.attn_type = AttnType.TORCH
4747
self.attn_fn = select_flash_attn_impl(self.attn_type, stage="fwd-bwd")
4848

4949
def forward(

0 commit comments

Comments
 (0)