@@ -11,6 +11,8 @@ static int issue_send(MPII_cga_request_queue * queue, const void *buf, MPI_Aint
1111 int peer_rank , int chunk_id , int * req_idx_out );
1212static int issue_recv (MPII_cga_request_queue * queue , void * buf , MPI_Aint count ,
1313 int peer_rank , int chunk_id , int * req_idx_out );
14+ static void add_pending (MPII_cga_request_queue * queue , int chunk_id , int req_idx );
15+ static void remove_pending (MPII_cga_request_queue * queue , int chunk_id , int req_idx );
1416static int wait_if_full (MPII_cga_request_queue * queue );
1517static int wait_for_request (MPII_cga_request_queue * queue , int i );
1618static int reduce_local (MPII_cga_request_queue * queue , int block );
@@ -245,8 +247,7 @@ int MPII_cga_bcast_send(MPII_cga_request_queue * queue, int block, int peer_rank
245247
246248 int chunk_id = block ;
247249 if (queue -> pending_blocks [chunk_id ] >= 0 ) {
248- int i = queue -> pending_blocks [chunk_id ];
249- mpi_errno = wait_for_request (queue , i );
250+ mpi_errno = wait_for_request (queue , queue -> pending_blocks [chunk_id ]);
250251 MPIR_ERR_CHECK (mpi_errno );
251252 }
252253
@@ -299,9 +300,7 @@ int MPII_cga_bcast_recv(MPII_cga_request_queue * queue, int block, int peer_rank
299300 mpi_errno = issue_recv (queue , buf , count , peer_rank , block , & req_idx );
300301 MPIR_ERR_CHECK (mpi_errno );
301302
302- int chunk_id = block ;
303- queue -> requests [req_idx ].chunk_id = chunk_id ;
304- queue -> pending_blocks [chunk_id ] = req_idx ;
303+ add_pending (queue , block , req_idx );
305304
306305 fn_exit :
307306 return mpi_errno ;
@@ -335,8 +334,7 @@ int MPII_cga_allgather_send(MPII_cga_request_queue * queue, int root, int block,
335334
336335 int chunk_id = block + root * queue -> num_chunks ;
337336 if (queue -> pending_blocks [chunk_id ] >= 0 ) {
338- int i = queue -> pending_blocks [chunk_id ];
339- mpi_errno = wait_for_request (queue , i );
337+ mpi_errno = wait_for_request (queue , queue -> pending_blocks [chunk_id ]);
340338 MPIR_ERR_CHECK (mpi_errno );
341339 }
342340
@@ -392,16 +390,15 @@ int MPII_cga_allgather_recv(MPII_cga_request_queue * queue, int root, int block,
392390 mpi_errno = issue_recv (queue , buf , count , peer_rank , chunk_id , & req_idx );
393391 MPIR_ERR_CHECK (mpi_errno );
394392
395- queue -> requests [req_idx ].chunk_id = chunk_id ;
396- queue -> pending_blocks [chunk_id ] = req_idx ;
393+ add_pending (queue , chunk_id , req_idx );
397394
398395 fn_exit :
399396 return mpi_errno ;
400397 fn_fail :
401398 goto fn_exit ;
402399}
403400
404- static int allgather_recv_unpack (MPII_cga_request_queue * queue , int block )
401+ static int allgather_recv_unpack (MPII_cga_request_queue * queue , int chunk_id )
405402{
406403 int mpi_errno = MPI_SUCCESS ;
407404
@@ -428,24 +425,31 @@ int MPII_cga_reduce_send(MPII_cga_request_queue * queue, int block, int peer_ran
428425{
429426 int mpi_errno = MPI_SUCCESS ;
430427
431- /* in reduce, send may depend on recv from the previous rounds; recv may depend
432- * on both the previous send and the previous recv */
428+ /* Dependency consideration for reduce:
429+ * * Recv - there are two operations, recv into tmp_buf and reduce into recvbuf.
430+ * Recv into tmp_buf require clear of previous recv (with the same block).
431+ * Reduction require clear of previous sends (with the same block).
432+ * * Send - assume we always issue send before recv, send require clear previous
433+ * recv (from previous rounds with the same block) and clear all previous
434+ * sends as part of recv completion (the reduction dpendency). However,
435+ * if there is no pending recv, multiple pending sends are ok.
436+ * Thus, we may have a single pending recv request and multiple pending send requests.
437+ */
433438
434- int chunk_id = block ;
435- if (queue -> pending_blocks [chunk_id ] >= 0 ) {
436- int i = queue -> pending_blocks [chunk_id ];
437- mpi_errno = wait_for_request (queue , i );
439+ int pending = queue -> pending_blocks [block ];
440+ if (pending >= 0 && queue -> requests [pending ].op_type == MPII_CGA_OP_RECV ) {
441+ mpi_errno = wait_for_request (queue , pending );
438442 MPIR_ERR_CHECK (mpi_errno );
439443 }
440444
441445 MPI_Aint count = GET_BLOCK_COUNT (block );
442446 void * buf = GET_BLOCK_BUF (queue -> u .reduce .recvbuf , block );
443447
444448 int req_idx ;
445- mpi_errno = issue_send (queue , buf , count , peer_rank , chunk_id , & req_idx );
449+ mpi_errno = issue_send (queue , buf , count , peer_rank , block , & req_idx );
446450 MPIR_ERR_CHECK (mpi_errno );
447451
448- queue -> pending_blocks [ chunk_id ] = req_idx ;
452+ add_pending ( queue , block , req_idx ) ;
449453
450454 fn_exit :
451455 return mpi_errno ;
@@ -457,16 +461,12 @@ int MPII_cga_reduce_recv(MPII_cga_request_queue * queue, int block, int peer_ran
457461{
458462 int mpi_errno = MPI_SUCCESS ;
459463
460- /* reduction receives the same block from multiple processes and may reduce
461- * into the pending send buffer, thus it may depend on either previous send
462- * or recv.
463- *
464- * However, strictly, only the reduce_local depend on previous send. We will need
465- * separaten pending_blocks tracking to make it precise.
466- * */
467- if (queue -> pending_blocks [block ] >= 0 ) {
468- int i = queue -> pending_blocks [block ];
469- mpi_errno = wait_for_request (queue , i );
464+ /* Reference the comments in MPII_cga_reduce_send. Issuing recv only need clear
465+ * previous pending recvs
466+ */
467+ int pending = queue -> pending_blocks [block ];
468+ if (pending >= 0 && queue -> requests [pending ].op_type == MPII_CGA_OP_RECV ) {
469+ mpi_errno = wait_for_request (queue , pending );
470470 MPIR_ERR_CHECK (mpi_errno );
471471 }
472472
@@ -477,9 +477,7 @@ int MPII_cga_reduce_recv(MPII_cga_request_queue * queue, int block, int peer_ran
477477 mpi_errno = issue_recv (queue , buf , count , peer_rank , block , & req_idx );
478478 MPIR_ERR_CHECK (mpi_errno );
479479
480- int chunk_id = block ;
481- queue -> requests [req_idx ].chunk_id = chunk_id ;
482- queue -> pending_blocks [chunk_id ] = req_idx ;
480+ add_pending (queue , block , req_idx );
483481
484482 fn_exit :
485483 return mpi_errno ;
@@ -581,6 +579,56 @@ static int issue_recv(MPII_cga_request_queue * queue, void *buf, MPI_Aint count,
581579 goto fn_exit ;
582580}
583581
582+ static void add_pending (MPII_cga_request_queue * queue , int chunk_id , int req_idx )
583+ {
584+ if (queue -> coll_type != MPII_CGA_REDUCE ) {
585+ /* simple case - at most a single pending request per block */
586+ queue -> requests [req_idx ].next_req_id = -1 ;
587+ queue -> pending_blocks [chunk_id ] = req_idx ;
588+ } else {
589+ /* reduction may have a single piending recv and multiple pending send */
590+ if (queue -> requests [req_idx ].op_type == MPII_CGA_OP_RECV ) {
591+ /* prepend */
592+ queue -> requests [req_idx ].next_req_id = queue -> pending_blocks [chunk_id ];
593+ queue -> pending_blocks [chunk_id ] = req_idx ;
594+ } else {
595+ /* there is no dependency between pending sends, just insert after the pending recv (if exist) */
596+ int pending = queue -> pending_blocks [chunk_id ];
597+ if (pending < 0 ) {
598+ /* no pending request */
599+ queue -> requests [req_idx ].next_req_id = -1 ;
600+ queue -> pending_blocks [chunk_id ] = req_idx ;
601+ } else if (queue -> requests [pending ].op_type == MPII_CGA_OP_RECV ) {
602+ /* insert after the first pending recv */
603+ queue -> requests [req_idx ].next_req_id = queue -> requests [pending ].next_req_id ;
604+ queue -> requests [pending ].next_req_id = req_idx ;
605+ } else {
606+ /* prepend */
607+ queue -> requests [req_idx ].next_req_id = queue -> pending_blocks [chunk_id ];
608+ queue -> pending_blocks [chunk_id ] = req_idx ;
609+ }
610+ }
611+ }
612+ }
613+
614+ static void remove_pending (MPII_cga_request_queue * queue , int chunk_id , int req_idx )
615+ {
616+ if (queue -> pending_blocks [chunk_id ] == req_idx ) {
617+ queue -> pending_blocks [chunk_id ] = queue -> requests [req_idx ].next_req_id ;
618+ } else {
619+ /* not the first pending recv. This can happen in a reduce in wait_if_full or
620+ * MPII_cga_waitall when we need complete a send request not due to dependency
621+ */
622+ int pending = queue -> pending_blocks [chunk_id ];
623+ MPIR_Assert (pending >= 0 );
624+ while (queue -> requests [pending ].next_req_id != req_idx ) {
625+ pending = queue -> requests [pending ].next_req_id ;
626+ MPIR_Assert (pending >= 0 );
627+ }
628+ queue -> requests [pending ].next_req_id = queue -> requests [req_idx ].next_req_id ;
629+ }
630+ }
631+
584632static int wait_if_full (MPII_cga_request_queue * queue )
585633{
586634 int mpi_errno = MPI_SUCCESS ;
@@ -607,10 +655,10 @@ static int wait_for_request(MPII_cga_request_queue * queue, int i)
607655 mpi_errno = MPIC_Wait (queue -> requests [i ].req );
608656 MPIR_ERR_CHECK (mpi_errno );
609657
658+ int chunk_id = queue -> requests [i ].chunk_id ;
610659 if (queue -> requests [i ].op_type == MPII_CGA_OP_RECV ) {
611660 /* it's a recv, update pending_blocks */
612- int chunk_id = queue -> requests [i ].chunk_id ;
613- queue -> pending_blocks [chunk_id ] = -1 ;
661+ remove_pending (queue , chunk_id , i );
614662 if (queue -> coll_type == MPII_CGA_BCAST ) {
615663 if (queue -> u .bcast .need_pack ) {
616664 mpi_errno = bcast_recv_unpack (queue , chunk_id );
@@ -622,9 +670,18 @@ static int wait_for_request(MPII_cga_request_queue * queue, int i)
622670 MPIR_ERR_CHECK (mpi_errno );
623671 }
624672 } else if (queue -> coll_type == MPII_CGA_REDUCE ) {
673+ /* clear all pending sends */
674+ while (queue -> pending_blocks [chunk_id ] >= 0 ) {
675+ int req_id = queue -> pending_blocks [chunk_id ];
676+ MPIR_Assert (queue -> requests [req_id ].op_type == MPII_CGA_OP_SEND );
677+ mpi_errno = wait_for_request (queue , req_id );
678+ MPIR_ERR_CHECK (mpi_errno );
679+ }
625680 mpi_errno = reduce_local (queue , chunk_id );
626681 MPIR_ERR_CHECK (mpi_errno );
627682 }
683+ } else {
684+ remove_pending (queue , chunk_id , i );
628685 }
629686 MPIR_Request_free (queue -> requests [i ].req );
630687 queue -> requests [i ].req = NULL ;
0 commit comments