Skip to content

Commit

Permalink
Update cudnn-frontend to 1.0.3 to fix cuDNN v9 SDPA NaNs (#650)
Browse files Browse the repository at this point in the history
* Update cudnn frontend to 1.0.3 to fix cudnn v9 Nans

Signed-off-by: Charlene Yang <[email protected]>

* make d_out contiguous for bwd

Signed-off-by: Charlene Yang <[email protected]>

* remove cudnnDestroy to let torch handle it

Signed-off-by: Charlene Yang <[email protected]>

* Update transformer_engine/pytorch/attention.py

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: cyanguwa <[email protected]>

* Update transformer_engine/pytorch/attention.py

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: cyanguwa <[email protected]>

* Update transformer_engine/pytorch/attention.py

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: cyanguwa <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: cyanguwa <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
2 people authored and ptrendx committed Feb 3, 2024
1 parent df9c29e commit 5b90b7f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 49 files
+1 −1 CMakeLists.txt
+5 −0 README.FE.1.0.md
+7 −1 README.md
+0 −2 include/cudnn_backend_base.h
+3 −1 include/cudnn_frontend.h
+3 −3 include/cudnn_frontend/cudnn_interface.h
+4 −2 include/cudnn_frontend/node/batchnorm.h
+5 −2 include/cudnn_frontend/node/batchnorm_inference.h
+4 −2 include/cudnn_frontend/node/bn_finalize.h
+4 −2 include/cudnn_frontend/node/conv_dgrad.h
+5 −2 include/cudnn_frontend/node/conv_fprop.h
+4 −2 include/cudnn_frontend/node/conv_wgrad.h
+4 −2 include/cudnn_frontend/node/dbn.h
+5 −2 include/cudnn_frontend/node/dbn_weight.h
+4 −2 include/cudnn_frontend/node/dln.h
+4 −2 include/cudnn_frontend/node/genstats.h
+8 −4 include/cudnn_frontend/node/instancenorm.h
+5 −2 include/cudnn_frontend/node/layernorm.h
+5 −2 include/cudnn_frontend/node/matmul.h
+4 −2 include/cudnn_frontend/node/pointwise.h
+4 −2 include/cudnn_frontend/node/reduction.h
+5 −2 include/cudnn_frontend/node/reshape.h
+9 −4 include/cudnn_frontend/node/rmsnorm.h
+4 −2 include/cudnn_frontend/node/rng.h
+119 −4 include/cudnn_frontend/node/scaled_dot_product_flash_attention.h
+27 −4 include/cudnn_frontend/node_interface.h
+0 −3 include/cudnn_frontend_ConvDesc.h
+0 −3 include/cudnn_frontend_Engine.h
+0 −3 include/cudnn_frontend_EngineConfig.h
+0 −1 include/cudnn_frontend_EngineFallbackList.h
+0 −3 include/cudnn_frontend_ExecutionPlan.h
+0 −2 include/cudnn_frontend_Filters.h
+0 −3 include/cudnn_frontend_Heuristics.h
+0 −3 include/cudnn_frontend_MatMulDesc.h
+0 −3 include/cudnn_frontend_Operation.h
+0 −3 include/cudnn_frontend_OperationGraph.h
+0 −3 include/cudnn_frontend_PointWiseDesc.h
+0 −3 include/cudnn_frontend_ReductionDesc.h
+0 −3 include/cudnn_frontend_Resample.h
+0 −3 include/cudnn_frontend_Rng.h
+0 −3 include/cudnn_frontend_VariantPack.h
+9 −5 python_bindings/properties.cpp
+0 −3 samples/legacy_samples/conv_sample.h
+26 −18 samples/legacy_samples/cpu_references.h
+1 −1 samples/legacy_samples/norm_samples.cpp
+9 −9 samples/legacy_samples/test_list.cpp
+4 −2 samples/python/test_conv_bias.py
+21 −9 samples/python/test_mhas.py
+1 −1 setup.py
5 changes: 0 additions & 5 deletions transformer_engine/common/fused_attn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,6 @@ class cudnnExecutionPlanManager {
}

~cudnnExecutionPlanManager() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] {
if (handle_ != nullptr) {
cudnnDestroy(handle_);
}});
}

private:
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,6 +1733,7 @@ def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias,

@staticmethod
def backward(ctx, d_out):
d_out = d_out.contiguous()
qkv, out, cu_seqlens = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
Expand Down Expand Up @@ -1802,6 +1803,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql

@staticmethod
def backward(ctx, d_out):
d_out = d_out.contiguous()
q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
Expand Down Expand Up @@ -1883,6 +1885,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql

@staticmethod
def backward(ctx, d_out):
d_out = d_out.contiguous()
q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
Expand Down

0 comments on commit 5b90b7f

Please sign in to comment.