Skip to content

Commit 57ced10

Browse files
committed
coll/circ_graph: refactor dependency tracking for reduce
In bcast and allgather the dependency tracking is simple as recv does not have dependency and send only depend on at most a single recv. For reduction, we may have multiple pending sends and a single pending recv.
1 parent 9c03e6c commit 57ced10

File tree

2 files changed

+98
-33
lines changed

2 files changed

+98
-33
lines changed

src/mpi/coll/algorithms/circ_graph/cga_request_queue.c

Lines changed: 90 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ static int issue_send(MPII_cga_request_queue * queue, const void *buf, MPI_Aint
1111
int peer_rank, int chunk_id, int *req_idx_out);
1212
static int issue_recv(MPII_cga_request_queue * queue, void *buf, MPI_Aint count,
1313
int peer_rank, int chunk_id, int *req_idx_out);
14+
static void add_pending(MPII_cga_request_queue * queue, int chunk_id, int req_idx);
15+
static void remove_pending(MPII_cga_request_queue * queue, int chunk_id, int req_idx);
1416
static int wait_if_full(MPII_cga_request_queue * queue);
1517
static int wait_for_request(MPII_cga_request_queue * queue, int i);
1618
static int reduce_local(MPII_cga_request_queue * queue, int block);
@@ -245,8 +247,7 @@ int MPII_cga_bcast_send(MPII_cga_request_queue * queue, int block, int peer_rank
245247

246248
int chunk_id = block;
247249
if (queue->pending_blocks[chunk_id] >= 0) {
248-
int i = queue->pending_blocks[chunk_id];
249-
mpi_errno = wait_for_request(queue, i);
250+
mpi_errno = wait_for_request(queue, queue->pending_blocks[chunk_id]);
250251
MPIR_ERR_CHECK(mpi_errno);
251252
}
252253

@@ -299,9 +300,7 @@ int MPII_cga_bcast_recv(MPII_cga_request_queue * queue, int block, int peer_rank
299300
mpi_errno = issue_recv(queue, buf, count, peer_rank, block, &req_idx);
300301
MPIR_ERR_CHECK(mpi_errno);
301302

302-
int chunk_id = block;
303-
queue->requests[req_idx].chunk_id = chunk_id;
304-
queue->pending_blocks[chunk_id] = req_idx;
303+
add_pending(queue, block, req_idx);
305304

306305
fn_exit:
307306
return mpi_errno;
@@ -335,8 +334,7 @@ int MPII_cga_allgather_send(MPII_cga_request_queue * queue, int root, int block,
335334

336335
int chunk_id = block + root * queue->num_chunks;
337336
if (queue->pending_blocks[chunk_id] >= 0) {
338-
int i = queue->pending_blocks[chunk_id];
339-
mpi_errno = wait_for_request(queue, i);
337+
mpi_errno = wait_for_request(queue, queue->pending_blocks[chunk_id]);
340338
MPIR_ERR_CHECK(mpi_errno);
341339
}
342340

@@ -392,16 +390,15 @@ int MPII_cga_allgather_recv(MPII_cga_request_queue * queue, int root, int block,
392390
mpi_errno = issue_recv(queue, buf, count, peer_rank, chunk_id, &req_idx);
393391
MPIR_ERR_CHECK(mpi_errno);
394392

395-
queue->requests[req_idx].chunk_id = chunk_id;
396-
queue->pending_blocks[chunk_id] = req_idx;
393+
add_pending(queue, chunk_id, req_idx);
397394

398395
fn_exit:
399396
return mpi_errno;
400397
fn_fail:
401398
goto fn_exit;
402399
}
403400

404-
static int allgather_recv_unpack(MPII_cga_request_queue * queue, int block)
401+
static int allgather_recv_unpack(MPII_cga_request_queue * queue, int chunk_id)
405402
{
406403
int mpi_errno = MPI_SUCCESS;
407404

@@ -428,24 +425,31 @@ int MPII_cga_reduce_send(MPII_cga_request_queue * queue, int block, int peer_ran
428425
{
429426
int mpi_errno = MPI_SUCCESS;
430427

431-
/* in reduce, send may depend on recv from the previous rounds; recv may depend
432-
* on both the previous send and the previous recv */
428+
/* Dependency consideration for reduce:
429+
* * Recv - there are two operations, recv into tmp_buf and reduce into recvbuf.
430+
* Recv into tmp_buf require clear of previous recv (with the same block).
431+
* Reduction require clear of previous sends (with the same block).
432+
* * Send - assume we always issue send before recv, send require clear previous
433+
* recv (from previous rounds with the same block) and clear all previous
434+
* sends as part of recv completion (the reduction dpendency). However,
435+
* if there is no pending recv, multiple pending sends are ok.
436+
* Thus, we may have a single pending recv request and multiple pending send requests.
437+
*/
433438

434-
int chunk_id = block;
435-
if (queue->pending_blocks[chunk_id] >= 0) {
436-
int i = queue->pending_blocks[chunk_id];
437-
mpi_errno = wait_for_request(queue, i);
439+
int pending = queue->pending_blocks[block];
440+
if (pending >= 0 && queue->requests[pending].op_type == MPII_CGA_OP_RECV) {
441+
mpi_errno = wait_for_request(queue, pending);
438442
MPIR_ERR_CHECK(mpi_errno);
439443
}
440444

441445
MPI_Aint count = GET_BLOCK_COUNT(block);
442446
void *buf = GET_BLOCK_BUF(queue->u.reduce.recvbuf, block);
443447

444448
int req_idx;
445-
mpi_errno = issue_send(queue, buf, count, peer_rank, chunk_id, &req_idx);
449+
mpi_errno = issue_send(queue, buf, count, peer_rank, block, &req_idx);
446450
MPIR_ERR_CHECK(mpi_errno);
447451

448-
queue->pending_blocks[chunk_id] = req_idx;
452+
add_pending(queue, block, req_idx);
449453

450454
fn_exit:
451455
return mpi_errno;
@@ -457,16 +461,12 @@ int MPII_cga_reduce_recv(MPII_cga_request_queue * queue, int block, int peer_ran
457461
{
458462
int mpi_errno = MPI_SUCCESS;
459463

460-
/* reduction receives the same block from multiple processes and may reduce
461-
* into the pending send buffer, thus it may depend on either previous send
462-
* or recv.
463-
*
464-
* However, strictly, only the reduce_local depend on previous send. We will need
465-
* separaten pending_blocks tracking to make it precise.
466-
* */
467-
if (queue->pending_blocks[block] >= 0) {
468-
int i = queue->pending_blocks[block];
469-
mpi_errno = wait_for_request(queue, i);
464+
/* Reference the comments in MPII_cga_reduce_send. Issuing recv only need clear
465+
* previous pending recvs
466+
*/
467+
int pending = queue->pending_blocks[block];
468+
if (pending >= 0 && queue->requests[pending].op_type == MPII_CGA_OP_RECV) {
469+
mpi_errno = wait_for_request(queue, pending);
470470
MPIR_ERR_CHECK(mpi_errno);
471471
}
472472

@@ -477,9 +477,7 @@ int MPII_cga_reduce_recv(MPII_cga_request_queue * queue, int block, int peer_ran
477477
mpi_errno = issue_recv(queue, buf, count, peer_rank, block, &req_idx);
478478
MPIR_ERR_CHECK(mpi_errno);
479479

480-
int chunk_id = block;
481-
queue->requests[req_idx].chunk_id = chunk_id;
482-
queue->pending_blocks[chunk_id] = req_idx;
480+
add_pending(queue, block, req_idx);
483481

484482
fn_exit:
485483
return mpi_errno;
@@ -581,6 +579,56 @@ static int issue_recv(MPII_cga_request_queue * queue, void *buf, MPI_Aint count,
581579
goto fn_exit;
582580
}
583581

582+
static void add_pending(MPII_cga_request_queue * queue, int chunk_id, int req_idx)
583+
{
584+
if (queue->coll_type != MPII_CGA_REDUCE) {
585+
/* simple case - at most a single pending request per block */
586+
queue->requests[req_idx].next_req_id = -1;
587+
queue->pending_blocks[chunk_id] = req_idx;
588+
} else {
589+
/* reduction may have a single piending recv and multiple pending send */
590+
if (queue->requests[req_idx].op_type == MPII_CGA_OP_RECV) {
591+
/* prepend */
592+
queue->requests[req_idx].next_req_id = queue->pending_blocks[chunk_id];
593+
queue->pending_blocks[chunk_id] = req_idx;
594+
} else {
595+
/* there is no dependency between pending sends, just insert after the pending recv (if exist) */
596+
int pending = queue->pending_blocks[chunk_id];
597+
if (pending < 0) {
598+
/* no pending request */
599+
queue->requests[req_idx].next_req_id = -1;
600+
queue->pending_blocks[chunk_id] = req_idx;
601+
} else if (queue->requests[pending].op_type == MPII_CGA_OP_RECV) {
602+
/* insert after the first pending recv */
603+
queue->requests[req_idx].next_req_id = queue->requests[pending].next_req_id;
604+
queue->requests[pending].next_req_id = req_idx;
605+
} else {
606+
/* prepend */
607+
queue->requests[req_idx].next_req_id = queue->pending_blocks[chunk_id];
608+
queue->pending_blocks[chunk_id] = req_idx;
609+
}
610+
}
611+
}
612+
}
613+
614+
static void remove_pending(MPII_cga_request_queue * queue, int chunk_id, int req_idx)
615+
{
616+
if (queue->pending_blocks[chunk_id] == req_idx) {
617+
queue->pending_blocks[chunk_id] = queue->requests[req_idx].next_req_id;
618+
} else {
619+
/* not the first pending recv. This can happen in a reduce in wait_if_full or
620+
* MPII_cga_waitall when we need complete a send request not due to dependency
621+
*/
622+
int pending = queue->pending_blocks[chunk_id];
623+
MPIR_Assert(pending >= 0);
624+
while (queue->requests[pending].next_req_id != req_idx) {
625+
pending = queue->requests[pending].next_req_id;
626+
MPIR_Assert(pending >= 0);
627+
}
628+
queue->requests[pending].next_req_id = queue->requests[req_idx].next_req_id;
629+
}
630+
}
631+
584632
static int wait_if_full(MPII_cga_request_queue * queue)
585633
{
586634
int mpi_errno = MPI_SUCCESS;
@@ -607,10 +655,10 @@ static int wait_for_request(MPII_cga_request_queue * queue, int i)
607655
mpi_errno = MPIC_Wait(queue->requests[i].req);
608656
MPIR_ERR_CHECK(mpi_errno);
609657

658+
int chunk_id = queue->requests[i].chunk_id;
610659
if (queue->requests[i].op_type == MPII_CGA_OP_RECV) {
611660
/* it's a recv, update pending_blocks */
612-
int chunk_id = queue->requests[i].chunk_id;
613-
queue->pending_blocks[chunk_id] = -1;
661+
remove_pending(queue, chunk_id, i);
614662
if (queue->coll_type == MPII_CGA_BCAST) {
615663
if (queue->u.bcast.need_pack) {
616664
mpi_errno = bcast_recv_unpack(queue, chunk_id);
@@ -622,9 +670,18 @@ static int wait_for_request(MPII_cga_request_queue * queue, int i)
622670
MPIR_ERR_CHECK(mpi_errno);
623671
}
624672
} else if (queue->coll_type == MPII_CGA_REDUCE) {
673+
/* clear all pending sends */
674+
while (queue->pending_blocks[chunk_id] >= 0) {
675+
int req_id = queue->pending_blocks[chunk_id];
676+
MPIR_Assert(queue->requests[req_id].op_type == MPII_CGA_OP_SEND);
677+
mpi_errno = wait_for_request(queue, req_id);
678+
MPIR_ERR_CHECK(mpi_errno);
679+
}
625680
mpi_errno = reduce_local(queue, chunk_id);
626681
MPIR_ERR_CHECK(mpi_errno);
627682
}
683+
} else {
684+
remove_pending(queue, chunk_id, i);
628685
}
629686
MPIR_Request_free(queue->requests[i].req);
630687
queue->requests[i].req = NULL;

src/mpi/coll/algorithms/circ_graph/circ_graph.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ typedef struct {
9191
void *recvbuf;
9292
MPI_Op op;
9393
} reduce;
94+
struct {
95+
void *tmp_buf;
96+
void *recvbuf;
97+
MPI_Op op;
98+
} allreduce;
9499
} u;
95100

96101
MPIR_Comm *comm;
@@ -104,6 +109,9 @@ typedef struct {
104109
MPIR_Request *req;
105110
int chunk_id;
106111
int round;
112+
/* for reduction, we may have multiple requests concurrent on the same block,
113+
* thus we may need a linked list */
114+
int next_req_id;
107115
} *requests; /* requests[q_len] */
108116
int q_head;
109117
int q_tail;

0 commit comments

Comments
 (0)