Skip to content

Commit f003208

Browse files
committed
ch4/ucx: Use UCX datatypes in noncontig AM send path
UCX has support for noncontig datatypes in the nbx active message send interface. Use it like we do in the tagged path.
1 parent 6ff7da4 commit f003208

File tree

1 file changed

+20
-31
lines changed

1 file changed

+20
-31
lines changed

src/mpid/ch4/netmod/ucx/ucx_am.h

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -42,52 +42,41 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_am_isend(int rank,
4242

4343
#ifdef HAVE_UCP_AM_NBX
4444
size_t header_size = sizeof(ucx_hdr) + am_hdr_sz;
45-
void *send_buf, *header, *data_ptr;
46-
/* note: since we are not copying large contig gpu data, it is less useful
47-
* to use MPIR_gpu_malloc_host */
48-
if (dt_contig) {
49-
/* only need copy headers */
50-
send_buf = MPL_malloc(header_size, MPL_MEM_OTHER);
51-
MPIR_Assert(send_buf);
52-
header = send_buf;
53-
54-
MPIR_Memcpy(header, &ucx_hdr, sizeof(ucx_hdr));
55-
MPIR_Memcpy((char *) header + sizeof(ucx_hdr), am_hdr, am_hdr_sz);
56-
57-
data_ptr = (char *) data + dt_true_lb;
58-
} else {
59-
/* need copy headers and pack data */
60-
send_buf = MPL_malloc(header_size + data_sz, MPL_MEM_OTHER);
61-
MPIR_Assert(send_buf);
62-
header = send_buf;
63-
data_ptr = (char *) send_buf + header_size;
64-
65-
MPIR_Memcpy(header, &ucx_hdr, sizeof(ucx_hdr));
66-
MPIR_Memcpy((char *) header + sizeof(ucx_hdr), am_hdr, am_hdr_sz);
67-
68-
MPI_Aint actual_pack_bytes;
69-
mpi_errno = MPIR_Typerep_pack(data, count, datatype, 0, data_ptr, data_sz,
70-
&actual_pack_bytes, MPIR_TYPEREP_FLAG_NONE);
71-
MPIR_ERR_CHECK(mpi_errno);
72-
MPIR_Assert(actual_pack_bytes == data_sz);
73-
}
45+
void *header;
46+
const void *data_ptr;
7447
ucp_request_param_t param = {
7548
.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA,
7649
.cb.send = &MPIDI_UCX_am_isend_callback_nbx,
7750
.user_data = sreq,
7851
};
52+
53+
header = MPL_malloc(header_size, MPL_MEM_OTHER);
54+
MPIR_Assert(header);
55+
56+
MPIR_Memcpy(header, &ucx_hdr, sizeof(ucx_hdr));
57+
MPIR_Memcpy((char *) header + sizeof(ucx_hdr), am_hdr, am_hdr_sz);
58+
59+
if (dt_contig) {
60+
data_ptr = (char *) data + dt_true_lb;
61+
} else {
62+
param.op_attr_mask |= UCP_OP_ATTR_FIELD_DATATYPE;
63+
param.datatype = dt_ptr->dev.netmod.ucx.ucp_datatype;
64+
MPIR_Datatype_ptr_add_ref(dt_ptr);
65+
data_ptr = data;
66+
data_sz = count;
67+
}
7968
ucp_request = (MPIDI_UCX_ucp_request_t *) ucp_am_send_nbx(ep, MPIDI_UCX_AM_NBX_HANDLER_ID,
8069
header, header_size,
8170
data_ptr, data_sz, &param);
8271
MPIDI_UCX_CHK_REQUEST(ucp_request);
8372
/* if send is done, free all resources and complete the request */
8473
if (ucp_request == NULL) {
85-
MPL_free(send_buf);
74+
MPL_free(header);
8675
MPIDIG_global.origin_cbs[handler_id] (sreq);
8776
goto fn_exit;
8877
}
8978

90-
MPIDI_UCX_AM_SEND_REQUEST(sreq, pack_buffer) = send_buf;
79+
MPIDI_UCX_AM_SEND_REQUEST(sreq, pack_buffer) = header;
9180
MPIDI_UCX_AM_SEND_REQUEST(sreq, handler_id) = handler_id;
9281
ucp_request_release(ucp_request);
9382

0 commit comments

Comments
 (0)