Skip to content

Commit 342993a

Browse files
committed
coll: add comp algorithm MPIR_Coll_buffer_swap
Recast the collective buffer swap as a composition algorithm. TODO: support all collective types.
1 parent 007c061 commit 342993a

File tree

4 files changed

+124
-1
lines changed

4 files changed

+124
-1
lines changed

src/mpi/coll/coll_algorithms.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,16 @@ conditions:
7070
avg_msg_size(thresh): MPIR_Csel_avg_msg_size
7171
total_msg_size(thresh): MPIR_Csel_total_msg_size
7272

73+
need_buffer_swap: MPIR_Csel_need_buffer_swap
74+
7375
# conditional conditions - only call the condition function under macro_guard
7476
MPIDI_CH4_release_gather: MPIDI_POSIX_check_release_gather #if defined(MPIDI_CH4_SHM_POSIX)
7577

7678
# ----
7779
general:
7880
MPIR_Coll_auto
7981
MPIR_Coll_nb
82+
MPIR_Coll_buffer_swap
8083

8184
# ----
8285
barrier-intra:

src/mpi/coll/coll_composition.json

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
11
{
2-
"algorithm=MPIR_Coll_auto":{}
2+
"need_buffer_swap":
3+
{
4+
"algorithm=MPIR_Coll_buffer_swap":{}
5+
},
6+
"any":
7+
{
8+
"algorithm=MPIR_Coll_auto":{}
9+
}
310
}

src/mpi/coll/include/coll_csel.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ MPL_STATIC_INLINE_PREFIX bool MPIR_Csel_is_node_canonical(MPIR_Csel_coll_sig_s *
307307
return MPII_Comm_is_node_canonical(coll_sig->comm_ptr);
308308
}
309309

310+
MPL_STATIC_INLINE_PREFIX bool MPIR_Csel_need_buffer_swap(MPIR_Csel_coll_sig_s * coll_sig)
311+
{
312+
return false;
313+
}
314+
310315
#include "coll_autogen.h"
311316

312317
#endif /* COLL_CSEL_H_INCLUDED */

src/mpi/coll/src/coll_impl.c

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,3 +541,111 @@ int MPIR_Coll_nb(MPIR_Csel_coll_sig_s * coll_sig, MPII_Csel_container_s * me)
541541
fn_fail:
542542
goto fn_exit;
543543
}
544+
545+
/* swap buffers and continue. This works around when collective algorithms can't work
546+
* with hybrid memory or inefficient in working with hybrid memory.
547+
*/
548+
static void *host_alloc(MPI_Aint count, MPI_Datatype datatype)
549+
{
550+
return MPIR_alloc_buffer(count, datatype);
551+
}
552+
553+
static void *host_swap(const void *buf, MPI_Aint count, MPI_Datatype datatype)
554+
{
555+
void *host_buf = host_alloc(count, datatype);
556+
MPIR_Localcopy(buf, count, datatype, host_buf, count, datatype);
557+
return host_buf;
558+
}
559+
560+
#define DO_BUFFER_SWAP(sendbuf, recvbuf) \
561+
do { \
562+
orig_recvbuf = (recvbuf); \
563+
if ((sendbuf) == MPI_IN_PLACE) { \
564+
(recvbuf) = host_swap((recvbuf), recvcount, recvtype); \
565+
} else { \
566+
(sendbuf) = host_swap((sendbuf), sendcount, sendtype); \
567+
(recvbuf) = host_alloc(recvcount, recvtype); \
568+
} \
569+
host_recvbuf = (recvbuf); \
570+
} while (0)
571+
572+
int MPIR_Coll_buffer_swap(MPIR_Csel_coll_sig_s * coll_sig, MPII_Csel_container_s * me)
573+
{
574+
int mpi_errno = MPI_SUCCESS;
575+
576+
void *orig_recvbuf = NULL;
577+
void *host_recvbuf;
578+
MPI_Aint sendcount = 0, recvcount = 0;
579+
MPI_Datatype sendtype = MPI_DATATYPE_NULL, recvtype = MPI_DATATYPE_NULL;
580+
switch (coll_sig->coll_type) {
581+
COLL_TYPE_ALL_CASE(ALLREDUCE):
582+
sendcount = recvcount = coll_sig->u.allreduce.count;
583+
sendtype = recvtype = coll_sig->u.allreduce.datatype;
584+
DO_BUFFER_SWAP(coll_sig->u.allreduce.sendbuf, coll_sig->u.allreduce.recvbuf);
585+
break;
586+
COLL_TYPE_ALL_CASE(REDUCE):
587+
sendcount = recvcount = coll_sig->u.reduce.count;
588+
sendtype = recvtype = coll_sig->u.reduce.datatype;
589+
DO_BUFFER_SWAP(coll_sig->u.reduce.sendbuf, coll_sig->u.reduce.recvbuf);
590+
break;
591+
COLL_TYPE_ALL_CASE(SCAN):
592+
sendcount = recvcount = coll_sig->u.scan.count;
593+
sendtype = recvtype = coll_sig->u.scan.datatype;
594+
DO_BUFFER_SWAP(coll_sig->u.scan.sendbuf, coll_sig->u.scan.recvbuf);
595+
break;
596+
COLL_TYPE_ALL_CASE(EXSCAN):
597+
sendcount = recvcount = coll_sig->u.exscan.count;
598+
sendtype = recvtype = coll_sig->u.exscan.datatype;
599+
DO_BUFFER_SWAP(coll_sig->u.exscan.sendbuf, coll_sig->u.exscan.recvbuf);
600+
break;
601+
COLL_TYPE_ALL_CASE(REDUCE_SCATTER_BLOCK):
602+
{
603+
MPIR_Comm *comm_ptr = coll_sig->comm_ptr;
604+
recvcount = coll_sig->u.reduce_scatter_block.recvcount;
605+
sendcount = comm_ptr->local_size * recvcount;
606+
}
607+
sendtype = recvtype = coll_sig->u.reduce_scatter_block.datatype;
608+
DO_BUFFER_SWAP(coll_sig->u.reduce_scatter_block.sendbuf,
609+
coll_sig->u.reduce_scatter_block.recvbuf);
610+
break;
611+
COLL_TYPE_ALL_CASE(REDUCE_SCATTER):
612+
{
613+
MPIR_Comm *comm_ptr = coll_sig->comm_ptr;
614+
const MPI_Aint *counts = coll_sig->u.reduce_scatter.recvcounts;
615+
recvcount = counts[comm_ptr->rank];
616+
sendcount = 0;
617+
for (int i = 0; i < comm_ptr->local_size; i++) {
618+
sendcount += counts[i];
619+
}
620+
}
621+
sendtype = recvtype = coll_sig->u.reduce_scatter.datatype;
622+
DO_BUFFER_SWAP(coll_sig->u.reduce_scatter.sendbuf, coll_sig->u.reduce_scatter.recvbuf);
623+
break;
624+
COLL_TYPE_ALL_CASE(BCAST):
625+
sendcount = coll_sig->u.bcast.count;
626+
sendtype = coll_sig->u.bcast.datatype;
627+
if (coll_sig->comm_ptr->rank == coll_sig->u.bcast.root) {
628+
coll_sig->u.bcast.buffer = host_swap(coll_sig->u.bcast.buffer, sendcount, sendtype);
629+
} else {
630+
orig_recvbuf = coll_sig->u.bcast.buffer;
631+
recvcount = sendcount;
632+
recvtype = recvtype;
633+
coll_sig->u.bcast.buffer = host_alloc(sendcount, sendtype);
634+
host_recvbuf = coll_sig->u.bcast.buffer;
635+
}
636+
default:
637+
break;
638+
}
639+
640+
mpi_errno = MPIR_Coll_auto(coll_sig, NULL);
641+
MPIR_ERR_CHECK(mpi_errno);
642+
643+
if (orig_recvbuf) {
644+
MPIR_Localcopy(host_recvbuf, recvcount, recvtype, orig_recvbuf, recvcount, recvtype);
645+
}
646+
647+
fn_exit:
648+
return mpi_errno;
649+
fn_fail:
650+
goto fn_exit;
651+
}

0 commit comments

Comments
 (0)