From 5b90b7f5ed67b373bc5f843d1ac3b7a8999df08e Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 2 Feb 2024 20:36:10 -0800 Subject: [PATCH] Update cudnn-frontend to 1.0.3 to fix cuDNN v9 SDPA NaNs (#650) * Update cudnn frontend to 1.0.3 to fix cudnn v9 Nans Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * make d_out contiguous for bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove cudnnDestroy to let torch handle it Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- transformer_engine/common/fused_attn/utils.h | 5 ----- transformer_engine/pytorch/attention.py | 3 +++ 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 9f82dda5c0..a86ad708db 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 9f82dda5c029d15a5f371f0fe003dc0c74a0c987 +Subproject commit a86ad708db725e4d29919bb6fadf8e6cdfa5dc06 diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 9da0dc553a..44288dd754 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -152,11 +152,6 @@ class cudnnExecutionPlanManager { } ~cudnnExecutionPlanManager() { - static thread_local std::once_flag flag; - std::call_once(flag, [&] { - if (handle_ != nullptr) { - cudnnDestroy(handle_); - }}); } private: diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b7a98de0cd..27c031e267 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1733,6 +1733,7 @@ def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, @staticmethod def backward(ctx, d_out): + d_out = d_out.contiguous() qkv, out, cu_seqlens = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() @@ -1802,6 +1803,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql @staticmethod def backward(ctx, d_out): + d_out = d_out.contiguous() q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() @@ -1883,6 +1885,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql @staticmethod def backward(ctx, d_out): + d_out = d_out.contiguous() q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()