Skip to content

Commit

Permalink
feat: add use_sync switch to ulysses (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
Eigensystem authored Nov 15, 2024
1 parent 1fb7f00 commit 7bbaf56
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
1 change: 1 addition & 0 deletions yunchang/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def forward(
deterministic=deterministic,
return_attn_probs=return_attn_probs,
group=self.ring_pg,
attn_type=self.attn_type,
)
else:
query_layer = SeqAllToAll4D.apply(
Expand Down
4 changes: 4 additions & 0 deletions yunchang/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def select_flash_attn_impl(impl_type: FlashAttentionImpl, stage : str = "fwd-bwd
elif stage == "fwd-bwd":
print(f"flash_attn_func: {flash_attn_func} here")
return flash_attn_func
else:
raise ValueError(f"Unknown stage: {stage}")

elif impl_type == FlashAttentionImpl.FA3:
if stage == "fwd-only":
Expand All @@ -52,6 +54,8 @@ def fn(q,
return flash3_attn_func(q, k, v, softmax_scale=softmax_scale, causal=causal)

return fn
else:
raise ValueError(f"Unknown stage: {stage}")

elif impl_type == FlashAttentionImpl.TORCH:
if stage == "fwd-bwd":
Expand Down
19 changes: 8 additions & 11 deletions yunchang/ulysses/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from yunchang.kernels import FlashAttentionImpl, select_flash_attn_impl
import torch.distributed as dist
from yunchang.comm.all_to_all import SeqAllToAll4D
import torch.nn.functional as F



class UlyssesAttention(torch.nn.Module):
Expand All @@ -32,20 +30,21 @@ def __init__(
scatter_idx: int = 2,
gather_idx: int = 1,
use_sync: bool = False,
attn_type : FlashAttentionImpl = FlashAttentionImpl.FA
attn_type : FlashAttentionImpl = FlashAttentionImpl.FA,
) -> None:

super(UlyssesAttention, self).__init__()
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.use_sync = use_sync
self.attn_type = attn_type
self.attn_fn = select_flash_attn_impl(attn_type)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_name = torch.cuda.get_device_name(device)
if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name:
self.attn_type = FlashAttentionImpl.TORCH
self.attn_fn = select_flash_attn_impl(self.attn_type, stage="fwd-bwd")

def forward(
self,
Expand Down Expand Up @@ -79,15 +78,13 @@ def forward(
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)

# scatter 2, gather 1
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx)


q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx, self.use_sync)
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx, self.use_sync)
v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx, self.use_sync)

if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5

context_layer = self.attn_fn(
q,
k,
Expand All @@ -108,7 +105,7 @@ def forward(
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
# scatter 1, gather 2
output = SeqAllToAll4D.apply(
self.spg, context_layer, self.gather_idx, self.scatter_idx
self.spg, context_layer, self.gather_idx, self.scatter_idx, self.use_sync
)

# out e.g., [s/p::h]
Expand Down

0 comments on commit 7bbaf56

Please sign in to comment.