Skip to content

Commit 6e8b003

Browse files
authored
bump version to 0.6.2 (#140)
1 parent 270ef55 commit 6e8b003

File tree

4 files changed

+27
-5
lines changed

4 files changed

+27
-5
lines changed

README.md

+15-2
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,12 @@ local_out = usp_attn(
122122

123123
### 3.Test
124124

125-
- Causal Attention Test
125+
if you do not install yuanchang, add the project root directory to the PYTHONPATH:
126+
```
127+
export PYTHONPATH=$PWD:$PYTHONPATH
128+
````
129+
130+
- FlashAttn/Torch Test
126131
```bash
127132
torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 2 --ring_impl_type "zigzag" --causal --attn_impl fa --use_bwd
128133
torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 2 --ring_impl_type "zigzag" --causal --attn_impl torch
@@ -134,13 +139,21 @@ torchrun --nproc_per_node 8 test/test_hybrid_qkvpacked_attn.py
134139
you need install [SpargeAttn](https://github.com/thu-ml/SpargeAttn) and [SageAttention](https://github.com/thu-ml/SageAttention) from source.
135140

136141
```bash
137-
torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 4 --attn_impl sage_fp8
142+
torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 2 --attn_impl sage_fp8
138143
```
139144

140145
```bash
141146
torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 4 --attn_impl sparse_sage --sparse_sage_tune_mode
142147
```
143148

149+
- FlashInfer Test (fwd only)
150+
151+
Install FlashInfer from [here](https://docs.flashinfer.ai/installation.html#quick-start).
152+
153+
```bash
154+
torchrun --nproc_per_node=4 --master_port=1234 ./test/test_hybrid_attn.py --sp_ulysses_degree 2 --ring_impl_type 'basic_flashinfer' --attn_impl flashinfer
155+
```
156+
144157
### 4. Verified in Megatron-LM
145158
The loss curves for Data Parallel (DP) and Unified Sequence Parallel (ulysses=2+ring=2) are closely aligned, as illustrated in the figure. This alignment confirms the accuracy of the unified sequence parallel.
146159

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "yunchang"
7-
version = "0.6.1"
7+
version = "0.6.2"
88
authors = [
99
{ name="Jiarui Fang", email="[email protected]" },
1010
]

yunchang/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44
from .globals import set_seq_parallel_pg
55
from .comm.extract_local import stripe_extract_local, basic_extract_local, zigzag_extract_local, EXTRACT_FUNC_DICT
66

7-
__version__ = "0.6.1"
7+
__version__ = "0.6.2"
88

yunchang/globals.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import os
23

34

45
class Singleton:
@@ -98,6 +99,13 @@ def set_seq_parallel_pg(
9899
try:
99100
from flashinfer.prefill import single_prefill_with_kv_cache
100101
HAS_FLASHINFER = True
102+
def get_cuda_arch():
103+
major, minor = torch.cuda.get_device_capability()
104+
return f"{major}.{minor}"
105+
106+
cuda_arch = get_cuda_arch()
107+
os.environ['TORCH_CUDA_ARCH_LIST'] = cuda_arch
108+
print(f"Set TORCH_CUDA_ARCH_LIST to {cuda_arch}")
101109
except ImportError:
102110
HAS_FLASHINFER = False
103111

@@ -111,4 +119,5 @@ def set_seq_parallel_pg(
111119
import spas_sage_attn
112120
HAS_SPARSE_SAGE_ATTENTION = True
113121
except ImportError:
114-
HAS_SPARSE_SAGE_ATTENTION = False
122+
HAS_SPARSE_SAGE_ATTENTION = False
123+

0 commit comments

Comments
 (0)