1616* limitations under the License.
1717*
1818**************************************************************************/
19- #include < CL/sycl/detail/pi.hpp>
2019#include " cublas_helper.hpp"
21- #include " cublas_scope_handle.hpp"
20+ #include " cublas_task.hpp"
21+
2222#include " oneapi/mkl/exceptions.hpp"
2323#include " oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp"
2424
@@ -42,12 +42,12 @@ inline void gemm_batch(Func func, cl::sycl::queue &queue, transpose transa, tran
4242 auto a_acc = a.template get_access <cl::sycl::access::mode::read>(cgh);
4343 auto b_acc = b.template get_access <cl::sycl::access::mode::read>(cgh);
4444 auto c_acc = c.template get_access <cl::sycl::access::mode::read_write>(cgh);
45- cgh.interop_task ([=](cl::sycl::interop_handler ih) {
46- auto sc = CublasScopedContextHandler (queue);
45+ onemkl_cublas_host_task (cgh, queue,[=](CublasScopedContextHandler sc) {
4746 auto handle = sc.get_handle (queue);
48- auto a_ = sc.get_mem <cuDataType *>(ih, a_acc);
49- auto b_ = sc.get_mem <cuDataType *>(ih, b_acc);
50- auto c_ = sc.get_mem <cuDataType *>(ih, c_acc);
47+
48+ auto a_ = sc.get_mem <cuDataType *>(a_acc);
49+ auto b_ = sc.get_mem <cuDataType *>(b_acc);
50+ auto c_ = sc.get_mem <cuDataType *>(c_acc);
5151 cublasStatus_t err;
5252 CUBLAS_ERROR_FUNC (func, err, handle, get_cublas_operation (transa),
5353 get_cublas_operation (transb), m, n, k, (cuDataType *)&alpha, a_, lda,
@@ -122,9 +122,9 @@ inline cl::sycl::event gemm_batch(Func func, cl::sycl::queue &queue, transpose t
122122 for (int64_t i = 0 ; i < num_events; i++) {
123123 cgh.depends_on (dependencies[i]);
124124 }
125- cgh.interop_task ([=](cl::sycl::interop_handler ih) {
126- auto sc = CublasScopedContextHandler (queue);
125+ onemkl_cublas_host_task (cgh, queue,[=](CublasScopedContextHandler sc) {
127126 auto handle = sc.get_handle (queue);
127+
128128 auto a_ = reinterpret_cast <const cuDataType *>(a);
129129 auto b_ = reinterpret_cast <const cuDataType *>(b);
130130 auto c_ = reinterpret_cast <cuDataType *>(c);
@@ -170,9 +170,9 @@ inline cl::sycl::event gemm_batch(Func func, cl::sycl::queue &queue, transpose *
170170 for (int64_t i = 0 ; i < num_events; i++) {
171171 cgh.depends_on (dependencies[i]);
172172 }
173- cgh.interop_task ([=](cl::sycl::interop_handler ih) {
174- auto sc = CublasScopedContextHandler (queue);
173+ onemkl_cublas_host_task (cgh, queue,[=](CublasScopedContextHandler sc) {
175174 auto handle = sc.get_handle (queue);
175+
176176 int64_t offset = 0 ;
177177 cublasStatus_t err;
178178 for (int64_t i = 0 ; i < group_count; i++) {
0 commit comments