From 7413843fd7d9b4a98f9abdb8843b24821a6b96a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Mon, 20 May 2024 09:33:04 -0700 Subject: [PATCH 01/10] [PyTorch] Fixed bug with loading calibrated weights (#771) * Calibration fix Signed-off-by: Pawel Gadzinski * Lint fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Co-authored-by: Pawel Gadzinski --- qa/L0_pytorch_unittest/test.sh | 1 + tests/pytorch/test_torch_save_load.py | 37 +++++++++++++++++++++-- transformer_engine/pytorch/module/base.py | 19 ++++++++++++ 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 2c14664dce..2aa58e6018 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -17,3 +17,4 @@ NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_a pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py +pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py \ No newline at end of file diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index 85ec7685b3..211030fe6d 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -65,6 +65,9 @@ def __init__(self, precision, use_bias): self.inp_type = tex.DType.kFloat8E4M3 self.weights_type = tex.DType.kFloat8E4M3 self.outp_type = precision + + def get_fp8_weights_scratchpad(self, is_first_microbatch): + raise RuntimeError("Method get_fp8_weights_scratchpad is dummy and should not be invoked.") def forward(self, inp, weight): inp_fp8 = cast_to_fp8( @@ -145,14 +148,11 @@ def test_fp8_model_checkpoint( params_dtype=dtype, device=device, ) - # Keep track of model output x = torch.randn(dims, dtype=dtype, device=device) with te.fp8_autocast(): y_ref = model(x.detach().clone()).detach().clone() - # Keep track of weights and FP8 scaling factors - weight_ref = model.weight.float().detach().clone() fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} } with te.fp8_autocast(), torch.no_grad(): fp8_meta_fwd = model.fp8_meta["scaling_fwd"] @@ -168,6 +168,18 @@ def test_fp8_model_checkpoint( fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"]) fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"]) del fp8_meta_fwd, fp8_meta_bwd + + # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] + # This line copies the fp8 scale_inv from the model metadata to the weight fp8 tensor. + # The sole purpose of the following lines is to set the scale_inv of the weight tensor, which is the simplest method. + # It is essential for these values to be equal, so setting scale_inv only in the model metadata is insufficient. + model.weight.data.copy_(model.weight.float().cuda()) + # After copying, the tensor computes the meta scale_inv based on the amax history; we then reset these values. + model.fp8_meta["scaling_fwd"].scale = fp8_meta_fwd_ref["scale"] + model.fp8_meta["scaling_fwd"].scale_inv = fp8_meta_fwd_ref["scale_inv"] + + # Keep track of weights and FP8 scaling factors + weight_ref = model.weight.float().detach().clone() # Save checkpoint byte_stream = io.BytesIO() @@ -214,6 +226,18 @@ def test_fp8_model_checkpoint( with pytest.raises(AssertionError): torch.testing.assert_close(y, y_ref, **tols) + + # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] + # When save_fp8_model=True, we load a model with weights in high precision, + # which does not include _scale_inv, + # but has the fp8 scaling factor in the meta data. This scenario can occur + # when using te.fp8_autocast(enabled=False, calibrating=True). + # + # In such cases, the default behavior of load_state_dict is incorrect - it loads tensors first, + # followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior + # is corrected by overriding the _load_state_dict method from PyTorch in TransformerEngineBaseModule, + # to load the fp8 metadata before loading tensors. + # # Load checkpoint model.load_state_dict(torch.load(io.BytesIO(model_bytes))) del model_bytes @@ -232,3 +256,10 @@ def test_fp8_model_checkpoint( with te.fp8_autocast(): y = model(x.detach().clone()) torch.testing.assert_close(y, y_ref, **tols) + + if load_fp8_model: + # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] + # We need to ensure that the tensor's scale_inv parameter matches its meta data. + # This is crucial to avoid confusion about which value is correct. + meta_index = model.weight._fp8_meta_index + torch.testing.assert_close(model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item()) \ No newline at end of file diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0803b474f6..31011be897 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -858,3 +858,22 @@ def get_fp8_weights_scratchpad( is_first_microbatch: Union[bool, None], ) -> List[torch.Tensor]: """Needs override.""" + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """ + This function loads tensors and extra state including fp8 metadata. + This metadata is essential for copying fp8 tensors, as the copy_ function + uses the scale_inv parameter from fp8_meta to set the correct scaling factor + for the new tensor. + Hence, this extra state must be loaded before the tensor copying process, + not after, as is typically done in _load_from_state_dict. + Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True, + otherwise, this behavior is not required. + """ + if self.primary_weights_in_fp8: + extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) From b2f2e1dc09faa9329f17fb36f9fed6357e0e9c50 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 21 May 2024 15:41:36 -0500 Subject: [PATCH 02/10] [PyTorch] Replaced deprecated `pkg_resources` with `packaging` (#860) replaced deprecated pkg_resources with packaging Signed-off-by: Alp Dener --- setup.py | 1 + transformer_engine/pytorch/attention.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 769d62a25b..e7bf2f38b7 100644 --- a/setup.py +++ b/setup.py @@ -246,6 +246,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: install_reqs: List[str] = [ "pydantic", "importlib-metadata>=1.0; python_version<'3.8'", + "packaging", ] test_reqs: List[str] = ["pytest"] diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d4198e688d..841f2ba8af 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5,14 +5,14 @@ """Attention.""" import collections from contextlib import nullcontext -from importlib.metadata import version +from importlib.metadata import version as get_pkg_version import math import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings import numpy as np -from pkg_resources import packaging +from packaging.version import Version as PkgVersion import torch import torch.nn.functional as F @@ -67,13 +67,13 @@ from transformer_engine.pytorch.graph import is_graph_capturing -_flash_attn_version = packaging.version.Version(version("flash-attn")) -_flash_attn_version_required = packaging.version.Version("2.0.6") -_flash_attn_max_version = packaging.version.Version("2.5.8") -_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1") -_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3") -_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4") -_flash_attn_2_4_1_plus = _flash_attn_version >= packaging.version.Version("2.4.1") +_flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) +_flash_attn_version_required = PkgVersion("2.0.6") +_flash_attn_max_version = PkgVersion("2.5.8") +_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") +_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") +_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") +_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") if _flash_attn_version >= _flash_attn_version_required: from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module From 5895eab18609829c793c2112c6a3d1b358a5aee9 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Tue, 21 May 2024 17:01:26 -0700 Subject: [PATCH 03/10] [Common] Added Alignment Requirements for CuBLAS heuristics (#845) * added alignment requirements for CuBLAS heuristics Signed-off-by: Phuong Nguyen * minor rewords Signed-off-by: Phuong Nguyen * added unit test for gemm with unaligned inputs Signed-off-by: Phuong Nguyen * added pytest skip if fp8 is not available Signed-off-by: Phuong Nguyen * changed offset so that it has alignment with 128 Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- tests/pytorch/test_sanity.py | 62 +++++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 28 ++++++++- 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index cf17eccd1b..91e67e8f9a 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -29,6 +29,10 @@ get_cpu_offload_context, ) from transformer_engine.common import recipe +import transformer_engine_extensions as tex +from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8 +from transformer_engine.pytorch.module.base import get_workspace +from test_onnx_export import create_meta # Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -924,3 +928,61 @@ def test_model_multiple_cast(): y2 = m(a) assert y2.dtype == torch.float16 + + +@pytest.mark.parametrize("N", [32]) +@pytest.mark.parametrize("offset", [1, 3, 5]) +@pytest.mark.parametrize("datatype", param_types) +def test_sanity_gemm_with_unalignment(N, offset, datatype): + scratchpad = torch.randn(N*N + 2*offset, device="cuda", dtype=datatype) + inp = torch.reshape(scratchpad[offset:-offset], (N, N)) + weight = torch.reshape(scratchpad[offset*2:], (N, N)) + + _, _, _ = gemm( + A=weight, + B=inp, + dtype=datatype, + workspace=get_workspace()) + torch.cuda.synchronize() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("N", [32]) +@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) +def test_sanity_fp8_gemm_with_unalignment(N, datatype): + offset = 16 + scratchpad = torch.randn(N*N + offset, device="cuda", dtype=datatype) + + fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT + fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT + + nb_inp_scales, nb_weight_scales = 1, N + scale_factor = 1. + meta_inp = create_meta(scale_factor, nb_inp_scales) + meta_weight = create_meta(scale_factor, nb_weight_scales) + inp_type = tex.DType.kFloat8E4M3 + weights_type = tex.DType.kFloat8E4M3 + outp_type = datatype + + scratchpad_fp8 = cast_to_fp8( + scratchpad, + meta_weight, + fp8_tensor_inp, + inp_type) + inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N)) + weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N)) + _, _ = fp8_gemm( + weight_fp8, + meta_weight.scale_inv, + fp8_tensor_weight, + inp_type, + inp_fp8, + meta_inp.scale_inv, + fp8_tensor_inp, + weights_type, + outp_type, + get_workspace(), + bias=None, + use_bias=False, + use_split_accumulator=False) + torch.cuda.synchronize() diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index d68c21cd19..a4c65661dc 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -9,6 +9,7 @@ #include #include #include +#include #include #include "../common.h" @@ -34,6 +35,16 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { } } +uint32_t _getAlignment(uintptr_t address) { + // alignment are in bytes + uint32_t alignment = 256; + for (; ; alignment /= 2) { + if (address % alignment == 0) { + return alignment; + } + } +} + } // namespace namespace transformer_engine { @@ -260,6 +271,22 @@ void cublas_gemm(const Tensor *inputA, NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); + const auto A_alignment = _getAlignment(reinterpret_cast(A)); + const auto B_alignment = _getAlignment(reinterpret_cast(B)); + const auto C_alignment = _getAlignment(reinterpret_cast(C)); + const auto D_alignment = _getAlignment(reinterpret_cast(D)); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, + &A_alignment, sizeof(A_alignment))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, + &B_alignment, sizeof(B_alignment))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, + &C_alignment, sizeof(C_alignment))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, + &D_alignment, sizeof(D_alignment))); const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, &heuristicResult, @@ -271,7 +298,6 @@ void cublas_gemm(const Tensor *inputA, if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C - NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, static_cast(&one), /* alpha */ From 08042a509c999844685dfeda7d4332be2da12c7e Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 22 May 2024 13:52:38 -0500 Subject: [PATCH 04/10] [PyTorch] Support `torch.amp.autocast` in TE checkpoint (#791) TE checkpoint now preserves the torch autocast context from the forward pass during the recompute phase Signed-off-by: Alp Dener --- transformer_engine/pytorch/distributed.py | 46 +++++++++++++++++++---- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index caaef91985..b0fb80b6a1 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -228,6 +228,26 @@ def in_fp8_activation_recompute_phase() -> bool: return _FP8_ACTIVATION_RECOMPUTE_PHASE +def _get_active_autocast_contexts(): + """ + Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state + at the time of this function's execution. + """ + autocast_cached = torch.is_autocast_cache_enabled() + + gpu_autocast_enabled = torch.is_autocast_enabled() + gpu_autocast_dtype = torch.get_autocast_gpu_dtype() + gpu_autocast_ctx = torch.cuda.amp.autocast( + gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached) + + cpu_autocast_enabled = torch.is_autocast_cpu_enabled() + cpu_autocast_dtype = torch.get_autocast_cpu_dtype() + cpu_autocast_ctx = torch.cpu.amp.autocast( + cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached) + + return gpu_autocast_ctx, cpu_autocast_ctx + + class _CheckpointFunction(torch.autograd.Function): """This function is adapted from torch.utils.checkpoint with two main changes: @@ -262,6 +282,10 @@ def forward( forward_ctx, recompute_ctx = context_fn() else: forward_ctx, recompute_ctx = noop_context_fn() + + # Preserve torch autocast context for the backward pass + torch_gpu_amp_ctx, torch_cpu_amp_ctx = _get_active_autocast_contexts() + with torch.no_grad(), forward_ctx: with activation_recompute_forward( activation_recompute=True, recompute_phase=False @@ -287,6 +311,8 @@ def forward( ctx.get_rng_state_tracker = get_rng_state_tracker ctx.tp_group = tp_group ctx.recompute_ctx = recompute_ctx + ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx + ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx ctx.kwargs = kwargs return outputs @@ -331,11 +357,11 @@ def backward( # Compute the forward pass. detached_inputs = detach_variable(inputs) - with torch.enable_grad(), ctx.recompute_ctx: - with activation_recompute_forward( - activation_recompute=True, recompute_phase=True - ): - outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) + with (torch.enable_grad(), ctx.recompute_ctx, + ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, + activation_recompute_forward( + activation_recompute=True, recompute_phase=True)): + outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) # Set the states back to what it was at the start of this function. torch.set_rng_state(bwd_cpu_rng_state) @@ -639,8 +665,13 @@ def checkpoint( user_forward_ctx, user_recompute_ctx = context_fn() te_forward_ctx, te_recompute_ctx = get_activation_recompute_contexts() + # Preserve the torch autocast contexts from the forward pass during recompute phase. + torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts() + def recompute_fn(*args, **kwargs): - with torch.autograd.enable_grad(), te_recompute_ctx, user_recompute_ctx: + with (torch.autograd.enable_grad(), + te_recompute_ctx, user_recompute_ctx, + torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx): function(*args, **kwargs) # Initialize a new checkpoint frame for each new forward pass. @@ -650,7 +681,8 @@ def recompute_fn(*args, **kwargs): ) new_frame.cache_rng_states(forward=True) - with _checkpoint_hook(new_frame, args, kwargs), te_forward_ctx, user_forward_ctx: + with (_checkpoint_hook(new_frame, args, kwargs), + te_forward_ctx, user_forward_ctx): out = function(*args, **kwargs) return out From 7190c30a4d9159a0b5466d2f85f5bb29e63fe3f9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 24 May 2024 18:42:32 -0700 Subject: [PATCH 05/10] [C] Allow bias support for sm80/86/89 for cuDNN 9+ (#863) allow bias support for sm80/86/89 for cuDNN 9+ Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 2d9759898f..71f8e6c6d9 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -148,7 +148,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && sm_arch_ == 90) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS - && sm_arch_ == 90)))) + && sm_arch_ == 90))) + || ((cudnn_runtime_version >= 90000) + && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS + && sm_arch_ >= 80))) && ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || ((cudnn_runtime_version >= 8906) && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK From 0c4cc05d369acd7b103bc0a49e46355334459446 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 22 May 2024 14:33:22 -0400 Subject: [PATCH 06/10] [JAX] Fixed the shape miss-matching issue in MLP. (#859) * Fixed the shape mismatching issue in MLP. Signed-off-by: Ming Huang * Add a corresponding test Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- tests/jax/test_layer.py | 2 ++ transformer_engine/jax/flax/module.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 1493b50cf0..a3a506f1c1 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -177,6 +177,8 @@ def enable_fused_attn(): _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias", }, { _KEY_OF_ATTENTION_DROPOUT: 0.3, +}, { + _KEY_OF_MLP_ACTIVATIONS: (('relu', 'relu')), }] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 442396d47c..1f827b505a 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1148,8 +1148,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): x_i = _convert_to_activation_function(act_fn)(x[idx]) activations.append(x_i) z = functools.reduce(operator.mul, activations) - if num_activations == 1: - z = jnp.reshape(z, (*z.shape[:-2], -1)) + # Remove act axis + z = jnp.reshape(z, (*z.shape[:-2], -1)) z = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=self.intermediate_hidden_dropout_dims, From ad24fc549bb276c015b2e50c4ec1141626cf3e43 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 28 May 2024 16:38:55 -0700 Subject: [PATCH 07/10] Use correct FP8 group in multi-GPU docs (#852) * Use correct FP8 group in multi-GPU docs FP8 process group should be tensor-parallel group Signed-off-by: Tim Moon * Synchronize FP8 scales over world group in multi-GPU docs Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon --- docs/examples/advanced_optimizations.ipynb | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/examples/advanced_optimizations.ipynb b/docs/examples/advanced_optimizations.ipynb index c7cd006dae..3d889859ba 100644 --- a/docs/examples/advanced_optimizations.ipynb +++ b/docs/examples/advanced_optimizations.ipynb @@ -115,12 +115,13 @@ "# Configure parallel groups\n", "import os\n", "import torch\n", - "world_group = torch.distributed.init_process_group(\n", + "torch.distributed.init_process_group(\n", " \"nccl\",\n", " init_method=\"file:///tmp/rdzv\",\n", " world_size=1,\n", " rank=0,\n", ")\n", + "world_group = torch.distributed.new_group(ranks=[0], backend=\"nccl\")\n", "data_parallel_group = torch.distributed.new_group(ranks=[0], backend=\"nccl\")\n", "tensor_parallel_group = torch.distributed.new_group(ranks=[0], backend=\"nccl\")" ] @@ -132,7 +133,9 @@ "source": [ "We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\times \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n", "\n", - "Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). FP8 training requires extra synchronization for the scaling factors, so the data-parallel process group must also be passed to the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group. In this case, the tensor parallel group must also be passed to the **fp8_group** argument in the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager, either directly or as a subset of a larger distributed group." + "Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n", + "\n", + "One important consideration for multi-GPU FP8 training is how to synchronize the FP8 scaling factors between GPUs. If tensor parallelism is enabled, the scales must be synchronized over the tensor-parallel group. However, synchronizing over both the data-parallel and tensor-parallel groups is recommended for the best convergence. This can be configured with the **fp8_group** argument in the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager." ] }, { @@ -166,7 +169,7 @@ ")\n", "\n", "# Training step\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=data_parallel_group):\n", + "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=world_group):\n", " y = parallel_transformer(x, attention_mask=None)\n", "y.backward(dy)\n", "\n", @@ -179,7 +182,7 @@ " fp8_autocast_kwargs = {\n", " \"enabled\": True,\n", " \"fp8_recipe\": fp8_recipe,\n", - " \"fp8_group\": data_parallel_group,\n", + " \"fp8_group\": world_group,\n", " },\n", ")" ] From 4e4aecbd11faefbba6d5e2789a7747bca73890b4 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 28 May 2024 18:25:25 -0700 Subject: [PATCH 08/10] [PyTorch] Make sure RoPE frequencies are in FP32 (#875) Make sure RoPE frequencies are in FP32 Signed-off-by: Tim Moon --- transformer_engine/pytorch/attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 841f2ba8af..a6e2a7a21a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1432,6 +1432,8 @@ def forward( tensor_format: str = "sbhd", cu_seqlens: Union[torch.Tensor, None] = None, ) -> torch.Tensor: + if freqs.dtype != torch.float32: + freqs = freqs.float() if tensor_format == "sbhd": output = tex.fused_rope_forward(t, freqs, False) elif tensor_format == "bshd": From 61ffb58357291cac967bc1d1579f31b9afff46b8 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 29 May 2024 09:50:57 -0700 Subject: [PATCH 09/10] New NVIDIA footer in documentation (#876) * Change the documentation footer Signed-off-by: Przemek Tredak * Update docs toolchain versions Signed-off-by: Przemek Tredak --------- Signed-off-by: Przemek Tredak --- .github/workflows/docs.yml | 4 ++-- docs/_static/NVIDIA-LogoBlack.svg | 1 + docs/_static/css/nvidia_footer.css | 29 +++++++++++++++++++++++++++++ docs/_templates/footer.html | 23 +++++++++++++++++++++++ docs/_templates/layout.html | 4 ---- docs/conf.py | 2 ++ 6 files changed, 57 insertions(+), 6 deletions(-) create mode 100644 docs/_static/NVIDIA-LogoBlack.svg create mode 100644 docs/_static/css/nvidia_footer.css create mode 100644 docs/_templates/footer.html diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index b4eeefa70b..581ff1e935 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -17,8 +17,8 @@ jobs: uses: actions/checkout@v3 - name: 'Install dependencies' run: | - pip install sphinx==5.1.1 sphinx_rtd_theme==1.0.0 nbsphinx==0.8.10 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==2.15.7 - pip install breathe==4.34.0 sphinx-autoapi==2.0.1 + pip install sphinx==7.1.2 sphinx_rtd_theme==2.0.0 nbsphinx==0.9.4 IPython ipython_genutils==0.2.0 ipywidgets==8.1.3 astroid==3.2.2 + pip install breathe==4.35.0 sphinx-autoapi==3.1.1 sudo apt-get install -y pandoc graphviz doxygen export GIT_SHA=$(git show-ref --hash HEAD) - name: 'Build docs' diff --git a/docs/_static/NVIDIA-LogoBlack.svg b/docs/_static/NVIDIA-LogoBlack.svg new file mode 100644 index 0000000000..c612396c71 --- /dev/null +++ b/docs/_static/NVIDIA-LogoBlack.svg @@ -0,0 +1 @@ +NVIDIA-LogoBlack \ No newline at end of file diff --git a/docs/_static/css/nvidia_footer.css b/docs/_static/css/nvidia_footer.css new file mode 100644 index 0000000000..9d18fb3b47 --- /dev/null +++ b/docs/_static/css/nvidia_footer.css @@ -0,0 +1,29 @@ +footer img { + display: block; + width: 137.5px; + position: relative; + left: -9px; + margin: 0 0 15px 0; +} + +footer p { + color: #666666; + font-weight: normal; + font-size: 12px; + line-height: 1.25em; +} + +footer p:not(.notices) { + display: inline; + margin: 0; +} + +footer p a, +footer p a:link, +footer p a:visited { + color: #666666; +} + +footer p a:hover { + color: #666666; +} diff --git a/docs/_templates/footer.html b/docs/_templates/footer.html new file mode 100644 index 0000000000..1ef5505d34 --- /dev/null +++ b/docs/_templates/footer.html @@ -0,0 +1,23 @@ +{% extends '!footer.html' %} + +{% block contentinfo %} + +

+Privacy Policy +| +Manage My Privacy +| +Do Not Sell or Share My Data +| +Terms of Service +| +Accessibility +| +Corporate Policies +| +Product Security +| +Contact +

+{{ super() }} +{% endblock %} diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html index 65b5b90931..cb372b3a72 100644 --- a/docs/_templates/layout.html +++ b/docs/_templates/layout.html @@ -61,10 +61,6 @@ } - {% endblock %} - - {% block footer %} {{ super() }} -