Skip to content

Commit a73f4ba

Browse files
committed
ch4: add CS_YIELD in netmod dynamic_sendrecv
The dynamic_sendrecv is used in MPI_Intercomm_create. The mismatching between threads are protected by the user provided tag, thus it is okay to yield during the blocking progress. Without the yield, MPI_Intercomm_create may block another thread's progress when the remote processes are not present (blocked by other communications).
1 parent aeecc1d commit a73f4ba

File tree

4 files changed

+39
-15
lines changed

4 files changed

+39
-15
lines changed

src/mpid/ch4/ch4_api.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ Non Native API:
9191
dynamic_recv : int
9292
NM : tag, buf-2, size, timeout
9393
dynamic_sendrecv : int
94-
NM : remote_lpid, tag, send_buf, send_size, recv_buf, recv_size, timeout
94+
NM : remote_lpid, peer_comm, tag, send_buf, send_size, recv_buf, recv_size, timeout
9595
mpi_comm_commit_pre_hook : int
9696
NM : comm
9797
SHM : comm
@@ -489,6 +489,7 @@ PARAM:
489489
origin_count: MPI_Aint
490490
origin_datatype: MPI_Datatype
491491
partner: MPIR_Request *
492+
peer_comm: MPIR_Comm *
492493
port_name: const char *
493494
port_name-2: char *
494495
ptr: void *

src/mpid/ch4/netmod/ofi/ofi_spawn.c

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
static int cancel_dynamic_request(MPIDI_OFI_dynamic_process_request_t * dynamic_req, bool is_send);
1919
static uint64_t get_dynamic_connection_match_bits(int tag);
20-
static uint64_t get_dynamic_match_bits(MPIR_Lpid lpid, int tag);
20+
static uint64_t get_dynamic_match_bits(MPIR_Lpid lpid, int context_id, int tag);
2121

2222
int MPIDI_OFI_dynamic_send(MPIR_Lpid remote_lpid, int tag, const void *buf, int size, int timeout)
2323
{
@@ -111,7 +111,7 @@ int MPIDI_OFI_dynamic_recv(int tag, void *buf, int size, int timeout)
111111
goto fn_exit;
112112
}
113113

114-
int MPIDI_OFI_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
114+
int MPIDI_OFI_dynamic_sendrecv(MPIR_Lpid remote_lpid, MPIR_Comm * peer_comm, int tag,
115115
const void *send_buf, int send_size, void *recv_buf, int recv_size,
116116
int timeout)
117117
{
@@ -132,7 +132,7 @@ int MPIDI_OFI_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
132132
send_req.event_id = MPIDI_OFI_EVENT_DYNPROC_DONE;
133133

134134
if (send_size > 0) {
135-
uint64_t match_bits = get_dynamic_match_bits(MPIR_Process.rank, tag);
135+
uint64_t match_bits = get_dynamic_match_bits(MPIR_Process.rank, peer_comm->context_id, tag);
136136
MPIDI_OFI_CALL_RETRY(fi_tsend(MPIDI_OFI_global.ctx[ctx_idx].tx,
137137
send_buf, send_size, NULL,
138138
remote_addr, match_bits, (void *) &send_req.context),
@@ -147,7 +147,7 @@ int MPIDI_OFI_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
147147

148148
if (recv_size > 0) {
149149
uint64_t mask_bits = 0;
150-
uint64_t match_bits = get_dynamic_match_bits(remote_lpid, tag);
150+
uint64_t match_bits = get_dynamic_match_bits(remote_lpid, peer_comm->recvcontext_id, tag);
151151
MPIDI_OFI_CALL_RETRY(fi_trecv(MPIDI_OFI_global.ctx[ctx_idx].rx,
152152
recv_buf, recv_size, NULL,
153153
remote_addr, match_bits, mask_bits, &recv_req.context),
@@ -182,6 +182,8 @@ int MPIDI_OFI_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
182182
break;
183183
}
184184
}
185+
MPID_THREAD_CS_YIELD(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
186+
MPID_THREAD_CS_YIELD(VCI, MPIDI_VCI_LOCK(vci));
185187
}
186188

187189
fn_exit:
@@ -295,12 +297,13 @@ int MPIDI_OFI_insert_upid(MPIR_Lpid lpid, const char *upid, int upid_len)
295297
/* -- internal static routines */
296298

297299
/* NOTE: used by MPIDI_OFI_dynamic_sendrecv, exact source match */
298-
static uint64_t get_dynamic_match_bits(MPIR_Lpid lpid, int tag)
300+
static uint64_t get_dynamic_match_bits(MPIR_Lpid lpid, int context_id, int tag)
299301
{
300302
/* normalize tag within (MPIDI_OFI_TAG_BITS - 1) bits, reserve 1 bit for dynamic connect/accept */
301303
tag &= (1 << (MPIDI_OFI_TAG_BITS - 1)) - 1;
302304

303-
uint64_t match_bits = MPIDI_OFI_DYNPROC_SEND | tag;
305+
uint64_t match_bits;
306+
match_bits = context_id;
304307

305308
if (!MPIDI_OFI_ENABLE_DATA) {
306309
/* FI_DIRECTED_RECV is not enabled, we have to embed source in the match_bits */
@@ -314,9 +317,15 @@ static uint64_t get_dynamic_match_bits(MPIR_Lpid lpid, int tag)
314317
HASH_VALUE(upid, sz, upid_hash);
315318
upid_hash &= (1 << MPIDI_OFI_SOURCE_BITS) - 1;
316319

317-
match_bits |= (upid_hash << MPIDI_OFI_TAG_BITS);
320+
match_bits <<= MPIDI_OFI_SOURCE_BITS;
321+
match_bits |= upid_hash;
318322
}
319323

324+
match_bits <<= MPIDI_OFI_TAG_BITS;
325+
match_bits |= tag;
326+
327+
match_bits |= MPIDI_OFI_DYNPROC_SEND;
328+
320329
return match_bits;
321330
}
322331

src/mpid/ch4/netmod/ucx/ucx_spawn.c

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,12 @@ int MPIDI_UCX_dynamic_recv(int tag, void *buf, int size, int timeout)
120120
return mpi_errno;
121121
}
122122

123-
int MPIDI_UCX_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
123+
static uint64_t get_dynamic_match_bits(int context_id, int tag)
124+
{
125+
return MPIDI_UCX_DYNPROC_MASK | ((uint64_t) context_id << MPIDI_UCX_TAG_BITS) | tag;
126+
}
127+
128+
int MPIDI_UCX_dynamic_sendrecv(MPIR_Lpid remote_lpid, MPIR_Comm * peer_comm, int tag,
124129
const void *send_buf, int send_size, void *recv_buf, int recv_size,
125130
int timeout)
126131
{
@@ -132,7 +137,6 @@ int MPIDI_UCX_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
132137
MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(vci));
133138
#endif
134139

135-
uint64_t ucx_tag = MPIDI_UCX_DYNPROC_MASK + tag;
136140
uint64_t tag_mask = 0xffffffffffffffff; /* for recv */
137141
MPIDI_av_entry_t *av = MPIDIU_lpid_to_av_slow(remote_lpid);
138142
ucp_ep_h ep = MPIDI_UCX_AV_TO_EP(av, vci, vci);
@@ -142,6 +146,8 @@ int MPIDI_UCX_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
142146
/* send */
143147
bool send_done = false;
144148
if (send_size > 0) {
149+
uint64_t ucx_tag = get_dynamic_match_bits(peer_comm->context_id, tag);
150+
145151
ucp_request_param_t send_param = {
146152
.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA,
147153
.cb.send = dynamic_send_cb,
@@ -163,6 +169,8 @@ int MPIDI_UCX_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
163169
/* recv */
164170
bool recv_done = false;
165171
if (recv_size > 0) {
172+
uint64_t ucx_tag = get_dynamic_match_bits(peer_comm->recvcontext_id, tag);
173+
166174
ucp_request_param_t recv_param = {
167175
.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA,
168176
.cb.recv = dynamic_recv_cb,
@@ -198,6 +206,8 @@ int MPIDI_UCX_dynamic_sendrecv(MPIR_Lpid remote_lpid, int tag,
198206
break;
199207
}
200208
}
209+
MPID_THREAD_CS_YIELD(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
210+
MPID_THREAD_CS_YIELD(VCI, MPIDI_VCI_LOCK(vci));
201211
}
202212

203213
fn_exit:

src/mpid/ch4/src/ch4_comm.c

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,8 @@ int MPID_Comm_set_hints(MPIR_Comm * comm_ptr, MPIR_Info * info_ptr)
324324
* 1. leader exchange data.
325325
* 2. leader broadcast over local_comm.
326326
*/
327-
static int leader_exchange(MPIR_Comm * local_comm, MPIR_Lpid remote_lpid, int tag,
327+
static int leader_exchange(MPIR_Comm * local_comm, MPIR_Lpid remote_lpid,
328+
MPIR_Comm * peer_comm, int tag,
328329
int context_id, int *remote_data_size_out, void **remote_data_out,
329330
int timeout);
330331
static int prepare_local_lpids(MPIR_Comm * local_comm, MPIR_Lpid ** lpids_out,
@@ -357,7 +358,7 @@ int MPID_Intercomm_exchange(MPIR_Comm * local_comm, int local_leader,
357358
void *remote_data = NULL;
358359
if (is_local_leader) {
359360
MPIR_Lpid remote_lpid = MPIR_comm_rank_to_lpid(peer_comm, remote_leader);
360-
mpi_errno = leader_exchange(local_comm, remote_lpid, tag, context_id,
361+
mpi_errno = leader_exchange(local_comm, remote_lpid, peer_comm, tag, context_id,
361362
&remote_data_size, &remote_data, timeout);
362363
}
363364

@@ -625,7 +626,8 @@ static int extract_remote_data(void *remote_data, int *remote_size_out,
625626
}
626627

627628
/* exchange data between leaders */
628-
static int leader_exchange(MPIR_Comm * local_comm, MPIR_Lpid remote_lpid, int tag, int context_id,
629+
static int leader_exchange(MPIR_Comm * local_comm, MPIR_Lpid remote_lpid,
630+
MPIR_Comm * peer_comm, int tag, int context_id,
629631
int *remote_data_size_out, void **remote_data_out, int timeout)
630632
{
631633
int mpi_errno = MPI_SUCCESS;
@@ -673,14 +675,16 @@ static int leader_exchange(MPIR_Comm * local_comm, MPIR_Lpid remote_lpid, int ta
673675
/* exchange */
674676
int remote_data_size;
675677
void *remote_data;
676-
mpi_errno = MPIDI_NM_dynamic_sendrecv(remote_lpid, tag, &local_data_size, sizeof(int),
678+
mpi_errno = MPIDI_NM_dynamic_sendrecv(remote_lpid, peer_comm, tag,
679+
&local_data_size, sizeof(int),
677680
&remote_data_size, sizeof(int), timeout);
678681
MPIR_ERR_CHECK(mpi_errno);
679682

680683
remote_data = MPL_malloc(remote_data_size, MPL_MEM_OTHER);
681684
MPIR_ERR_CHKANDJUMP(!remote_data, mpi_errno, MPI_ERR_OTHER, "**nomem");
682685

683-
mpi_errno = MPIDI_NM_dynamic_sendrecv(remote_lpid, tag, local_data, local_data_size,
686+
mpi_errno = MPIDI_NM_dynamic_sendrecv(remote_lpid, peer_comm, tag,
687+
local_data, local_data_size,
684688
remote_data, remote_data_size, timeout);
685689
MPIR_ERR_CHECK(mpi_errno);
686690

0 commit comments

Comments
 (0)