Skip to content

Commit 4a51281

Browse files
authored
[BLAS] Avoid blocking wait inside native-command callable (#681)
1 parent ca1ecc2 commit 4a51281

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/blas/backends/cublas/cublas_level3.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,15 @@ inline void gemm_ex(DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, sycl::que
9595
auto c_ = sc.get_mem<cuDataType_C*>(c_acc);
9696
cublasStatus_t err;
9797
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
98+
CUBLAS_ERROR_FUNC(cublasGemmEx, err, handle, get_cublas_operation(transa),
99+
get_cublas_operation(transb), m, n, k, (cuDataType_C*)&alpha, a_,
100+
DT_A, lda, b_, DT_B, ldb, (cuDataType_C*)&beta, c_, DT_C, ldc, DT_C,
101+
CUBLAS_GEMM_DEFAULT);
102+
#else
98103
CUBLAS_ERROR_FUNC_SYNC(cublasGemmEx, err, handle, get_cublas_operation(transa),
99104
get_cublas_operation(transb), m, n, k, (cuDataType_C*)&alpha, a_,
100105
DT_A, lda, b_, DT_B, ldb, (cuDataType_C*)&beta, c_, DT_C, ldc,
101106
DT_C, CUBLAS_GEMM_DEFAULT);
102-
#else
103-
CUBLAS_ERROR_FUNC(cublasGemmEx, err, handle, get_cublas_operation(transa),
104-
get_cublas_operation(transb), m, n, k, (cuDataType_C *)&alpha,
105-
a_, DT_A, lda, b_, DT_B, ldb, (cuDataType_C *)&beta, c_, DT_C,
106-
ldc, DT_C, CUBLAS_GEMM_DEFAULT);
107107
#endif
108108
});
109109
});
@@ -500,15 +500,15 @@ inline sycl::event gemm_ex_usm(DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C
500500
auto c_ = reinterpret_cast<cuDataType_C*>(c);
501501
cublasStatus_t err;
502502
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
503+
CUBLAS_ERROR_FUNC(cublasGemmEx, err, handle, get_cublas_operation(transa),
504+
get_cublas_operation(transb), m, n, k, (cuDataType_C*)&alpha, a_,
505+
DT_A, lda, b_, DT_B, ldb, (cuDataType_C*)&beta, c_, DT_C, ldc, DT_C,
506+
CUBLAS_GEMM_DEFAULT);
507+
#else
503508
CUBLAS_ERROR_FUNC_SYNC(cublasGemmEx, err, handle, get_cublas_operation(transa),
504509
get_cublas_operation(transb), m, n, k, (cuDataType_C*)&alpha, a_,
505510
DT_A, lda, b_, DT_B, ldb, (cuDataType_C*)&beta, c_, DT_C, ldc,
506511
DT_C, CUBLAS_GEMM_DEFAULT);
507-
#else
508-
CUBLAS_ERROR_FUNC(cublasGemmEx, err, handle, get_cublas_operation(transa),
509-
get_cublas_operation(transb), m, n, k, (cuDataType_C *)&alpha,
510-
a_, DT_A, lda, b_, DT_B, ldb, (cuDataType_C *)&beta, c_, DT_C,
511-
ldc, DT_C, CUBLAS_GEMM_DEFAULT);
512512
#endif
513513
});
514514
});

0 commit comments

Comments
 (0)