Skip to content

Commit 7a7225c

Browse files
committed
Release v1.12
2 parents c27ee60 + 7f2afaa commit 7a7225c

File tree

105 files changed

+4410
-1575
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

105 files changed

+4410
-1575
lines changed

.github/workflows/trigger-ci.yml

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ jobs:
3636
|| github.actor == 'yaox12'
3737
|| github.actor == 'huanghua1994'
3838
|| github.actor == 'mgoldfarb-nvidia'
39+
|| github.actor == 'pggPL'
40+
|| github.actor == 'vasunvidia'
41+
|| github.actor == 'erhoo82'
3942
)
4043
steps:
4144
- name: Check if comment is issued by authorized person

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ develop-eggs/
3939
dist/
4040
downloads/
4141
.pytest_cache/
42+
compile_commands.json

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 146 files

build_tools/VERSION.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.11.0
1+
1.12.0

build_tools/pytorch.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ def setup_pytorch_extension(
8080
)
8181
)
8282

83-
if "80" in cuda_architectures:
84-
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
85-
if "90" in cuda_architectures:
86-
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
83+
for arch in cuda_architectures.split(";"):
84+
if arch == "70":
85+
continue # Already handled
86+
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
8787

8888
# Libraries
8989
library_dirs = []

docs/faq.rst

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
..
2+
Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
4+
See LICENSE for license information.
5+
6+
Frequently Asked Questions (FAQ)
7+
================================
8+
9+
FP8 checkpoint compatibility
10+
----------------------------
11+
12+
Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a `._extra_state` key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the `._extra_state` key has also shifted.
13+
14+
Here, we take the `MultiheadAttention` module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as `core_attention._extra_state` as shown below.
15+
16+
.. code-block:: python
17+
18+
>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
19+
>>> with fp8_model_init(enabled=True):
20+
... mha = MultiheadAttention(
21+
... hidden_size=1024,
22+
... num_attention_heads=16,
23+
... bias=True,
24+
... params_dtype=torch.bfloat16,
25+
... input_layernorm=False,
26+
... fuse_qkv_params=True,
27+
... attention_type="self",
28+
... qkv_weight_interleaved=True,
29+
... ).to(dtype=torch.bfloat16, device="cuda")
30+
...
31+
>>> state_dict = mha.state_dict()
32+
>>> print(state_dict.keys())
33+
odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state'])
34+
35+
Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions.
36+
37+
.. list-table::
38+
39+
* - **Version: <= 1.5**
40+
41+
- Saves no FP8 metadata since FP8 attention is not supported
42+
- Loading behavior for checkpoints created by the following versions:
43+
44+
:<= 1.5: Loads no FP8 metadata
45+
:> 1.5: Error: unexpected key
46+
* - **Version: 1.6, 1.7**
47+
48+
- Saves FP8 metadata to `core_attention.fused_attention._extra_state`
49+
- Loading behavior for checkpoints created by the following versions:
50+
51+
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
52+
:1.6, 1.7: Loads FP8 metadata from checkpoint
53+
:>= 1.8: Error: unexpected key
54+
* - **Version: >=1.8, <= 1.11**
55+
56+
- Saves FP8 metadata to `core_attention._extra_state`
57+
- Loading behavior for checkpoints created by the following versions:
58+
59+
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
60+
:1.6, 1.7: This save/load combination relies on users to map the 1.6/1.7 key to the 1.8-1.11 key. Otherwise, it initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes. The mapping can be done, in this `MultiheadAttention` example, by
61+
62+
.. code-block:: python
63+
64+
>>> state_dict["core_attention._extra_state"] = \
65+
state_dict["core_attention.fused_attention._extra_state"]
66+
>>> del state_dict["core_attention.fused_attention._extra_state"]
67+
68+
:>= 1.8: Loads FP8 metadata from checkpoint
69+
* - **Version: >=1.12**
70+
71+
- Saves FP8 metadata to `core_attention._extra_state`
72+
- Loading behavior for checkpoints created by the following versions:
73+
74+
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
75+
:>= 1.6: Loads FP8 metadata from checkpoint

docs/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Transformer Engine documentation
3030

3131
installation
3232
examples/quickstart.ipynb
33+
faq
3334

3435
.. toctree::
3536
:hidden:

examples/README.md

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Examples
2+
3+
We provide a variety of examples for deep learning frameworks including [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/jax-ml/jax), and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle).
4+
Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/TransformerEngine/tree/main/docs/examples) and a selection of [third-party examples](#third-party). Please be aware that these third-party examples might need specific, older versions of dependencies to function properly.
5+
6+
# PyTorch
7+
8+
- [Accelerate Hugging Face Llama models with TE](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb)
9+
- Provides code examples and explanations for integrating TE with the LLaMA2 and LLaMA2 models.
10+
- [PyTorch FSDP with FP8](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/fsdp)
11+
- **Distributed Training**: How to set up and run distributed training using PyTorch’s FullyShardedDataParallel (FSDP) strategy.
12+
- **TE Integration**: Instructions on integrating TE/FP8 with PyTorch for optimized performance.
13+
- **Checkpointing**: Methods for applying activation checkpointing to manage memory usage during training.
14+
- [Attention backends in TE](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/attention/attention.ipynb)
15+
- **Attention Backends**: Describes various attention backends supported by Transformer Engine, including framework-native, fused, and flash-attention backends, and their performance benefits.
16+
- **Flash vs. Non-Flash**: Compares the flash algorithm with the standard non-flash algorithm, highlighting memory and computational efficiency improvements.
17+
- **Backend Selection**: Details the logic for selecting the most appropriate backend based on availability and performance, and provides user control options for backend selection.
18+
- [Overlapping Communication with GEMM](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/comm_gemm_overlap)
19+
- Training a TE module with GEMM and communication overlap, including various configurations and command-line arguments for customization.
20+
- [Performance Optimizations](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/advanced_optimizations.ipynb)
21+
- **Multi-GPU Training**: How to use TE with data, tensor, and sequence parallelism.
22+
- **Gradient Accumulation Fusion**: Utilizing Tensor Cores to accumulate outputs directly into FP32 for better numerical accuracy.
23+
- **FP8 Weight Caching**: Avoiding redundant FP8 casting during multiple gradient accumulation steps to improve efficiency.
24+
- [Introduction to FP8](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/fp8_primer.ipynb)
25+
- Overview of FP8 datatypes (E4M3, E5M2), mixed precision training, delayed scaling strategies, and code examples for FP8 configuration and usage.
26+
- [TE Quickstart](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb)
27+
- Introduction to TE, building a Transformer Layer using PyTorch, and instructions on integrating TE modules like Linear and LayerNorm.
28+
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/mnist)
29+
30+
# JAX
31+
- [Basic Transformer Encoder Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/encoder)
32+
- Single GPU Training: Demonstrates setting up and training a Transformer model using a single GPU.
33+
- Data Parallelism: Scale training across multiple GPUs using data parallelism.
34+
- Model Parallelism: Divide a model across multiple GPUs for parallel training.
35+
- Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup.
36+
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/mnist)
37+
38+
# PaddlePaddle
39+
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/paddle/mnist)
40+
41+
# Third party
42+
- [Hugging Face Accelerate + TE](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8/transformer_engine)
43+
- Scripts for training with Accelerate and TE. Supports single GPU, and multi-GPU via DDP, FSDP, and DeepSpeed ZeRO 1-3.

examples/jax/encoder/common.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
"""Shared functions for the encoder tests"""
5+
from functools import lru_cache
6+
7+
from transformer_engine.transformer_engine_jax import get_device_compute_capability
8+
9+
10+
@lru_cache
11+
def is_bf16_supported():
12+
"""Return if BF16 has hardware supported"""
13+
gpu_arch = get_device_compute_capability(0)
14+
return gpu_arch >= 80

examples/jax/encoder/test_model_parallel_encoder.py

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import transformer_engine.jax as te
2323
import transformer_engine.jax.flax as te_flax
2424

25+
from common import is_bf16_supported
26+
2527
DEVICE_DP_AXIS = "data"
2628
DEVICE_TP_AXIS = "model"
2729
NAMED_BROADCAST_AXIS = "my_broadcast_axis"
@@ -434,6 +436,7 @@ def setUpClass(cls):
434436
"""Run 3 epochs for testing"""
435437
cls.args = encoder_parser(["--epochs", "3"])
436438

439+
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
437440
def test_te_bf16(self):
438441
"""Test Transformer Engine with BF16"""
439442
actual = train_and_evaluate(self.args)
@@ -446,6 +449,7 @@ def test_te_fp8(self):
446449
actual = train_and_evaluate(self.args)
447450
assert actual[0] < 0.45 and actual[1] > 0.79
448451

452+
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
449453
def test_te_bf16_sp(self):
450454
"""Test Transformer Engine with BF16 + SP"""
451455
self.args.enable_sp = True

examples/jax/encoder/test_multigpu_encoder.py

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import transformer_engine.jax as te
2323
import transformer_engine.jax.flax as te_flax
2424

25+
from common import is_bf16_supported
26+
2527
DEVICE_DP_AXIS = "data"
2628
PARAMS_KEY = "params"
2729
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
@@ -402,6 +404,7 @@ def setUpClass(cls):
402404
"""Run 3 epochs for testing"""
403405
cls.args = encoder_parser(["--epochs", "3"])
404406

407+
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
405408
def test_te_bf16(self):
406409
"""Test Transformer Engine with BF16"""
407410
actual = train_and_evaluate(self.args)

examples/jax/encoder/test_multiprocessing_encoder.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import transformer_engine.jax as te
2525
import transformer_engine.jax.flax as te_flax
2626

27+
from common import is_bf16_supported
28+
2729
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
2830
DEVICE_DP_AXIS = "data"
2931
DEVICE_TP_AXIS = "model"
@@ -552,8 +554,9 @@ def encoder_parser(args):
552554
def query_gpu(q):
553555
"""Query GPU info on the system"""
554556
gpu_has_fp8, reason = te.fp8.is_fp8_available()
557+
gpu_has_bf16 = is_bf16_supported()
555558
num_gpu = len(jax.devices())
556-
q.put([num_gpu, gpu_has_fp8, reason])
559+
q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason])
557560

558561

559562
def unittest_query_gpu():
@@ -566,15 +569,15 @@ def unittest_query_gpu():
566569
q = mp.Queue()
567570
p = mp.Process(target=query_gpu, args=(q,))
568571
p.start()
569-
num_gpu, gpu_has_fp8, reason = q.get()
572+
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get()
570573
p.join()
571-
return num_gpu, gpu_has_fp8, reason
574+
return num_gpu, gpu_has_fp8, gpu_has_bf16, reason
572575

573576

574577
class TestEncoder(unittest.TestCase):
575578
"""Encoder unittests"""
576579

577-
num_gpu, gpu_has_fp8, reason = unittest_query_gpu()
580+
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu()
578581

579582
def exec(self, use_fp8):
580583
"""Run 3 epochs for testing"""
@@ -598,6 +601,7 @@ def exec(self, use_fp8):
598601

599602
return results
600603

604+
@unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16")
601605
def test_te_bf16(self):
602606
"""Test Transformer Engine with BF16"""
603607
results = self.exec(False)

examples/jax/encoder/test_single_gpu_encoder.py

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import transformer_engine.jax as te
2020
import transformer_engine.jax.flax as te_flax
2121

22+
from common import is_bf16_supported
23+
2224
PARAMS_KEY = "params"
2325
DROPOUT_KEY = "dropout"
2426
INPUT_KEY = "input_rng"
@@ -321,6 +323,7 @@ def setUpClass(cls):
321323
"""Run 4 epochs for testing"""
322324
cls.args = encoder_parser(["--epochs", "3"])
323325

326+
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
324327
def test_te_bf16(self):
325328
"""Test Transformer Engine with BF16"""
326329
actual = train_and_evaluate(self.args)

pylintrc

+2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ extension-pkg-whitelist=flash_attn_2_cuda,
88
extension-pkg-allow-list=transformer_engine.transformer_engine_jax
99

1010
disable=too-many-locals,
11+
too-few-public-methods,
1112
too-many-public-methods,
13+
too-many-positional-arguments,
1214
invalid-name,
1315
too-many-arguments,
1416
abstract-method,

qa/L0_jax_lint/test.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ set -e
66

77
: "${TE_PATH:=/opt/transformerengine}"
88

9-
pip install cpplint==1.6.0 pylint==2.13.5
9+
pip install cpplint==1.6.0 pylint==3.3.1
1010
if [ -z "${PYTHON_ONLY}" ]
1111
then
1212
cd $TE_PATH

qa/L0_jax_unittest/test.sh

-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,5 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
1818

1919
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist
2020

21-
# Make encoder tests to have run-to-run deterministic to have the stable CI results
22-
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
2321
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
2422
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py

qa/L0_paddle_lint/test.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ set -e
66

77
: "${TE_PATH:=/opt/transformerengine}"
88

9-
pip install cpplint==1.6.0 pylint==2.13.5
9+
pip install cpplint==1.6.0 pylint==3.3.1
1010
if [ -z "${PYTHON_ONLY}" ]
1111
then
1212
cd $TE_PATH

qa/L0_paddle_wheel/test.sh

+7-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ set -e
66

77
: "${TE_PATH:=/opt/transformerengine}"
88

9-
pip install wheel==0.44.0 pydantic
9+
# Install dependencies
10+
# Note: Need to install wheel locally since PaddlePaddle container
11+
# already contains APT install.
12+
pip install pydantic
13+
pip install --user wheel==0.44.0
1014

1115
cd $TE_PATH
1216
pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle
@@ -16,11 +20,11 @@ WHL_BASE="transformer_engine-${VERSION}"
1620

1721
# Core wheel.
1822
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
19-
wheel unpack dist/*
23+
python -m wheel unpack dist/*
2024
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
2125
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
2226
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
23-
wheel pack ${WHL_BASE}
27+
python -m wheel pack ${WHL_BASE}
2428
rm dist/*.whl
2529
mv *.whl dist/
2630
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel

qa/L0_pytorch_lint/test.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ set -e
66

77
: "${TE_PATH:=/opt/transformerengine}"
88

9-
pip install cpplint==1.6.0 pylint==2.13.5
9+
pip install cpplint==1.6.0 pylint==3.3.1
1010
if [ -z "${PYTHON_ONLY}" ]
1111
then
1212
cd $TE_PATH

qa/L0_pytorch_unittest/test.sh

+2-4
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,19 @@ set -e
66

77
: ${TE_PATH:=/opt/transformerengine}
88

9-
pip install pytest==8.2.1 onnxruntime==1.13.1
9+
pip install pytest==8.2.1
1010
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
1111
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
1212
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
1313
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
1414
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
1515
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
16-
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
16+
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
1717
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
18-
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
1918
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
2019
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
2120
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
2221
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
2322
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
2423
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
25-
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
2624
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py

qa/L1_pytorch_context_parallel_test/test.sh

-10
This file was deleted.

0 commit comments

Comments
 (0)