@@ -16,7 +16,7 @@ def main():
1616 func_list = load_C_func_list (binding_dir , silent = True )
1717
1818 # Loading coll_algorithms.txt. It sets -
19- # - G.restrictions : a list of boolean conditions that can be used as restrictions
19+ # - G.conditions : a list of conditions that can be used as restrictions and in JSON tuning files
2020 # - G.algos: a two level array: [func-commkind][algo]
2121 load_coll_algos ("src/mpi/coll/coll_algorithms.txt" )
2222 # Prepare a one level algo array for conveninece -
@@ -47,12 +47,13 @@ def main():
4747 for a in G .coll_names :
4848 add_sched_auto_prototypes (a )
4949
50- # initialize MPIR_Coll_algo_table[algo_id] -> algo_fn
50+ # initialize MPIR_Coll_algo_table and MPIR_Coll_algo_names
5151 dump_MPII_Coll_algo_init ()
52- # initialize MPIR_Coll_cvar_table
53- dump_MPII_Coll_cvar_init ()
54- # create csel container from parsing json
55- dump_MPII_Create_container ()
52+ # initialize MPIR_Coll_cvar_table and MPIR_Coll_type_names
53+ dump_MPII_Coll_type_init ()
54+ # parsing routines for loading JSONs
55+ dump_MPII_Csel_parse_container ()
56+ dump_MPII_Csel_parse_operator ()
5657 # routines for checking algorithm CVARs
5758 dump_MPIR_Coll_cvar_to_algo_id ()
5859 dump_MPIR_Coll_init_algo_container ()
@@ -62,6 +63,8 @@ def main():
6263 dump_MPIR_Csel_coll_type_e ()
6364 # enum for algorithm id
6465 dump_MPIR_Csel_container_type_e ()
66+ # enum CSEL_NODE_TYPE
67+ dump_MPIR_Csel_node_type_e ()
6568 # algorithm container struct
6669 dump_MPII_Csel_container ()
6770
@@ -162,19 +165,25 @@ def add_sched_auto_prototypes(coll_name):
162165 if not re .match (r'(scan|exscan|neighbor_)' , coll_name ):
163166 add_prototype ("int MPIR_I%s_inter_sched_auto(%s)" % (coll_name , params ))
164167
165- def dump_MPII_Coll_cvar_init ():
168+ def dump_MPII_Coll_type_init ():
166169 G .out .append ("" )
167- decl = "void MPII_Coll_cvar_init (void)"
170+ decl = "void MPII_Coll_type_init (void)"
168171 add_prototype (decl )
169172 G .out .append (decl )
170173 dump_open ('{' )
171174 for a in G .coll_names :
172175 for is_blocking in (True , False ):
173- G .out .append ("MPIR_Coll_cvar_table[%s * 2] = %s;" % (coll_type (a , is_blocking ), cvar_name (a , is_blocking , "intra" )))
174- if not re .match (r'(scan|exscan|neighbor_)' , a ):
175- G .out .append ("MPIR_Coll_cvar_table[%s * 2 + 1] = %s;" % (coll_type (a , is_blocking ), cvar_name (a , is_blocking , "inter" )))
176- else :
177- G .out .append ("MPIR_Coll_cvar_table[%s * 2 + 1] = 0;" % (coll_type (a , is_blocking )))
176+ for commkind in ("intra" , "inter" ):
177+ if commkind == "inter" and re .match (r'(scan|exscan|neighbor_)' , a ):
178+ # CVARs for these inter-collective does not exist
179+ G .out .append ("MPIR_Coll_cvar_table[%s] = 0;" % (coll_type (a , is_blocking , commkind )))
180+ else :
181+ G .out .append ("MPIR_Coll_cvar_table[%s] = %s;" % (coll_type (a , is_blocking , commkind ), cvar_name (a , is_blocking , commkind )))
182+
183+ for a in G .coll_names :
184+ for is_blocking in (True , False ):
185+ for commkind in ("intra" , "inter" ):
186+ G .out .append ("MPIR_Coll_type_names[%s] = \" %s\" ;" % (coll_type (a , is_blocking , commkind ), a + '-' + commkind ))
178187 dump_close ('}' )
179188
180189def dump_MPII_Coll_algo_init ():
@@ -189,10 +198,16 @@ def dump_MPII_Coll_algo_init():
189198 if a ['func-commkind' ] != 'general' :
190199 algo_funcname += "_cnt"
191200 G .out .append ("MPIR_Coll_algo_table[%s] = %s;" % (idx , algo_funcname ))
201+ for a in G .algo_list :
202+ algo_funcname = get_algo_funcname (a )
203+ idx = algo_id (algo_funcname )
204+ if a ['func-commkind' ] != 'general' :
205+ algo_funcname += "_cnt"
206+ G .out .append ("MPIR_Coll_algo_names[%s] = \" %s\" ;" % (idx , algo_funcname ))
192207 dump_close ('}' )
193208
194209
195- def dump_MPII_Create_container ():
210+ def dump_MPII_Csel_parse_container ():
196211 G .out .append ("" )
197212 def dump_json_foreach_open ():
198213 dump_open ("json_object_object_foreach(obj, key, val) {" )
@@ -243,7 +258,7 @@ def dump_parse_params():
243258 G .out .append (" break;" )
244259 dump_close ('}' ) # switch
245260
246- decl = "void *MPII_Create_container (struct json_object *obj)"
261+ decl = "void *MPII_Csel_parse_container (struct json_object *obj)"
247262 add_prototype (decl )
248263 G .out .append (decl )
249264 dump_open ('{' )
@@ -256,6 +271,44 @@ def dump_parse_params():
256271 G .out .append ("return (void *) cnt;" )
257272 dump_close ('}' )
258273
274+ def dump_MPII_Csel_parse_operator ():
275+ decl = "int MPII_Csel_parse_operator(const char *ckey, MPIR_Csel_node_s *csel_node)"
276+ add_prototype (decl )
277+ G .out .append (decl )
278+ dump_open ('{' )
279+ dump_open ("if (ckey[0] == '!') {" )
280+ G .out .append ("csel_node->condition.negate = true;" )
281+ G .out .append ("ckey++;" )
282+ dump_else ()
283+ G .out .append ("csel_node->condition.negate = false;" )
284+ dump_close ('}' )
285+
286+ if_clase = "if"
287+ for a in G .conditions :
288+ cond = a
289+ has_thresh = False
290+ if RE .match (r'(.+)\(thresh\)' , a ):
291+ cond = RE .m .group (1 )
292+ has_thresh = True
293+ n = len (cond )
294+ if has_thresh :
295+ G .out .append ("%s (strncmp(ckey, \" %s\" , %d) == 0) {" % (if_clase , cond , n ))
296+ G .out .append (" csel_node->type = %s;" % condition_id (cond ))
297+ G .out .append (" MPIR_Assert(ckey[%d] == '(');" % n )
298+ G .out .append (" csel_node->condition.thresh = atoi(ckey + %d);" % (n + 1 ))
299+ else :
300+ G .out .append ("%s (strcmp(ckey, \" %s\" ) == 0) {" % (if_clase , cond ))
301+ G .out .append (" csel_node->type = %s;" % condition_id (cond ))
302+ if_clase = "} else if"
303+ G .out .append ("} else {" )
304+ G .out .append (" MPIR_Assert(0);" )
305+ G .out .append (" return MPI_ERR_OTHER;" )
306+ G .out .append ("}" )
307+
308+ G .out .append ("" )
309+ G .out .append ("return MPI_SUCCESS;" )
310+ dump_close ('}' )
311+
259312def dump_MPIR_Coll_cvar_to_algo_id ():
260313 G .out .append ("" )
261314 def dump_cvar_cases (name , commkind ):
@@ -330,19 +383,17 @@ def dump_MPIR_Coll_init_algo_container():
330383def dump_MPIR_Coll_check_algo_restriction ():
331384 G .out .append ("" )
332385 def dump_check_restriction (restriction ):
333- u = "coll_sig->%s" % coll_name
334386 r = restriction
335387 negate = False
336388 if restriction .startswith ('!' ):
337389 r = restriction [1 :]
338390 negate = True
391+ if RE .match (r'.*\(.*\)' , r ):
392+ raise Exception ("Threshold condition %s cannot be used as a restriction" % r )
339393
340394 cond = None
341- if r in G .restrictions :
342- if G .restrictions [r ].startswith ('MPIR_COLL_ATTR__' ):
343- cond = "(coll_sig->attr & %s)" % G .restrictions [r ]
344- else :
345- cond = "%s(coll_sig)" % G .restrictions [r ]
395+ if r in G .conditions :
396+ cond = "%s(coll_sig)" % G .conditions [r ]
346397 else :
347398 raise Exception ("Restriction %s not listed" % restriction )
348399
@@ -373,8 +424,9 @@ def dump_MPIR_Csel_coll_type_e():
373424 G .out2 .append ("typedef enum {" )
374425 for a in G .coll_names :
375426 for is_blocking in (True , False ):
376- G .out2 .append (" %s," % coll_type (a , is_blocking ))
377- G .out2 .append (" %s" % coll_type ("END" , True ))
427+ for commkind in ("intra" , "inter" ):
428+ G .out2 .append (" %s," % coll_type (a , is_blocking , commkind ))
429+ G .out2 .append (" %s" % coll_type_END ())
378430 G .out2 .append ("} MPIR_Csel_coll_type_e;" )
379431
380432def dump_MPIR_Csel_container_type_e ():
@@ -383,9 +435,27 @@ def dump_MPIR_Csel_container_type_e():
383435 for a in G .algo_list :
384436 algo_funcname = get_algo_funcname (a )
385437 G .out2 .append (" %s," % algo_id (algo_funcname ))
386- G .out2 .append (" %s" % algo_id ( "Algorithm_count" ))
438+ G .out2 .append (" %s" % algo_id_END ( ))
387439 G .out2 .append ("} MPIR_Csel_container_type_e;" )
388440
441+ def dump_MPIR_Csel_node_type_e ():
442+ G .out2 .append ("" )
443+ G .out2 .append ("typedef enum {" )
444+ for a in G .conditions :
445+ G .out2 .append (" %s," % condition_id (a ))
446+ G .out2 .append (" CSEL_NODE_TYPE__OPERATOR__ANY," )
447+ G .out2 .append (" CSEL_NODE_TYPE__CONTAINER," )
448+ G .out2 .append ("} MPIR_Csel_node_type_e;" )
449+
450+ def dump_MPIR_Csel_node_s ():
451+ G .out2 .append ("" )
452+ G .out2 .append ("typedef struct MPIR_Csel_node {" )
453+ G .out2 .append (" MPIR_Csel_node_type_e type;" )
454+ G .out2 .append (" MPI_Aint thresh;" )
455+ G .out2 .append (" struct MPIR_Csel_node *success;" )
456+ G .out2 .append (" struct MPIR_Csel_node *failure;" )
457+ G .out2 .append ("} MPIR_Csel_node_s;" )
458+
389459def dump_MPII_Csel_container ():
390460 G .out2 .append ("" )
391461 def dump_algo_params ():
@@ -414,19 +484,18 @@ def add_prototype(l):
414484
415485def load_coll_algos (algo_txt ):
416486 G .algos = {}
417- G .restrictions = {}
487+ G .conditions = {}
418488 with open (algo_txt ) as In :
419489 (func_commkind , algo_list , algo ) = (None , None , None )
420490 for line in In :
421- if RE .match (r'(\w+-(intra|inter)|general):' , line ):
491+ if RE .match (r'(\w+-(intra|inter)|general|conditions ):' , line ):
422492 func_commkind = RE .m .group (1 )
423- algo_list = []
424- G .algos [func_commkind ] = algo_list
425- elif RE .match (r'restrictions:' , line ):
426- func_commkind = "restrictions"
427- elif func_commkind == "restrictions" :
493+ if func_commkind != "conditions" :
494+ algo_list = []
495+ G .algos [func_commkind ] = algo_list
496+ elif func_commkind == "conditions" :
428497 if RE .match (r'\s+([\w-]+):\s*(\w+)' , line ):
429- G .restrictions [RE .m .group (1 )] = RE .m .group (2 )
498+ G .conditions [RE .m .group (1 )] = RE .m .group (2 )
430499 elif func_commkind :
431500 if RE .match (r'\s+(\w+)\s*$' , line ):
432501 algo = {"name" : RE .m .group (1 ), "func-commkind" : func_commkind }
@@ -462,6 +531,8 @@ def dump_coll_impl(name, blocking_type):
462531 G .out .append ("coll_sig.is_persistent = false;" )
463532 G .out .append ("coll_sig.sched = NULL;" )
464533
534+ G .out .append ("memset(&coll_sig.cache, sizeof(coll_sig.cache), 0);" );
535+
465536 phash = {}
466537 for p in func ['parameters' ]:
467538 if p ['name' ] == 'comm' :
@@ -470,42 +541,6 @@ def dump_coll_impl(name, blocking_type):
470541 phash [p ['name' ]] = 1
471542 G .out .append ("coll_sig.u.%s.%s = %s;" % (name , p ['name' ], p ['name' ]))
472543
473- # msg_size
474- def init_msg_size (count , datatype ):
475- G .out .append ("MPIR_Datatype_get_size_macro(%s, coll_sig.msg_size);" % datatype )
476- G .out .append ("coll_sig.msg_size *= %s;" % count )
477-
478- if 'count' in phash and 'datatype' in phash :
479- init_msg_size ('count' , 'datatype' )
480- elif 'recvcount' in phash and 'recvtype' in phash :
481- init_msg_size ('recvcount' , 'recvtype' )
482- elif 'recvcount' in phash and 'datatype' in phash :
483- init_msg_size ('recvcount' , 'datatype' )
484- elif 'recvcounts' in phash and 'recvtype' in phash :
485- # allgatherv, alltoallv - should we use total/max message size, or just skip?
486- init_msg_size ('recvcounts[0]' , 'recvtype' )
487- elif 'recvcounts' in phash and 'recvtypes' in phash :
488- # alltoallw - should we use total/max message size, or just skip?
489- init_msg_size ('recvcounts[0]' , 'recvtypes[0]' )
490- else :
491- raise Exception ("init coll_sig: unhandled coll type" )
492-
493-
494- G .out .append ("if (MPL_is_pof2(comm_ptr->local_size)) coll_sig.attr |= MPIR_COLL_ATTR__pof2;" )
495-
496- if 'sendbuf' in phash :
497- G .out .append ("if (sendbuf == MPI_IN_PLACE) coll_sig.attr |= MPIR_COLL_ATTR__inplace;" )
498-
499- if 'op' in phash :
500- G .out .append ("if (HANDLE_IS_BUILTIN(op)) {" )
501- G .out .append (" coll_sig.attr |= MPIR_COLL_ATTR__builtin_op | MPIR_COLL_ATTR__commutative;" )
502- G .out .append ("} else if (MPIR_Op_is_commutative(op)) {" )
503- G .out .append (" coll_sig.attr |= MPIR_COLL_ATTR__commutative;" )
504- G .out .append ("}" )
505-
506- G .out .append ("MPIR_Init_coll_sig(&coll_sig);" )
507- G .out .append ("MPID_Init_coll_sig(&coll_sig);" )
508-
509544 # Call csel
510545 G .out .append ("" )
511546 G .out .append ("mpi_errno = MPIR_Coll_composition_auto(&coll_sig);" )
@@ -701,12 +736,15 @@ def get_func_params(func, name, blocking_type):
701736
702737 return ', ' .join (params )
703738
704- def coll_type (coll_name , is_blocking ):
739+ def coll_type (coll_name , is_blocking , commkind ):
705740 prefix = "MPIR_CSEL_COLL_TYPE"
706741 if is_blocking :
707- return "%s__%s" % (prefix , coll_name .upper ())
742+ return "%s__%s_% s" % (prefix , commkind . upper () , coll_name .upper ())
708743 else :
709- return "%s__I%s" % (prefix , coll_name .upper ())
744+ return "%s__%s_I%s" % (prefix , commkind .upper (), coll_name .upper ())
745+
746+ def coll_type_END ():
747+ return "MPIR_CSEL_COLL_TYPE__END"
710748
711749def cvar_name (coll_name , is_blocking , commkind ):
712750 if is_blocking :
@@ -722,6 +760,15 @@ def algo_id(algo_funcname):
722760 else :
723761 return "%s__%s" % (prefix , algo_funcname )
724762
763+ def algo_id_END ():
764+ return "MPII_CSEL_CONTAINER_TYPE__ALGORITHM__END"
765+
766+ def condition_id (name ):
767+ prefix = "CSEL_NODE_TYPE__OPERATOR__"
768+ a = re .sub (r'-' , '_' , name )
769+ a = re .sub (r'\(thresh\)$' , '' , a )
770+ return prefix + a
771+
725772def algo_struct_name (algo ):
726773 algo_funcname = get_algo_funcname (algo )
727774 struct_name = re .sub (r'MPIR_' , '' , algo_funcname ).lower ()
0 commit comments