Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions src/mpid/ch4/netmod/ucx/ucx_am.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,19 @@ int MPIDI_UCX_do_am_recv(MPIR_Request * rreq)
MPIDI_UCX_ucp_request_t *ucp_request;
size_t received_length;
ucp_request_param_t param = {
.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_RECV_INFO,
.op_attr_mask =
UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_RECV_INFO | UCP_OP_ATTR_FIELD_USER_DATA,
.cb.recv_am = &MPIDI_UCX_am_recv_callback_nbx,
.recv_info.length = &received_length,
.user_data = rreq,
};
void *data_desc = MPIDI_UCX_AM_RECV_REQUEST(rreq, data_desc);
/* note: use in_data_sz to match promised data size */
ucp_request = ucp_am_recv_data_nbx(MPIDI_UCX_global.ctx[vci].worker,
data_desc, recv_buf, in_data_sz, &param);
if (ucp_request == NULL) {
/* completed immediately */
MPIDI_UCX_ucp_request_t tmp_ucp_request;
tmp_ucp_request.req = rreq;
MPIDI_UCX_am_recv_callback_nbx(&tmp_ucp_request, UCS_OK, received_length, NULL);
} else {
ucp_request->req = rreq;
MPIDI_UCX_am_recv_callback_nbx(NULL, UCS_OK, received_length, rreq);
}

return MPI_SUCCESS;
Expand Down Expand Up @@ -163,8 +161,7 @@ ucs_status_t MPIDI_UCX_am_nbx_handler(void *arg, const void *header, size_t head
void MPIDI_UCX_am_recv_callback_nbx(void *request, ucs_status_t status, size_t length,
void *user_data)
{
MPIDI_UCX_ucp_request_t *ucp_request = (MPIDI_UCX_ucp_request_t *) request;
MPIR_Request *rreq = ucp_request->req;
MPIR_Request *rreq = user_data;

/* FIXME: proper error handling */
MPIR_Assert(status == UCS_OK);
Expand All @@ -177,22 +174,21 @@ void MPIDI_UCX_am_recv_callback_nbx(void *request, ucs_status_t status, size_t l
MPIDIG_recv_done(length, rreq);
}
MPIDIG_REQUEST(rreq, req->target_cmpl_cb) (rreq);
ucp_request->req = NULL;
ucp_request_release(ucp_request);
if (request) {
ucp_request_release(request);
}
}

void MPIDI_UCX_am_isend_callback_nbx(void *request, ucs_status_t status, void *user_data)
{
/* note: only difference from MPIDI_UCX_am_isend_callback is we need
* MPL_free in stead of MPIR_gpu_free_host
*/
MPIDI_UCX_ucp_request_t *ucp_request = (MPIDI_UCX_ucp_request_t *) request;
MPIR_Request *req = ucp_request->req;
MPIR_Request *req = user_data;
int handler_id = MPIDI_UCX_AM_SEND_REQUEST(req, handler_id);

MPL_free(MPIDI_UCX_AM_SEND_REQUEST(req, pack_buffer));
MPIDI_UCX_AM_SEND_REQUEST(req, pack_buffer) = NULL;
MPIDIG_global.origin_cbs[handler_id] (req);
ucp_request->req = NULL;
}
#endif
51 changes: 20 additions & 31 deletions src/mpid/ch4/netmod/ucx/ucx_am.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,53 +42,42 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_am_isend(int rank,

#ifdef HAVE_UCP_AM_NBX
size_t header_size = sizeof(ucx_hdr) + am_hdr_sz;
void *send_buf, *header, *data_ptr;
/* note: since we are not copying large contig gpu data, it is less useful
* to use MPIR_gpu_malloc_host */
if (dt_contig) {
/* only need copy headers */
send_buf = MPL_malloc(header_size, MPL_MEM_OTHER);
MPIR_Assert(send_buf);
header = send_buf;
void *header;
const void *data_ptr;
ucp_request_param_t param = {
.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA,
.cb.send = &MPIDI_UCX_am_isend_callback_nbx,
.user_data = sreq,
};

header = MPL_malloc(header_size, MPL_MEM_OTHER);
MPIR_Assert(header);

MPIR_Memcpy(header, &ucx_hdr, sizeof(ucx_hdr));
MPIR_Memcpy((char *) header + sizeof(ucx_hdr), am_hdr, am_hdr_sz);
MPIR_Memcpy(header, &ucx_hdr, sizeof(ucx_hdr));
MPIR_Memcpy((char *) header + sizeof(ucx_hdr), am_hdr, am_hdr_sz);

if (dt_contig) {
data_ptr = (char *) data + dt_true_lb;
} else {
/* need copy headers and pack data */
send_buf = MPL_malloc(header_size + data_sz, MPL_MEM_OTHER);
MPIR_Assert(send_buf);
header = send_buf;
data_ptr = (char *) send_buf + header_size;

MPIR_Memcpy(header, &ucx_hdr, sizeof(ucx_hdr));
MPIR_Memcpy((char *) header + sizeof(ucx_hdr), am_hdr, am_hdr_sz);

MPI_Aint actual_pack_bytes;
mpi_errno = MPIR_Typerep_pack(data, count, datatype, 0, data_ptr, data_sz,
&actual_pack_bytes, MPIR_TYPEREP_FLAG_NONE);
MPIR_ERR_CHECK(mpi_errno);
MPIR_Assert(actual_pack_bytes == data_sz);
param.op_attr_mask |= UCP_OP_ATTR_FIELD_DATATYPE;
param.datatype = dt_ptr->dev.netmod.ucx.ucp_datatype;
MPIR_Datatype_ptr_add_ref(dt_ptr);
data_ptr = data;
data_sz = count;
}
ucp_request_param_t param = {
.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK,
.cb.send = &MPIDI_UCX_am_isend_callback_nbx,
};
ucp_request = (MPIDI_UCX_ucp_request_t *) ucp_am_send_nbx(ep, MPIDI_UCX_AM_NBX_HANDLER_ID,
header, header_size,
data_ptr, data_sz, &param);
MPIDI_UCX_CHK_REQUEST(ucp_request);
/* if send is done, free all resources and complete the request */
if (ucp_request == NULL) {
MPL_free(send_buf);
MPL_free(header);
MPIDIG_global.origin_cbs[handler_id] (sreq);
goto fn_exit;
}

MPIDI_UCX_AM_SEND_REQUEST(sreq, pack_buffer) = send_buf;
MPIDI_UCX_AM_SEND_REQUEST(sreq, pack_buffer) = header;
MPIDI_UCX_AM_SEND_REQUEST(sreq, handler_id) = handler_id;
ucp_request->req = sreq;
ucp_request_release(ucp_request);

#else /* !HAVE_UCP_AM_NBX */
Expand Down