Skip to content

Commit 5b90b7f

Browse files
cyanguwatimmoon10
authored andcommitted
Update cudnn-frontend to 1.0.3 to fix cuDNN v9 SDPA NaNs (#650)
* Update cudnn frontend to 1.0.3 to fix cudnn v9 Nans Signed-off-by: Charlene Yang <[email protected]> * make d_out contiguous for bwd Signed-off-by: Charlene Yang <[email protected]> * remove cudnnDestroy to let torch handle it Signed-off-by: Charlene Yang <[email protected]> * Update transformer_engine/pytorch/attention.py Co-authored-by: Tim Moon <[email protected]> Signed-off-by: cyanguwa <[email protected]> * Update transformer_engine/pytorch/attention.py Co-authored-by: Tim Moon <[email protected]> Signed-off-by: cyanguwa <[email protected]> * Update transformer_engine/pytorch/attention.py Co-authored-by: Tim Moon <[email protected]> Signed-off-by: cyanguwa <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: cyanguwa <[email protected]> Co-authored-by: Tim Moon <[email protected]>
1 parent df9c29e commit 5b90b7f

File tree

3 files changed

+4
-6
lines changed

3 files changed

+4
-6
lines changed

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 49 files

transformer_engine/common/fused_attn/utils.h

-5
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,6 @@ class cudnnExecutionPlanManager {
152152
}
153153

154154
~cudnnExecutionPlanManager() {
155-
static thread_local std::once_flag flag;
156-
std::call_once(flag, [&] {
157-
if (handle_ != nullptr) {
158-
cudnnDestroy(handle_);
159-
}});
160155
}
161156

162157
private:

transformer_engine/pytorch/attention.py

+3
Original file line numberDiff line numberDiff line change
@@ -1733,6 +1733,7 @@ def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias,
17331733

17341734
@staticmethod
17351735
def backward(ctx, d_out):
1736+
d_out = d_out.contiguous()
17361737
qkv, out, cu_seqlens = ctx.saved_tensors
17371738
if not ctx.aux_ctx_tensors[0].is_contiguous():
17381739
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
@@ -1802,6 +1803,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql
18021803

18031804
@staticmethod
18041805
def backward(ctx, d_out):
1806+
d_out = d_out.contiguous()
18051807
q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
18061808
if not ctx.aux_ctx_tensors[0].is_contiguous():
18071809
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
@@ -1883,6 +1885,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql
18831885

18841886
@staticmethod
18851887
def backward(ctx, d_out):
1888+
d_out = d_out.contiguous()
18861889
q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
18871890
if not ctx.aux_ctx_tensors[0].is_contiguous():
18881891
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()

0 commit comments

Comments
 (0)