Skip to content

Commit 5b9dbfc

Browse files
committed
[BLAS] SYCL-Graph integration for native-command
In order to support applications calling the library with a sycl queue recording to a SYCL-Graph, check if the `ext_codeplay_enqueue_native_command` command-group is being recorded to a graph object. If so use the native stream recording APIs to add the blas calls as nodes in the graph. In particular this fixes the llama.cpp unit test `MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0)` on CUDA with SYCL-Graph enabled. Previously this would throw an error: ```sh $ GGML_SYCL_DISABLE_GRAPH=0 ./bin/test-backend-ops -b SYCL0 -o MUL_MAT -p type_a=f16,type_b=f32,m=16,n=1,k=256,bs=\\[1,1\\],nr=\\[2 UR CUDA ERROR: Value: 700 Name: CUDA_ERROR_ILLEGAL_ADDRESS Description: an illegal memory access was encountered Function: operator() Source Location: $HOME/dpcpp/unified-runtime/source/adapters/cuda/queue.cpp:154 Native API failed. Native API returns: 2147483646 (UR_RESULT_ERROR_UNKNOWN) Exception caught at file:$HOME/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp, line:3598, func:operator() SYCL error: CHECK_TRY_ERROR((stream)->wait()): Meet error in this line code! in function ggml_backend_sycl_synchronize at $HOME/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp:3598 $HOME/llama.cpp/ggml/src/ggml-sycl/../ggml-sycl/common.hpp:118: SYCL error Could not attach to process. If your uid matches the uid of the target process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try again as the root user. For more details, see /etc/sysctl.d/10-ptrace.conf ptrace: Operation not permitted. No stack. The program is not being run. ```
1 parent 4a51281 commit 5b9dbfc

File tree

6 files changed

+247
-87
lines changed

6 files changed

+247
-87
lines changed

src/blas/backends/cublas/cublas_scope_handle.cpp

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,36 +32,80 @@ namespace cublas {
3232
*/
3333
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};
3434

35-
CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {}
35+
CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {
36+
// Initialize streamID member to a CUstream associated with the queue `ih`
37+
// has been submitted to.
38+
streamId = ih.get_native_queue<sycl::backend::ext_oneapi_cuda>();
3639

37-
cublasHandle_t CublasScopedContextHandler::get_handle() {
40+
// Initialize the `cublasHandle_t` member `nativeHandle`
3841
CUdevice device = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
39-
CUstream streamId = get_stream();
40-
cublasStatus_t err;
41-
4242
auto it = handle_helper.cublas_handle_mapper_.find(device);
4343
if (it != handle_helper.cublas_handle_mapper_.end()) {
44-
cublasHandle_t nativeHandle = it->second;
44+
// Use existing handle if one already exists for the device, but update
45+
// the native stream.
46+
nativeHandle = it->second;
4547
cudaStream_t currentStreamId;
48+
cublasStatus_t err;
4649
CUBLAS_ERROR_FUNC(cublasGetStream, err, nativeHandle, &currentStreamId);
4750
if (currentStreamId != streamId) {
4851
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
4952
}
50-
return nativeHandle;
5153
}
52-
53-
cublasHandle_t nativeHandle;
54-
CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle);
55-
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
56-
57-
auto insert_iter =
54+
else {
55+
// Create a new handle if one doesn't already exist for the device
56+
cublasStatus_t err;
57+
CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle);
58+
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
5859
handle_helper.cublas_handle_mapper_.insert(std::make_pair(device, nativeHandle));
60+
}
61+
}
5962

60-
return nativeHandle;
63+
void CublasScopedContextHandler::begin_recording_if_graph() {
64+
// interop_handle graph methods only available from extension version 2
65+
#if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
66+
if (!ih.ext_codeplay_has_graph()) {
67+
return;
68+
}
69+
70+
CUresult err;
71+
#if CUDA_VERSION >= 12030
72+
// After CUDA 12.3 we can use cuStreamBeginCaptureToGraph to capture
73+
// the stream directly in the native graph, rather than needing to
74+
// instantiate the stream capture as a new graph.
75+
auto graph = ih.ext_codeplay_get_native_graph<sycl::backend::ext_oneapi_cuda>();
76+
CUDA_ERROR_FUNC(cuStreamBeginCaptureToGraph, err, streamId, graph, nullptr, nullptr, 0,
77+
CU_STREAM_CAPTURE_MODE_GLOBAL);
78+
#else
79+
CUDA_ERROR_FUNC(cuStreamBeginCapture, err, streamId, CU_STREAM_CAPTURE_MODE_GLOBAL);
80+
#endif // CUDA_VERSION
81+
#endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
6182
}
6283

63-
CUstream CublasScopedContextHandler::get_stream() {
64-
return ih.get_native_queue<sycl::backend::ext_oneapi_cuda>();
84+
void CublasScopedContextHandler::end_recording_if_graph() {
85+
// interop_handle graph methods only available from extension version 2
86+
#if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
87+
if (!ih.ext_codeplay_has_graph()) {
88+
return;
89+
}
90+
91+
auto graph = ih.ext_codeplay_get_native_graph<sycl::backend::ext_oneapi_cuda>();
92+
CUresult err;
93+
#if CUDA_VERSION >= 12030
94+
CUDA_ERROR_FUNC(cuStreamEndCapture, err, streamId, &graph);
95+
#else
96+
// cuStreamEndCapture returns a new graph, if we overwrite
97+
// "graph" it won't be picked up by the SYCL runtime, as
98+
// "ext_codeplay_get_native_graph" returns a passed-by-value pointer.
99+
CUgraph recorded_graph;
100+
CUDA_ERROR_FUNC(cuStreamEndCapture, err, streamId, &recorded_graph);
101+
102+
// Add graph to native graph as a child node
103+
// Need to return a node object for the node to be created,
104+
// can't be nullptr.
105+
CUgraphNode node;
106+
CUDA_ERROR_FUNC(cuGraphAddChildGraphNode, err, &node, graph, nullptr, 0, recorded_graph);
107+
#endif // CUDA_VERSION
108+
#endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
65109
}
66110
} // namespace cublas
67111
} // namespace blas

src/blas/backends/cublas/cublas_scope_handle.hpp

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,49 @@ the handle must be destroyed when the context goes out of scope. This will bind
6363
class CublasScopedContextHandler {
6464
sycl::interop_handle& ih;
6565
static thread_local cublas_handle handle_helper;
66-
CUstream get_stream();
66+
cublasHandle_t nativeHandle;
67+
// Cache the native CU stream when the `CublasScopedContextHandler`object
68+
// is constructed. This avoids calling `get_native_queue(ih)` multiple
69+
// times which isn't guaranteed to return the same CUstream handle each
70+
// time. A scenario that causes problems when trying to start/end cuda
71+
// stream recording to a graph.
72+
CUstream streamId;
6773

6874
public:
75+
/**
76+
* @brief Constructor
77+
* @detail Creates the cublasHandle_t by implicitly impose the advice
78+
* given by nvidia for creating a cublas_handle. (e.g. one cuStream per device
79+
* per thread).
80+
*/
6981
CublasScopedContextHandler(sycl::interop_handle& ih);
7082

7183
/**
72-
* @brief get_handle: creates the handle by implicitly impose the advice
73-
* given by nvidia for creating a cublas_handle. (e.g. one cuStream per device
74-
* per thread).
75-
* @return cublasHandle_t a handle to construct cublas routines
76-
*/
77-
cublasHandle_t get_handle();
84+
* @brief Start recording cuBlas calls to a graph.
85+
* @detail Checks if the command-group associated with \p ih is being added
86+
* to a graph, and if so, begin stream recording of the native CUDA stream
87+
* associated with \p queue to the native cuda-graph object.
88+
*/
89+
void begin_recording_if_graph();
90+
91+
/**
92+
* @brief End recording cuBlas calls to a graph.
93+
* @detail Checks if the command-group associated with \p ih is being added
94+
* to a graph, and if so, ends stream recording of the native CUDA stream
95+
* associated with \p queue to the native cuda-graph object. Doing any
96+
* extra work to ensure that stream recorded calls get added as nodes to
97+
* the native graph object associated with \p ih.
98+
* @param queue The sycl queue to end stream recording on native stream
99+
* backing the queue.
100+
*/
101+
void end_recording_if_graph();
102+
103+
/// @brief Query the cuBLAS handle created on construction
104+
/// @return cublasHandle_t a handle to construct cublas routines
105+
cublasHandle_t get_handle() const {
106+
return nativeHandle;
107+
}
108+
78109
// This is a work-around function for reinterpret_casting the memory. This
79110
// will be fixed when SYCL-2020 has been implemented for Pi backend.
80111
template <typename T, typename U>

src/blas/backends/cublas/cublas_task.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ static inline void host_task_internal(H& cgh, F f) {
6161
cgh.host_task([f](sycl::interop_handle ih) {
6262
#endif
6363
auto sc = CublasScopedContextHandler(ih);
64+
sc.begin_recording_if_graph();
6465
f(sc);
66+
sc.end_recording_if_graph();
6567
});
6668
}
6769
#endif

tests/unit_tests/blas/batch/gemm_batch_usm.cpp

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ extern std::vector<sycl::device*> devices;
4848
namespace {
4949

5050
template <typename Ta, typename Tb, typename Tc, typename Ts>
51-
int test(device* dev, oneapi::math::layout layout, int64_t group_count) {
51+
int test(device* dev, oneapi::math::layout layout, int64_t group_count, bool graph_record = false) {
5252
// Catch asynchronous exceptions.
5353
auto exception_handler = [](exception_list exceptions) {
5454
for (std::exception_ptr const& e : exceptions) {
@@ -247,6 +247,15 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count) {
247247

248248
try {
249249
#ifdef CALL_RT_API
250+
#ifdef SYCL_EXT_ONEAPI_GRAPH
251+
namespace sycl_exp = sycl::ext::oneapi::experimental;
252+
using modifiable_graph = sycl_exp::command_graph<sycl_exp::graph_state::modifiable>;
253+
std::unique_ptr<modifiable_graph> graph;
254+
if (graph_record) {
255+
graph = std::make_unique<modifiable_graph>(main_queue);
256+
graph->begin_recording(main_queue);
257+
}
258+
#endif
250259
switch (layout) {
251260
case oneapi::math::layout::col_major:
252261
done = oneapi::math::blas::column_major::gemm_batch(
@@ -262,7 +271,18 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count) {
262271
break;
263272
default: break;
264273
}
265-
done.wait_and_throw();
274+
275+
#ifdef SYCL_EXT_ONEAPI_GRAPH
276+
if (graph_record) {
277+
graph->end_recording(main_queue);
278+
auto exec_graph = graph->finalize();
279+
main_queue.ext_oneapi_graph(exec_graph).wait_and_throw();
280+
}
281+
else
282+
#endif
283+
{
284+
done.wait_and_throw();
285+
}
266286
#else
267287
switch (layout) {
268288
case oneapi::math::layout::col_major:
@@ -365,58 +385,65 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count) {
365385
}
366386

367387
class GemmBatchUsmTests
368-
: public ::testing::TestWithParam<std::tuple<sycl::device*, oneapi::math::layout>> {};
388+
: public ::testing::TestWithParam<std::tuple<sycl::device*, oneapi::math::layout, bool>> {
389+
virtual void SetUp() override {
390+
// Skip test if graph recording variant and device doesn't support sycl_ext_oneapi_graph
391+
if (std::get<2>(GetParam())) {
392+
CHECK_GRAPH_ON_DEVICE(std::get<0>(GetParam()));
393+
}
394+
}
395+
};
369396

370397
TEST_P(GemmBatchUsmTests, RealHalfPrecision) {
371398
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, sycl::half, sycl::half>(
372-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
399+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
373400
}
374401

375402
TEST_P(GemmBatchUsmTests, HalfHalfFloatPrecision) {
376-
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, float, float>(std::get<0>(GetParam()),
377-
std::get<1>(GetParam()), 5)));
403+
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, float, float>(
404+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
378405
}
379406

380407
TEST_P(GemmBatchUsmTests, Int8Int8SinglePrecision) {
381-
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, float, float>(std::get<0>(GetParam()),
382-
std::get<1>(GetParam()), 5)));
408+
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, float, float>(
409+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
383410
}
384411

385412
TEST_P(GemmBatchUsmTests, Int8Int8Int32Precision) {
386413
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, std::int32_t, float>(
387-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
414+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
388415
}
389416

390417
TEST_P(GemmBatchUsmTests, RealSinglePrecision) {
391-
EXPECT_TRUEORSKIP(
392-
(test<float, float, float, float>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
418+
EXPECT_TRUEORSKIP((test<float, float, float, float>(
419+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
393420
}
394421

395422
TEST_P(GemmBatchUsmTests, RealDoublePrecision) {
396423
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));
397424

398-
EXPECT_TRUEORSKIP((
399-
test<double, double, double, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
425+
EXPECT_TRUEORSKIP((test<double, double, double, double>(
426+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
400427
}
401428

402429
TEST_P(GemmBatchUsmTests, ComplexSinglePrecision) {
403430
EXPECT_TRUEORSKIP(
404431
(test<std::complex<float>, std::complex<float>, std::complex<float>, std::complex<float>>(
405-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
432+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
406433
}
407434

408435
TEST_P(GemmBatchUsmTests, ComplexDoublePrecision) {
409436
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));
410437

411-
EXPECT_TRUEORSKIP(
412-
(test<std::complex<double>, std::complex<double>, std::complex<double>,
413-
std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
438+
EXPECT_TRUEORSKIP((test<std::complex<double>, std::complex<double>, std::complex<double>,
439+
std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam()),
440+
5, std::get<2>(GetParam()))));
414441
}
415442

416443
INSTANTIATE_TEST_SUITE_P(GemmBatchUsmTestSuite, GemmBatchUsmTests,
417444
::testing::Combine(testing::ValuesIn(devices),
418445
testing::Values(oneapi::math::layout::col_major,
419-
oneapi::math::layout::row_major)),
420-
::LayoutDeviceNamePrint());
421-
446+
oneapi::math::layout::row_major),
447+
testing::Values(true, false)),
448+
::LayoutGraphDeviceNamePrint());
422449
} // anonymous namespace

0 commit comments

Comments
 (0)