Skip to content

Commit e9507f5

Browse files
committed
ch4: add CS_YIELD in netmod dynamic_sendrecv
The dynamic_sendrecv is used in MPI_Intercomm_create. The mismatching between threads are protected by the user provided tag, thus it is okay to yield during the blocking progress. Without the yield, MPI_Intercomm_create may block another thread's progress when the remote processes are not present (blocked by other communications). In the dynamic process accept/connect path, we force peer_comm's context id to 0. This is okay because the leader exchange is established with a specific pair of addresses and there is no other communications yet during leader_exchange.
1 parent 495e529 commit e9507f5

File tree

5 files changed

+44
-15
lines changed

5 files changed

+44
-15
lines changed

src/mpid/ch4/ch4_api.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ Non Native API:
9191
dynamic_recv : int
9292
NM : tag, buf-2, size, timeout
9393
dynamic_sendrecv : int
94-
NM : remote_lpid, tag, send_buf, send_size, recv_buf, recv_size, timeout
94+
NM : remote_lpid, peer_comm, tag, send_buf, send_size, recv_buf, recv_size, timeout
9595
mpi_comm_commit_pre_hook : int
9696
NM : comm
9797
SHM : comm
@@ -489,6 +489,7 @@ PARAM:
489489
origin_count: MPI_Aint
490490
origin_datatype: MPI_Datatype
491491
partner: MPIR_Request *
492+
peer_comm: MPIR_Comm *
492493
port_name: const char *
493494
port_name-2: char *
494495
ptr: void *

src/mpid/ch4/netmod/ofi/ofi_spawn.c

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
static int cancel_dynamic_request(MPIDI_OFI_dynamic_process_request_t * dynamic_req, bool is_send);
1919
static uint64_t get_dynamic_connection_match_bits(int tag);
20-
static uint64_t get_dynamic_match_bits(MPIR_Lpid lpid, int tag);
20+
static uint64_t get_dynamic_match_bits(MPIR_Lpid lpid, int context_id, int tag);
2121

2222
int MPIDI_OFI_dynamic_send(MPIR_Lpid remote_lpid, int tag, const void *buf, int size, int timeout)
2323
{
@@ -111,7 +111,7 @@ int MPIDI_OFI_dynamic_recv(int tag, void *buf, int size, int timeout)
111111
goto fn_exit;
112112
}
113113

114-
int MPIDI_OFI_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
114+
int MPIDI_OFI_dynamic_sendrecv(MPIR_Lpid remote_lpid, MPIR_Comm * peer_comm, int tag,
115115
const void *send_buf, int send_size, void *recv_buf, int recv_size,
116116
int timeout)
117117
{
@@ -132,7 +132,7 @@ int MPIDI_OFI_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
132132
send_req.event_id = MPIDI_OFI_EVENT_DYNPROC_DONE;
133133

134134
if (send_size > 0) {
135-
uint64_t match_bits = get_dynamic_match_bits(MPIR_Process.rank, tag);
135+
uint64_t match_bits = get_dynamic_match_bits(MPIR_Process.rank, peer_comm->context_id, tag);
136136
MPIDI_OFI_CALL_RETRY(fi_tsend(MPIDI_OFI_global.ctx[ctx_idx].tx,
137137
send_buf, send_size, NULL,
138138
remote_addr, match_bits, (void *) &send_req.context),
@@ -147,7 +147,7 @@ int MPIDI_OFI_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
147147

148148
if (recv_size > 0) {
149149
uint64_t mask_bits = 0;
150-
uint64_t match_bits = get_dynamic_match_bits(remote_lpid, tag);
150+
uint64_t match_bits = get_dynamic_match_bits(remote_lpid, peer_comm->recvcontext_id, tag);
151151
MPIDI_OFI_CALL_RETRY(fi_trecv(MPIDI_OFI_global.ctx[ctx_idx].rx,
152152
recv_buf, recv_size, NULL,
153153
remote_addr, match_bits, mask_bits, &recv_req.context),
@@ -182,6 +182,8 @@ int MPIDI_OFI_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
182182
break;
183183
}
184184
}
185+
MPID_THREAD_CS_YIELD(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
186+
MPID_THREAD_CS_YIELD(VCI, MPIDI_VCI_LOCK(vci));
185187
}
186188

187189
fn_exit:
@@ -295,12 +297,13 @@ int MPIDI_OFI_insert_upid(MPIR_Lpid lpid, const char *upid, int upid_len)
295297
/* -- internal static routines */
296298

297299
/* NOTE: used by MPIDI_OFI_dynamic_sendrecv, exact source match */
298-
static uint64_t get_dynamic_match_bits(MPIR_Lpid lpid, int tag)
300+
static uint64_t get_dynamic_match_bits(MPIR_Lpid lpid, int context_id, int tag)
299301
{
300302
/* normalize tag within (MPIDI_OFI_TAG_BITS - 1) bits, reserve 1 bit for dynamic connect/accept */
301303
tag &= (1 << (MPIDI_OFI_TAG_BITS - 1)) - 1;
302304

303-
uint64_t match_bits = MPIDI_OFI_DYNPROC_SEND | tag;
305+
uint64_t match_bits;
306+
match_bits = context_id;
304307

305308
if (!MPIDI_OFI_ENABLE_DATA) {
306309
/* FI_DIRECTED_RECV is not enabled, we have to embed source in the match_bits */
@@ -314,9 +317,15 @@ static uint64_t get_dynamic_match_bits(MPIR_Lpid lpid, int tag)
314317
HASH_VALUE(upid, sz, upid_hash);
315318
upid_hash &= (1 << MPIDI_OFI_SOURCE_BITS) - 1;
316319

317-
match_bits |= (upid_hash << MPIDI_OFI_TAG_BITS);
320+
match_bits <<= MPIDI_OFI_SOURCE_BITS;
321+
match_bits |= upid_hash;
318322
}
319323

324+
match_bits <<= MPIDI_OFI_TAG_BITS;
325+
match_bits |= tag;
326+
327+
match_bits |= MPIDI_OFI_DYNPROC_SEND;
328+
320329
return match_bits;
321330
}
322331

src/mpid/ch4/netmod/ucx/ucx_spawn.c

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,12 @@ int MPIDI_UCX_dynamic_recv(int tag, void *buf, int size, int timeout)
120120
return mpi_errno;
121121
}
122122

123-
int MPIDI_UCX_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
123+
static uint64_t get_dynamic_match_bits(int context_id, int tag)
124+
{
125+
return MPIDI_UCX_DYNPROC_MASK | ((uint64_t) context_id << MPIDI_UCX_TAG_BITS) | tag;
126+
}
127+
128+
int MPIDI_UCX_dynamic_sendrecv(MPIR_Lpid remote_lpid, MPIR_Comm * peer_comm, int tag,
124129
const void *send_buf, int send_size, void *recv_buf, int recv_size,
125130
int timeout)
126131
{
@@ -132,7 +137,6 @@ int MPIDI_UCX_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
132137
MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(vci));
133138
#endif
134139

135-
uint64_t ucx_tag = MPIDI_UCX_DYNPROC_MASK + tag;
136140
uint64_t tag_mask = 0xffffffffffffffff; /* for recv */
137141
MPIDI_av_entry_t *av = MPIDIU_lpid_to_av_slow(remote_lpid);
138142
ucp_ep_h ep = MPIDI_UCX_AV_TO_EP(av, vci, vci);
@@ -142,6 +146,8 @@ int MPIDI_UCX_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
142146
/* send */
143147
bool send_done = false;
144148
if (send_size > 0) {
149+
uint64_t ucx_tag = get_dynamic_match_bits(peer_comm->context_id, tag);
150+
145151
ucp_request_param_t send_param = {
146152
.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA,
147153
.cb.send = dynamic_send_cb,
@@ -163,6 +169,8 @@ int MPIDI_UCX_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
163169
/* recv */
164170
bool recv_done = false;
165171
if (recv_size > 0) {
172+
uint64_t ucx_tag = get_dynamic_match_bits(peer_comm->recvcontext_id, tag);
173+
166174
ucp_request_param_t recv_param = {
167175
.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA,
168176
.cb.recv = dynamic_recv_cb,
@@ -198,6 +206,8 @@ int MPIDI_UCX_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
198206
break;
199207
}
200208
}
209+
MPID_THREAD_CS_YIELD(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
210+
MPID_THREAD_CS_YIELD(VCI, MPIDI_VCI_LOCK(vci));
201211
}
202212

203213
fn_exit:

src/mpid/ch4/src/ch4_comm.c

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,8 @@ int MPID_Comm_set_hints(MPIR_Comm * comm_ptr, MPIR_Info * info_ptr)
323323
* 1. leader exchange data.
324324
* 2. leader broadcast over local_comm.
325325
*/
326-
static int leader_exchange(MPIR_Comm * local_comm, MPIR_Lpid remote_lpid, int tag,
326+
static int leader_exchange(MPIR_Comm * local_comm, MPIR_Lpid remote_lpid,
327+
MPIR_Comm * peer_comm, int tag,
327328
int context_id, int *remote_data_size_out, void **remote_data_out,
328329
int timeout);
329330
static int prepare_local_lpids(MPIR_Comm * local_comm, MPIR_Lpid ** lpids_out,
@@ -356,7 +357,7 @@ int MPID_Intercomm_exchange(MPIR_Comm * local_comm, int local_leader,
356357
void *remote_data = NULL;
357358
if (is_local_leader) {
358359
MPIR_Lpid remote_lpid = MPIR_comm_rank_to_lpid(peer_comm, remote_leader);
359-
mpi_errno = leader_exchange(local_comm, remote_lpid, tag, context_id,
360+
mpi_errno = leader_exchange(local_comm, remote_lpid, peer_comm, tag, context_id,
360361
&remote_data_size, &remote_data, timeout);
361362
}
362363

@@ -624,7 +625,8 @@ static int extract_remote_data(void *remote_data, int *remote_size_out,
624625
}
625626

626627
/* exchange data between leaders */
627-
static int leader_exchange(MPIR_Comm * local_comm, MPIR_Lpid remote_lpid, int tag, int context_id,
628+
static int leader_exchange(MPIR_Comm * local_comm, MPIR_Lpid remote_lpid,
629+
MPIR_Comm * peer_comm, int tag, int context_id,
628630
int *remote_data_size_out, void **remote_data_out, int timeout)
629631
{
630632
int mpi_errno = MPI_SUCCESS;
@@ -672,14 +674,16 @@ static int leader_exchange(MPIR_Comm * local_comm, MPIR_Lpid remote_lpid, int ta
672674
/* exchange */
673675
int remote_data_size;
674676
void *remote_data;
675-
mpi_errno = MPIDI_NM_dynamic_sendrecv(remote_lpid, tag, &local_data_size, sizeof(int),
677+
mpi_errno = MPIDI_NM_dynamic_sendrecv(remote_lpid, peer_comm, tag,
678+
&local_data_size, sizeof(int),
676679
&remote_data_size, sizeof(int), timeout);
677680
MPIR_ERR_CHECK(mpi_errno);
678681

679682
remote_data = MPL_malloc(remote_data_size, MPL_MEM_OTHER);
680683
MPIR_ERR_CHKANDJUMP(!remote_data, mpi_errno, MPI_ERR_OTHER, "**nomem");
681684

682-
mpi_errno = MPIDI_NM_dynamic_sendrecv(remote_lpid, tag, local_data, local_data_size,
685+
mpi_errno = MPIDI_NM_dynamic_sendrecv(remote_lpid, peer_comm, tag,
686+
local_data, local_data_size,
683687
remote_data, remote_data_size, timeout);
684688
MPIR_ERR_CHECK(mpi_errno);
685689

src/mpid/ch4/src/ch4_spawn.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,11 @@ static int dynamic_intercomm_create(const char *port_name, MPIR_Info * info, int
373373
peer_comm->local_size = 1;
374374
peer_comm->rank = 0;
375375
peer_comm->local_group = NULL;
376+
/* We have not exchanged context_id yet, set them to 0. This is okay since
377+
* the dynamic exchange is established between a pair of addresses (lpids) that
378+
* no other communications can happen yet. */
379+
peer_comm->context_id = 0;
380+
peer_comm->recvcontext_id = 0;
376381

377382
MPIR_Group_create_stride(1, 0, NULL, remote_lpid, 1, &peer_comm->remote_group);
378383

0 commit comments

Comments
 (0)