@@ -44,14 +44,45 @@ MPL_STATIC_INLINE_PREFIX bool MPIDI_POSIX_check_release_gather(MPIR_Csel_coll_si
4444
4545 /* check coll_type */
4646 MPIDI_POSIX_release_gather_opcode_t opcode ;
47+ MPI_Datatype datatype_for_reduce = MPI_DATATYPE_NULL ;
48+ MPI_Op op_for_reduce ;
4749 switch (coll_sig -> coll_type ) {
4850 case MPIR_CSEL_COLL_TYPE__INTRA_BCAST :
4951 opcode = MPIDI_POSIX_RELEASE_GATHER_OPCODE_BCAST ;
5052 break ;
53+ case MPIR_CSEL_COLL_TYPE__INTRA_REDUCE :
54+ opcode = MPIDI_POSIX_RELEASE_GATHER_OPCODE_REDUCE ;
55+ datatype_for_reduce = coll_sig -> u .reduce .datatype ;
56+ op_for_reduce = coll_sig -> u .reduce .op ;
57+ break ;
58+ case MPIR_CSEL_COLL_TYPE__INTRA_ALLREDUCE :
59+ opcode = MPIDI_POSIX_RELEASE_GATHER_OPCODE_ALLREDUCE ;
60+ datatype_for_reduce = coll_sig -> u .allreduce .datatype ;
61+ op_for_reduce = coll_sig -> u .allreduce .op ;
62+ break ;
63+ case MPIR_CSEL_COLL_TYPE__INTRA_BARRIER :
64+ opcode = MPIDI_POSIX_RELEASE_GATHER_OPCODE_BARRIER ;
65+ break ;
5166 default :
5267 return false;
5368 }
5469
70+ if (datatype_for_reduce != MPI_DATATYPE_NULL ) {
71+ MPI_Aint type_size , dummy_lb , extent , true_extent ;
72+ MPIR_Datatype_get_size_macro (datatype_for_reduce , type_size );
73+ MPIR_Type_get_extent_impl (datatype_for_reduce , & dummy_lb , & extent );
74+ MPIR_Type_get_true_extent_impl (datatype_for_reduce , & dummy_lb , & true_extent );
75+ extent = MPL_MAX (extent , true_extent );
76+ if (MPL_MAX (type_size , extent ) >=
77+ MPIR_CVAR_REDUCE_INTRANODE_BUFFER_TOTAL_SIZE / MPIR_CVAR_REDUCE_INTRANODE_NUM_CELLS ) {
78+ return false;
79+ }
80+
81+ if (!MPIR_Csel_op_is_commutative (op )) {
82+ return false;
83+ }
84+ }
85+
5586 /* Check repeats if the algorithm CVAR is not set */
5687 if (!(coll_sig -> flags & MPIR_COLL_SIG_FLAG__CVAR )) {
5788 MPIDI_POSIX_COMM (comm_ptr , release_gather ).num_collective_calls ++ ;
@@ -217,25 +248,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_release_gather(const void *s
217248 MPIR_Type_get_extent_impl (datatype , & lb , & extent );
218249 MPIR_Type_get_true_extent_impl (datatype , & lb , & true_extent );
219250 extent = MPL_MAX (extent , true_extent );
220- if (MPL_MAX (type_size , extent ) >=
221- MPIR_CVAR_REDUCE_INTRANODE_BUFFER_TOTAL_SIZE / MPIR_CVAR_REDUCE_INTRANODE_NUM_CELLS ) {
222- goto fallback ;
223- }
224-
225- MPIDI_POSIX_COMM (comm_ptr , release_gather ).num_collective_calls ++ ;
226- if (MPIDI_POSIX_COMM (comm_ptr , release_gather ).num_collective_calls <
227- MPIR_CVAR_POSIX_NUM_COLLS_THRESHOLD ) {
228- /* Fallback to pt2pt algorithms if the total number of release_gather collective calls is
229- * less than the specified threshold */
230- goto fallback ;
231- }
232-
233- /* Lazy initialization of release_gather specific struct */
234- mpi_errno =
235- MPIDI_POSIX_mpi_release_gather_comm_init (comm_ptr ,
236- MPIDI_POSIX_RELEASE_GATHER_OPCODE_REDUCE );
237- MPII_COLLECTIVE_FALLBACK_CHECK (MPIR_Comm_rank (comm_ptr ), !mpi_errno , mpi_errno ,
238- "release_gather reduce cannot create more shared memory. Falling back to pt2pt algorithms.\n" );
239251
240252 if (sendbuf == MPI_IN_PLACE ) {
241253 sendbuf = recvbuf ;
@@ -271,10 +283,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_release_gather(const void *s
271283 return mpi_errno ;
272284 fn_fail :
273285 goto fn_exit ;
274- fallback :
275- /* FIXME: proper error */
276- mpi_errno = MPI_ERR_OTHER ;
277- goto fn_exit ;
278286}
279287
280288/* Intra-node allreduce is implemented as a gather step followed by a release step in release_gather
@@ -309,25 +317,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce_release_gather(const void
309317 MPIR_Type_get_extent_impl (datatype , & lb , & extent );
310318 MPIR_Type_get_true_extent_impl (datatype , & lb , & true_extent );
311319 extent = MPL_MAX (extent , true_extent );
312- if (MPL_MAX (type_size , extent ) >=
313- MPIR_CVAR_REDUCE_INTRANODE_BUFFER_TOTAL_SIZE / MPIR_CVAR_REDUCE_INTRANODE_NUM_CELLS ) {
314- goto fallback ;
315- }
316-
317- MPIDI_POSIX_COMM (comm_ptr , release_gather ).num_collective_calls ++ ;
318- if (MPIDI_POSIX_COMM (comm_ptr , release_gather ).num_collective_calls <
319- MPIR_CVAR_POSIX_NUM_COLLS_THRESHOLD ) {
320- /* Fallback to pt2pt algorithms if the total number of release_gather collective calls is
321- * less than the specified threshold */
322- goto fallback ;
323- }
324-
325- /* Lazy initialization of release_gather specific struct */
326- mpi_errno =
327- MPIDI_POSIX_mpi_release_gather_comm_init (comm_ptr ,
328- MPIDI_POSIX_RELEASE_GATHER_OPCODE_ALLREDUCE );
329- MPII_COLLECTIVE_FALLBACK_CHECK (MPIR_Comm_rank (comm_ptr ), !mpi_errno , mpi_errno ,
330- "release_gather allreduce cannot create more shared memory. Falling back to pt2pt algorithms.\n" );
331320
332321 if (sendbuf == MPI_IN_PLACE ) {
333322 sendbuf = recvbuf ;
@@ -365,11 +354,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce_release_gather(const void
365354
366355 fn_fail :
367356 goto fn_exit ;
368-
369- fallback :
370- /* FIXME: proper error */
371- mpi_errno = MPI_ERR_OTHER ;
372- goto fn_exit ;
373357}
374358
375359/* Intra-node barrier is implemented as a gather step followed by a release step in release_gather
@@ -382,21 +366,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier_release_gather(MPIR_Comm *
382366
383367 MPIR_FUNC_ENTER ;
384368
385- MPIDI_POSIX_COMM (comm_ptr , release_gather ).num_collective_calls ++ ;
386- if (MPIDI_POSIX_COMM (comm_ptr , release_gather ).num_collective_calls <
387- MPIR_CVAR_POSIX_NUM_COLLS_THRESHOLD ) {
388- /* Fallback to pt2pt algorithms if the total number of release_gather collective calls is
389- * less than the specified threshold */
390- goto fallback ;
391- }
392-
393- /* Lazy initialization of release_gather specific struct */
394- mpi_errno =
395- MPIDI_POSIX_mpi_release_gather_comm_init (comm_ptr ,
396- MPIDI_POSIX_RELEASE_GATHER_OPCODE_BARRIER );
397- MPII_COLLECTIVE_FALLBACK_CHECK (MPIR_Comm_rank (comm_ptr ), !mpi_errno , mpi_errno ,
398- "release_gather barrier cannot create more shared memory. Falling back to pt2pt algorithms.\n" );
399-
400369 mpi_errno =
401370 MPIDI_POSIX_mpi_release_gather_gather (NULL , NULL , 0 , MPI_DATATYPE_NULL , MPI_OP_NULL , 0 ,
402371 comm_ptr , coll_attr ,
@@ -414,11 +383,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier_release_gather(MPIR_Comm *
414383
415384 fn_fail :
416385 goto fn_exit ;
417-
418- fallback :
419- /* FIXME: proper error */
420- mpi_errno = MPI_ERR_OTHER ;
421- goto fn_exit ;
422386}
423387
424388#endif /* POSIX_COLL_RELEASE_GATHER_H_INCLUDED */
0 commit comments