@@ -541,3 +541,111 @@ int MPIR_Coll_nb(MPIR_Csel_coll_sig_s * coll_sig, MPII_Csel_container_s * me)
541541 fn_fail :
542542 goto fn_exit ;
543543}
544+
545+ /* swap buffers and continue. This works around when collective algorithms can't work
546+ * with hybrid memory or inefficient in working with hybrid memory.
547+ */
548+ static void * host_alloc (MPI_Aint count , MPI_Datatype datatype )
549+ {
550+ return MPIR_alloc_buffer (count , datatype );
551+ }
552+
553+ static void * host_swap (const void * buf , MPI_Aint count , MPI_Datatype datatype )
554+ {
555+ void * host_buf = host_alloc (count , datatype );
556+ MPIR_Localcopy (buf , count , datatype , host_buf , count , datatype );
557+ return host_buf ;
558+ }
559+
560+ #define DO_BUFFER_SWAP (sendbuf , recvbuf ) \
561+ do { \
562+ orig_recvbuf = (recvbuf); \
563+ if ((sendbuf) == MPI_IN_PLACE) { \
564+ (recvbuf) = host_swap((recvbuf), recvcount, recvtype); \
565+ } else { \
566+ (sendbuf) = host_swap((sendbuf), sendcount, sendtype); \
567+ (recvbuf) = host_alloc(recvcount, recvtype); \
568+ } \
569+ host_recvbuf = (recvbuf); \
570+ } while (0)
571+
572+ int MPIR_Coll_buffer_swap (MPIR_Csel_coll_sig_s * coll_sig , MPII_Csel_container_s * me )
573+ {
574+ int mpi_errno = MPI_SUCCESS ;
575+
576+ void * orig_recvbuf = NULL ;
577+ void * host_recvbuf ;
578+ MPI_Aint sendcount = 0 , recvcount = 0 ;
579+ MPI_Datatype sendtype = MPI_DATATYPE_NULL , recvtype = MPI_DATATYPE_NULL ;
580+ switch (coll_sig -> coll_type ) {
581+ COLL_TYPE_ALL_CASE (ALLREDUCE ):
582+ sendcount = recvcount = coll_sig -> u .allreduce .count ;
583+ sendtype = recvtype = coll_sig -> u .allreduce .datatype ;
584+ DO_BUFFER_SWAP (coll_sig -> u .allreduce .sendbuf , coll_sig -> u .allreduce .recvbuf );
585+ break ;
586+ COLL_TYPE_ALL_CASE (REDUCE ):
587+ sendcount = recvcount = coll_sig -> u .reduce .count ;
588+ sendtype = recvtype = coll_sig -> u .reduce .datatype ;
589+ DO_BUFFER_SWAP (coll_sig -> u .reduce .sendbuf , coll_sig -> u .reduce .recvbuf );
590+ break ;
591+ COLL_TYPE_ALL_CASE (SCAN ):
592+ sendcount = recvcount = coll_sig -> u .scan .count ;
593+ sendtype = recvtype = coll_sig -> u .scan .datatype ;
594+ DO_BUFFER_SWAP (coll_sig -> u .scan .sendbuf , coll_sig -> u .scan .recvbuf );
595+ break ;
596+ COLL_TYPE_ALL_CASE (EXSCAN ):
597+ sendcount = recvcount = coll_sig -> u .exscan .count ;
598+ sendtype = recvtype = coll_sig -> u .exscan .datatype ;
599+ DO_BUFFER_SWAP (coll_sig -> u .exscan .sendbuf , coll_sig -> u .exscan .recvbuf );
600+ break ;
601+ COLL_TYPE_ALL_CASE (REDUCE_SCATTER_BLOCK ):
602+ {
603+ MPIR_Comm * comm_ptr = coll_sig -> comm_ptr ;
604+ recvcount = coll_sig -> u .reduce_scatter_block .recvcount ;
605+ sendcount = comm_ptr -> local_size * recvcount ;
606+ }
607+ sendtype = recvtype = coll_sig -> u .reduce_scatter_block .datatype ;
608+ DO_BUFFER_SWAP (coll_sig -> u .reduce_scatter_block .sendbuf ,
609+ coll_sig -> u .reduce_scatter_block .recvbuf );
610+ break ;
611+ COLL_TYPE_ALL_CASE (REDUCE_SCATTER ):
612+ {
613+ MPIR_Comm * comm_ptr = coll_sig -> comm_ptr ;
614+ const MPI_Aint * counts = coll_sig -> u .reduce_scatter .recvcounts ;
615+ recvcount = counts [comm_ptr -> rank ];
616+ sendcount = 0 ;
617+ for (int i = 0 ; i < comm_ptr -> local_size ; i ++ ) {
618+ sendcount += counts [i ];
619+ }
620+ }
621+ sendtype = recvtype = coll_sig -> u .reduce_scatter .datatype ;
622+ DO_BUFFER_SWAP (coll_sig -> u .reduce_scatter .sendbuf , coll_sig -> u .reduce_scatter .recvbuf );
623+ break ;
624+ COLL_TYPE_ALL_CASE (BCAST ):
625+ sendcount = coll_sig -> u .bcast .count ;
626+ sendtype = coll_sig -> u .bcast .datatype ;
627+ if (coll_sig -> comm_ptr -> rank == coll_sig -> u .bcast .root ) {
628+ coll_sig -> u .bcast .buffer = host_swap (coll_sig -> u .bcast .buffer , sendcount , sendtype );
629+ } else {
630+ orig_recvbuf = coll_sig -> u .bcast .buffer ;
631+ recvcount = sendcount ;
632+ recvtype = recvtype ;
633+ coll_sig -> u .bcast .buffer = host_alloc (sendcount , sendtype );
634+ host_recvbuf = coll_sig -> u .bcast .buffer ;
635+ }
636+ default :
637+ break ;
638+ }
639+
640+ mpi_errno = MPIR_Coll_auto (coll_sig , NULL );
641+ MPIR_ERR_CHECK (mpi_errno );
642+
643+ if (orig_recvbuf ) {
644+ MPIR_Localcopy (host_recvbuf , recvcount , recvtype , orig_recvbuf , recvcount , recvtype );
645+ }
646+
647+ fn_exit :
648+ return mpi_errno ;
649+ fn_fail :
650+ goto fn_exit ;
651+ }
0 commit comments