@@ -42,52 +42,41 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_am_isend(int rank,
4242
4343#ifdef HAVE_UCP_AM_NBX
4444 size_t header_size = sizeof (ucx_hdr ) + am_hdr_sz ;
45- void * send_buf , * header , * data_ptr ;
46- /* note: since we are not copying large contig gpu data, it is less useful
47- * to use MPIR_gpu_malloc_host */
48- if (dt_contig ) {
49- /* only need copy headers */
50- send_buf = MPL_malloc (header_size , MPL_MEM_OTHER );
51- MPIR_Assert (send_buf );
52- header = send_buf ;
53-
54- MPIR_Memcpy (header , & ucx_hdr , sizeof (ucx_hdr ));
55- MPIR_Memcpy ((char * ) header + sizeof (ucx_hdr ), am_hdr , am_hdr_sz );
56-
57- data_ptr = (char * ) data + dt_true_lb ;
58- } else {
59- /* need copy headers and pack data */
60- send_buf = MPL_malloc (header_size + data_sz , MPL_MEM_OTHER );
61- MPIR_Assert (send_buf );
62- header = send_buf ;
63- data_ptr = (char * ) send_buf + header_size ;
64-
65- MPIR_Memcpy (header , & ucx_hdr , sizeof (ucx_hdr ));
66- MPIR_Memcpy ((char * ) header + sizeof (ucx_hdr ), am_hdr , am_hdr_sz );
67-
68- MPI_Aint actual_pack_bytes ;
69- mpi_errno = MPIR_Typerep_pack (data , count , datatype , 0 , data_ptr , data_sz ,
70- & actual_pack_bytes , MPIR_TYPEREP_FLAG_NONE );
71- MPIR_ERR_CHECK (mpi_errno );
72- MPIR_Assert (actual_pack_bytes == data_sz );
73- }
45+ void * header ;
46+ const void * data_ptr ;
7447 ucp_request_param_t param = {
7548 .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA ,
7649 .cb .send = & MPIDI_UCX_am_isend_callback_nbx ,
7750 .user_data = sreq ,
7851 };
52+
53+ header = MPL_malloc (header_size , MPL_MEM_OTHER );
54+ MPIR_Assert (header );
55+
56+ MPIR_Memcpy (header , & ucx_hdr , sizeof (ucx_hdr ));
57+ MPIR_Memcpy ((char * ) header + sizeof (ucx_hdr ), am_hdr , am_hdr_sz );
58+
59+ if (dt_contig ) {
60+ data_ptr = (char * ) data + dt_true_lb ;
61+ } else {
62+ param .op_attr_mask |= UCP_OP_ATTR_FIELD_DATATYPE ;
63+ param .datatype = dt_ptr -> dev .netmod .ucx .ucp_datatype ;
64+ MPIR_Datatype_ptr_add_ref (dt_ptr );
65+ data_ptr = data ;
66+ data_sz = count ;
67+ }
7968 ucp_request = (MPIDI_UCX_ucp_request_t * ) ucp_am_send_nbx (ep , MPIDI_UCX_AM_NBX_HANDLER_ID ,
8069 header , header_size ,
8170 data_ptr , data_sz , & param );
8271 MPIDI_UCX_CHK_REQUEST (ucp_request );
8372 /* if send is done, free all resources and complete the request */
8473 if (ucp_request == NULL ) {
85- MPL_free (send_buf );
74+ MPL_free (header );
8675 MPIDIG_global .origin_cbs [handler_id ] (sreq );
8776 goto fn_exit ;
8877 }
8978
90- MPIDI_UCX_AM_SEND_REQUEST (sreq , pack_buffer ) = send_buf ;
79+ MPIDI_UCX_AM_SEND_REQUEST (sreq , pack_buffer ) = header ;
9180 MPIDI_UCX_AM_SEND_REQUEST (sreq , handler_id ) = handler_id ;
9281 ucp_request_release (ucp_request );
9382
0 commit comments