Skip to content

Commit 08a85d3

Browse files
committed
Release v1.10
2 parents e79d915 + a7e9d3e commit 08a85d3

File tree

129 files changed

+7742
-2560
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

129 files changed

+7742
-2560
lines changed

.github/workflows/blossom-ci.yml

+5-2
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@ jobs:
2323
args: ${{ env.args }}
2424

2525
# This job only runs for pull request comments
26-
if: |
27-
contains( ',ptrendx,ksivaman,', format(',{0},', github.actor)) &&
26+
if: >
2827
github.event.comment.body == '/blossom-ci'
28+
&& (
29+
github.actor == 'ptrendx'
30+
|| github.actor == 'ksivaman'
31+
)
2932
steps:
3033
- name: Check if comment is issued by authorized person
3134
run: blossom-ci

.github/workflows/build.yml

+10-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
name: 'Core'
1313
runs-on: ubuntu-latest
1414
container:
15-
image: nvcr.io/nvidia/cuda:12.5.0-devel-ubuntu22.04
15+
image: nvcr.io/nvidia/cuda:12.0.0-devel-ubuntu22.04
1616
options: --user root
1717
steps:
1818
- name: 'Dependencies'
@@ -35,9 +35,14 @@ jobs:
3535
name: 'PyTorch'
3636
runs-on: ubuntu-latest
3737
container:
38-
image: nvcr.io/nvidia/pytorch:24.05-py3
38+
image: nvcr.io/nvidia/cuda:12.5.0-devel-ubuntu22.04
3939
options: --user root
4040
steps:
41+
- name: 'Dependencies'
42+
run: |
43+
apt-get update
44+
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
45+
pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11
4146
- name: 'Checkout'
4247
uses: actions/checkout@v3
4348
with:
@@ -48,7 +53,8 @@ jobs:
4853
NVTE_FRAMEWORK: pytorch
4954
MAX_JOBS: 1
5055
- name: 'Sanity check'
51-
run: python tests/pytorch/test_sanity_import.py
56+
if: false # Sanity import test requires Flash Attention
57+
run: python3 tests/pytorch/test_sanity_import.py
5258
jax:
5359
name: 'JAX'
5460
runs-on: ubuntu-latest
@@ -70,7 +76,7 @@ jobs:
7076
name: 'PaddlePaddle'
7177
runs-on: ubuntu-latest
7278
container:
73-
image: nvcr.io/nvidia/paddlepaddle:24.05-py3
79+
image: nvcr.io/nvidia/paddlepaddle:24.07-py3
7480
options: --user root
7581
steps:
7682
- name: 'Checkout'

.github/workflows/trigger-ci.yml

+18-2
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,25 @@ jobs:
1515
args: ${{ env.args }}
1616

1717
# This job only runs for pull request comments
18-
if: |
19-
contains( ',ptrendx,ksivaman,schetlur-nv,timmoon10,zlsh80826,mingxu1067,cyanguwa,nzmora-nvidia,galagam,nouiz,denera,sudhakarsingh27,Oleg-Goncharov,phu0ngng,nvcforster,', format(',{0},', github.actor)) &&
18+
if: >
2019
startsWith(github.event.comment.body, '/te-ci')
20+
&& (
21+
github.actor == 'ptrendx'
22+
|| github.actor == 'ksivaman'
23+
|| github.actor == 'schetlur-nv'
24+
|| github.actor == 'timmoon10'
25+
|| github.actor == 'zlsh80826'
26+
|| github.actor == 'mingxu1067'
27+
|| github.actor == 'cyanguwa'
28+
|| github.actor == 'nzmora-nvidia'
29+
|| github.actor == 'galagam'
30+
|| github.actor == 'nouiz'
31+
|| github.actor == 'denera'
32+
|| github.actor == 'sudhakarsingh27'
33+
|| github.actor == 'Oleg-Goncharov'
34+
|| github.actor == 'phu0ngng'
35+
|| github.actor == 'xrennvidia'
36+
)
2137
steps:
2238
- name: Check if comment is issued by authorized person
2339
run: blossom-ci

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 118 files

README.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ Installation
149149
Pre-requisites
150150
^^^^^^^^^^^^^^^^^^^^
151151
* Linux x86_64
152-
* CUDA 11.8+ for Hopper and CUDA 12.1+ for Ada
153-
* NVIDIA Driver supporting CUDA 11.8 or later
152+
* CUDA 12.0+ for Hopper and CUDA 12.1+ for Ada
153+
* NVIDIA Driver supporting CUDA 12.0 or later
154154
* cuDNN 8.1 or later
155155
* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
156156

@@ -182,7 +182,7 @@ From source
182182

183183
Compiling with FlashAttention-2
184184
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
185-
Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance.
185+
Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance.
186186

187187
It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue.
188188

benchmarks/attention/benchmark_attention.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
import transformer_engine
1212
from tests.pytorch.fused_attn.test_fused_attn import (
1313
ModelConfig,
14-
_is_flash_attention_supported,
15-
_is_fused_attention_supported,
16-
_is_unfused_attention_supported,
14+
_get_attention_backends,
1715
_run_dot_product_attention,
1816
)
1917

@@ -29,8 +27,6 @@
2927
workspace_opt = True
3028
# QKV memory layout
3129
qkv_layout = "bshd_bshd_bshd"
32-
# sliding window attention
33-
swa = False
3430
# padding between sequences for qkv_format=thd
3531
pad_between_seqs = False
3632
# training mode
@@ -64,7 +60,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
6460
ckpt_attn,
6561
qkv_layout,
6662
workspace_opt,
67-
swa,
6863
pad_between_seqs,
6964
is_training,
7065
)
@@ -76,7 +71,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
7671
ckpt_attn,
7772
qkv_layout,
7873
workspace_opt,
79-
swa,
8074
pad_between_seqs,
8175
is_training,
8276
)
@@ -97,7 +91,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
9791
ckpt_attn,
9892
qkv_layout,
9993
workspace_opt,
100-
swa,
10194
pad_between_seqs,
10295
is_training,
10396
)
@@ -115,7 +108,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
115108
ckpt_attn,
116109
qkv_layout,
117110
workspace_opt,
118-
swa,
119111
pad_between_seqs,
120112
is_training,
121113
)
@@ -205,13 +197,15 @@ def main():
205197
)
206198
for model in model_configs.keys():
207199
config = model_configs[model]
208-
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
200+
available_backends, fused_attn_backends = _get_attention_backends(
209201
config,
210-
dtype,
202+
qkv_dtype=dtype,
211203
qkv_layout=qkv_layout,
204+
window_size=config.window_size,
205+
pad_between_seqs=pad_between_seqs,
212206
)
213-
fused_attn_supported = fused_attn_supported and not swa
214-
flash_attn_supported = _is_flash_attention_supported(config)
207+
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
208+
215209
print(
216210
f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
217211
f'{" and flash-attention" if flash_attn_supported else ""}...'

build_tools/VERSION.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.9.0
1+
1.10.0

build_tools/build_ext.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import sys
1111
import sysconfig
1212
import copy
13+
import time
1314

1415
from pathlib import Path
1516
from subprocess import CalledProcessError
@@ -69,8 +70,8 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
6970
configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")
7071

7172
# CMake build and install commands
72-
build_command = [_cmake_bin, "--build", build_dir]
73-
install_command = [_cmake_bin, "--install", build_dir]
73+
build_command = [_cmake_bin, "--build", build_dir, "--verbose"]
74+
install_command = [_cmake_bin, "--install", build_dir, "--verbose"]
7475

7576
# Check whether parallel build is restricted
7677
max_jobs = get_max_jobs_for_parallel_build()
@@ -81,13 +82,17 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
8182
build_command.append(str(max_jobs))
8283

8384
# Run CMake commands
85+
start_time = time.perf_counter()
8486
for command in [configure_command, build_command, install_command]:
8587
print(f"Running command {' '.join(command)}")
8688
try:
8789
subprocess.run(command, cwd=build_dir, check=True)
8890
except (CalledProcessError, OSError) as e:
8991
raise RuntimeError(f"Error when running CMake: {e}")
9092

93+
total_time = time.perf_counter() - start_time
94+
print(f"Time for build_ext: {total_time:.2f} seconds")
95+
9196

9297
def get_build_ext(extension_cls: Type[setuptools.Extension]):
9398
class _CMakeBuildExtension(extension_cls):

build_tools/jax.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
#
33
# See LICENSE for license information.
44

5-
"""Paddle-paddle related extensions."""
5+
"""JAX related extensions."""
6+
import os
67
from pathlib import Path
78

89
import setuptools
@@ -12,6 +13,25 @@
1213
from typing import List
1314

1415

16+
def xla_path() -> str:
17+
"""XLA root path lookup.
18+
Throws FileNotFoundError if XLA source is not found."""
19+
20+
try:
21+
from jax.extend import ffi
22+
except ImportError:
23+
if os.getenv("XLA_HOME"):
24+
xla_home = Path(os.getenv("XLA_HOME"))
25+
else:
26+
xla_home = "/opt/xla"
27+
else:
28+
xla_home = ffi.include_dir()
29+
30+
if not os.path.isdir(xla_home):
31+
raise FileNotFoundError("Could not find xla source.")
32+
return xla_home
33+
34+
1535
def setup_jax_extension(
1636
csrc_source_files,
1737
csrc_header_files,
@@ -27,12 +47,14 @@ def setup_jax_extension(
2747

2848
# Header files
2949
cuda_home, _ = cuda_path()
50+
xla_home = xla_path()
3051
include_dirs = [
3152
cuda_home / "include",
3253
common_header_files,
3354
common_header_files / "common",
3455
common_header_files / "common" / "include",
3556
csrc_header_files,
57+
xla_home,
3658
]
3759

3860
# Compile flags

build_tools/paddle.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pathlib import Path
77

88
import setuptools
9+
import os
910

1011
from .utils import cuda_version
1112

@@ -61,12 +62,18 @@ def setup_paddle_extension(
6162
except FileNotFoundError:
6263
print("Could not determine CUDA Toolkit version")
6364
else:
64-
if version >= (11, 2):
65-
nvcc_flags.extend(["--threads", "4"])
66-
if version >= (11, 0):
67-
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
68-
if version >= (11, 8):
69-
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
65+
if version < (12, 0):
66+
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
67+
nvcc_flags.extend(
68+
(
69+
"--threads",
70+
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
71+
"-gencode",
72+
"arch=compute_80,code=sm_80",
73+
"-gencode",
74+
"arch=compute_90,code=sm_90",
75+
)
76+
)
7077

7178
# Construct Paddle CUDA extension
7279
sources = [str(path) for path in sources]

build_tools/pytorch.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,18 @@ def setup_pytorch_extension(
6767
except FileNotFoundError:
6868
print("Could not determine CUDA Toolkit version")
6969
else:
70-
if version >= (11, 2):
71-
nvcc_flags.extend(["--threads", "4"])
72-
if version >= (11, 0):
73-
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
74-
if version >= (11, 8):
75-
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
70+
if version < (12, 0):
71+
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
72+
nvcc_flags.extend(
73+
(
74+
"--threads",
75+
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
76+
"-gencode",
77+
"arch=compute_80,code=sm_80",
78+
"-gencode",
79+
"arch=compute_90,code=sm_90",
80+
)
81+
)
7682

7783
# Libraries
7884
library_dirs = []

0 commit comments

Comments
 (0)