1515#include <uct/cuda/base/cuda_md.h>
1616#include <ucs/type/class.h>
1717#include <ucs/sys/string.h>
18- #include <ucs/async/async .h>
18+ #include <ucs/async/eventfd .h>
1919#include <ucs/arch/cpu.h>
2020
2121
@@ -73,7 +73,7 @@ static ucs_status_t uct_cuda_copy_iface_query(uct_iface_h tl_iface,
7373{
7474 uct_cuda_copy_iface_t * iface = ucs_derived_of (tl_iface , uct_cuda_copy_iface_t );
7575
76- uct_base_iface_query (& iface -> super , iface_attr );
76+ uct_base_iface_query (& iface -> super . super , iface_attr );
7777
7878 iface_attr -> iface_addr_len = sizeof (uct_cuda_copy_iface_addr_t );
7979 iface_attr -> device_addr_len = 0 ;
@@ -87,7 +87,7 @@ static ucs_status_t uct_cuda_copy_iface_query(uct_iface_h tl_iface,
8787
8888 iface_attr -> cap .event_flags = UCT_IFACE_FLAG_EVENT_SEND_COMP |
8989 UCT_IFACE_FLAG_EVENT_RECV |
90- UCT_IFACE_FLAG_EVENT_ASYNC_CB ;
90+ UCT_IFACE_FLAG_EVENT_FD ;
9191
9292 iface_attr -> cap .put .max_short = UINT_MAX ;
9393 iface_attr -> cap .put .max_bcopy = 0 ;
@@ -209,22 +209,6 @@ static unsigned uct_cuda_copy_iface_progress(uct_iface_h tl_iface)
209209 return count ;
210210}
211211
212- #if (__CUDACC_VER_MAJOR__ >= 100000 )
213- static void CUDA_CB myHostFn (void * cuda_copy_iface )
214- #else
215- static void CUDA_CB myHostCallback (CUstream hStream , CUresult status ,
216- void * cuda_copy_iface )
217- #endif
218- {
219- uct_cuda_copy_iface_t * iface = cuda_copy_iface ;
220-
221- ucs_assert (iface -> async .event_cb != NULL );
222- /* notify user */
223- UCS_ASYNC_BLOCK (iface -> super .worker -> async );
224- iface -> async .event_cb (iface -> async .event_arg , 0 );
225- UCS_ASYNC_UNBLOCK (iface -> super .worker -> async );
226- }
227-
228212static ucs_status_t uct_cuda_copy_iface_event_fd_arm (uct_iface_h tl_iface ,
229213 unsigned events )
230214{
@@ -242,18 +226,30 @@ static ucs_status_t uct_cuda_copy_iface_event_fd_arm(uct_iface_h tl_iface,
242226 }
243227 }
244228
229+ status = ucs_async_eventfd_poll (iface -> super .eventfd );
230+ if (status == UCS_OK ) {
231+ return UCS_ERR_BUSY ;
232+ } else if (status == UCS_ERR_IO_ERROR ) {
233+ return status ;
234+ }
235+
236+ ucs_assertv (status == UCS_ERR_NO_PROGRESS , "%s" , ucs_status_string (status ));
237+
245238 ucs_queue_for_each_safe (q_desc , iter , & iface -> active_queue , queue ) {
246239 event_q = & q_desc -> event_queue ;
247240 stream = & q_desc -> stream ;
248241 if (!ucs_queue_is_empty (event_q )) {
249242 status =
250243#if (__CUDACC_VER_MAJOR__ >= 100000 )
251- UCT_CUDADRV_FUNC_LOG_ERR (cuLaunchHostFunc (* stream ,
252- myHostFn , iface ));
244+ UCT_CUDADRV_FUNC_LOG_ERR (
245+ cuLaunchHostFunc (* stream ,
246+ uct_cuda_base_iface_stream_cb_fxn ,
247+ & iface -> super ));
253248#else
254- UCT_CUDADRV_FUNC_LOG_ERR (cuStreamAddCallback (* stream ,
255- myHostCallback ,
256- iface , 0 ));
249+ UCT_CUDADRV_FUNC_LOG_ERR (
250+ cuStreamAddCallback (* stream ,
251+ uct_cuda_base_iface_stream_cb_fxn ,
252+ & iface -> super , 0 ));
257253#endif
258254 if (UCS_OK != status ) {
259255 return status ;
@@ -280,7 +276,7 @@ static uct_iface_ops_t uct_cuda_copy_iface_ops = {
280276 .iface_progress_enable = uct_base_iface_progress_enable ,
281277 .iface_progress_disable = uct_base_iface_progress_disable ,
282278 .iface_progress = uct_cuda_copy_iface_progress ,
283- .iface_event_fd_get = ( uct_iface_event_fd_get_func_t ) ucs_empty_function_return_success ,
279+ .iface_event_fd_get = uct_cuda_base_iface_event_fd_get ,
284280 .iface_event_arm = uct_cuda_copy_iface_event_fd_arm ,
285281 .iface_close = UCS_CLASS_DELETE_FUNC_NAME (uct_cuda_copy_iface_t ),
286282 .iface_query = uct_cuda_copy_iface_query ,
@@ -409,11 +405,9 @@ static UCS_CLASS_INIT_FUNC(uct_cuda_copy_iface_t, uct_md_h md, uct_worker_h work
409405 ucs_memory_type_t src , dst ;
410406 ucs_mpool_params_t mp_params ;
411407
412- UCS_CLASS_CALL_SUPER_INIT (uct_base_iface_t , & uct_cuda_copy_iface_ops ,
408+ UCS_CLASS_CALL_SUPER_INIT (uct_cuda_iface_t , & uct_cuda_copy_iface_ops ,
413409 & uct_cuda_copy_iface_internal_ops , md , worker ,
414- params ,
415- tl_config UCS_STATS_ARG (params -> stats_root )
416- UCS_STATS_ARG ("cuda_copy" ));
410+ params , tl_config , "cuda_copy" );
417411
418412 if (strncmp (params -> mode .device .dev_name ,
419413 UCT_CUDA_DEV_NAME , strlen (UCT_CUDA_DEV_NAME )) != 0 ) {
@@ -438,9 +432,6 @@ static UCS_CLASS_INIT_FUNC(uct_cuda_copy_iface_t, uct_md_h md, uct_worker_h work
438432 return UCS_ERR_IO_ERROR ;
439433 }
440434
441- uct_iface_set_async_event_params (params , & self -> async .event_cb ,
442- & self -> async .event_arg );
443-
444435 ucs_queue_head_init (& self -> active_queue );
445436
446437 for (src = 0 ; src < UCS_MEMORY_TYPE_LAST ; ++ src ) {
@@ -463,7 +454,7 @@ static UCS_CLASS_CLEANUP_FUNC(uct_cuda_copy_iface_t)
463454 ucs_queue_head_t * event_q ;
464455 ucs_memory_type_t src , dst ;
465456
466- uct_base_iface_progress_disable (& self -> super .super ,
457+ uct_base_iface_progress_disable (& self -> super .super . super ,
467458 UCT_PROGRESS_SEND | UCT_PROGRESS_RECV );
468459
469460 UCT_CUDADRV_FUNC_LOG_ERR (cuCtxGetCurrent (& cuda_context ));
@@ -494,7 +485,7 @@ static UCS_CLASS_CLEANUP_FUNC(uct_cuda_copy_iface_t)
494485 ucs_mpool_cleanup (& self -> cuda_event_desc , 1 );
495486}
496487
497- UCS_CLASS_DEFINE (uct_cuda_copy_iface_t , uct_base_iface_t );
488+ UCS_CLASS_DEFINE (uct_cuda_copy_iface_t , uct_cuda_iface_t );
498489UCS_CLASS_DEFINE_NEW_FUNC (uct_cuda_copy_iface_t , uct_iface_t , uct_md_h , uct_worker_h ,
499490 const uct_iface_params_t * , const uct_iface_config_t * );
500491static UCS_CLASS_DEFINE_DELETE_FUNC (uct_cuda_copy_iface_t , uct_iface_t ) ;
0 commit comments