Skip to content

Commit 98abcd7

Browse files
committed
Generate CSEL node coditions
1 parent 3029c64 commit 98abcd7

File tree

6 files changed

+401
-988
lines changed

6 files changed

+401
-988
lines changed

maint/gen_coll.py

Lines changed: 118 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def main():
1616
func_list = load_C_func_list(binding_dir, silent=True)
1717

1818
# Loading coll_algorithms.txt. It sets -
19-
# - G.restrictions: a list of boolean conditions that can be used as restrictions
19+
# - G.conditions: a list of conditions that can be used as restrictions and in JSON tuning files
2020
# - G.algos: a two level array: [func-commkind][algo]
2121
load_coll_algos("src/mpi/coll/coll_algorithms.txt")
2222
# Prepare a one level algo array for conveninece -
@@ -47,12 +47,13 @@ def main():
4747
for a in G.coll_names:
4848
add_sched_auto_prototypes(a)
4949

50-
# initialize MPIR_Coll_algo_table[algo_id] -> algo_fn
50+
# initialize MPIR_Coll_algo_table and MPIR_Coll_algo_names
5151
dump_MPII_Coll_algo_init()
52-
# initialize MPIR_Coll_cvar_table
53-
dump_MPII_Coll_cvar_init()
54-
# create csel container from parsing json
55-
dump_MPII_Create_container()
52+
# initialize MPIR_Coll_cvar_table and MPIR_Coll_type_names
53+
dump_MPII_Coll_type_init()
54+
# parsing routines for loading JSONs
55+
dump_MPII_Csel_parse_container()
56+
dump_MPII_Csel_parse_operator()
5657
# routines for checking algorithm CVARs
5758
dump_MPIR_Coll_cvar_to_algo_id()
5859
dump_MPIR_Coll_init_algo_container()
@@ -62,6 +63,8 @@ def main():
6263
dump_MPIR_Csel_coll_type_e()
6364
# enum for algorithm id
6465
dump_MPIR_Csel_container_type_e()
66+
# enum CSEL_NODE_TYPE
67+
dump_MPIR_Csel_node_type_e()
6568
# algorithm container struct
6669
dump_MPII_Csel_container()
6770

@@ -162,19 +165,25 @@ def add_sched_auto_prototypes(coll_name):
162165
if not re.match(r'(scan|exscan|neighbor_)', coll_name):
163166
add_prototype("int MPIR_I%s_inter_sched_auto(%s)" % (coll_name, params))
164167

165-
def dump_MPII_Coll_cvar_init():
168+
def dump_MPII_Coll_type_init():
166169
G.out.append("")
167-
decl = "void MPII_Coll_cvar_init(void)"
170+
decl = "void MPII_Coll_type_init(void)"
168171
add_prototype(decl)
169172
G.out.append(decl)
170173
dump_open('{')
171174
for a in G.coll_names:
172175
for is_blocking in (True, False):
173-
G.out.append("MPIR_Coll_cvar_table[%s * 2] = %s;" % (coll_type(a, is_blocking), cvar_name(a, is_blocking, "intra")))
174-
if not re.match(r'(scan|exscan|neighbor_)', a):
175-
G.out.append("MPIR_Coll_cvar_table[%s * 2 + 1] = %s;" % (coll_type(a, is_blocking), cvar_name(a, is_blocking, "inter")))
176-
else:
177-
G.out.append("MPIR_Coll_cvar_table[%s * 2 + 1] = 0;" % (coll_type(a, is_blocking)))
176+
for commkind in ("intra", "inter"):
177+
if commkind == "inter" and re.match(r'(scan|exscan|neighbor_)', a):
178+
# CVARs for these inter-collective does not exist
179+
G.out.append("MPIR_Coll_cvar_table[%s] = 0;" % (coll_type(a, is_blocking, commkind)))
180+
else:
181+
G.out.append("MPIR_Coll_cvar_table[%s] = %s;" % (coll_type(a, is_blocking, commkind), cvar_name(a, is_blocking, commkind)))
182+
183+
for a in G.coll_names:
184+
for is_blocking in (True, False):
185+
for commkind in ("intra", "inter"):
186+
G.out.append("MPIR_Coll_type_names[%s] = \"%s\";" % (coll_type(a, is_blocking, commkind), a + '-' + commkind))
178187
dump_close('}')
179188

180189
def dump_MPII_Coll_algo_init():
@@ -189,10 +198,16 @@ def dump_MPII_Coll_algo_init():
189198
if a['func-commkind'] != 'general':
190199
algo_funcname += "_cnt"
191200
G.out.append("MPIR_Coll_algo_table[%s] = %s;" % (idx, algo_funcname))
201+
for a in G.algo_list:
202+
algo_funcname = get_algo_funcname(a)
203+
idx = algo_id(algo_funcname)
204+
if a['func-commkind'] != 'general':
205+
algo_funcname += "_cnt"
206+
G.out.append("MPIR_Coll_algo_names[%s] = \"%s\";" % (idx, algo_funcname))
192207
dump_close('}')
193208

194209

195-
def dump_MPII_Create_container():
210+
def dump_MPII_Csel_parse_container():
196211
G.out.append("")
197212
def dump_json_foreach_open():
198213
dump_open("json_object_object_foreach(obj, key, val) {")
@@ -243,7 +258,7 @@ def dump_parse_params():
243258
G.out.append(" break;")
244259
dump_close('}') # switch
245260

246-
decl = "void *MPII_Create_container(struct json_object *obj)"
261+
decl = "void *MPII_Csel_parse_container(struct json_object *obj)"
247262
add_prototype(decl)
248263
G.out.append(decl)
249264
dump_open('{')
@@ -256,6 +271,44 @@ def dump_parse_params():
256271
G.out.append("return (void *) cnt;")
257272
dump_close('}')
258273

274+
def dump_MPII_Csel_parse_operator():
275+
decl = "int MPII_Csel_parse_operator(const char *ckey, MPIR_Csel_node_s *csel_node)"
276+
add_prototype(decl)
277+
G.out.append(decl)
278+
dump_open('{')
279+
dump_open("if (ckey[0] == '!') {")
280+
G.out.append("csel_node->condition.negate = true;")
281+
G.out.append("ckey++;")
282+
dump_else()
283+
G.out.append("csel_node->condition.negate = false;")
284+
dump_close('}')
285+
286+
if_clase = "if"
287+
for a in G.conditions:
288+
cond = a
289+
has_thresh = False
290+
if RE.match(r'(.+)\(thresh\)', a):
291+
cond = RE.m.group(1)
292+
has_thresh = True
293+
n = len(cond)
294+
if has_thresh:
295+
G.out.append("%s (strncmp(ckey, \"%s\", %d) == 0) {" % (if_clase, cond, n))
296+
G.out.append(" csel_node->type = %s;" % condition_id(cond))
297+
G.out.append(" MPIR_Assert(ckey[%d] == '(');" % n)
298+
G.out.append(" csel_node->condition.thresh = atoi(ckey + %d);" % (n + 1))
299+
else:
300+
G.out.append("%s (strcmp(ckey, \"%s\") == 0) {" % (if_clase, cond))
301+
G.out.append(" csel_node->type = %s;" % condition_id(cond))
302+
if_clase = "} else if"
303+
G.out.append("} else {")
304+
G.out.append(" MPIR_Assert(0);")
305+
G.out.append(" return MPI_ERR_OTHER;")
306+
G.out.append("}")
307+
308+
G.out.append("")
309+
G.out.append("return MPI_SUCCESS;")
310+
dump_close('}')
311+
259312
def dump_MPIR_Coll_cvar_to_algo_id():
260313
G.out.append("")
261314
def dump_cvar_cases(name, commkind):
@@ -330,19 +383,17 @@ def dump_MPIR_Coll_init_algo_container():
330383
def dump_MPIR_Coll_check_algo_restriction():
331384
G.out.append("")
332385
def dump_check_restriction(restriction):
333-
u = "coll_sig->%s" % coll_name
334386
r = restriction
335387
negate = False
336388
if restriction.startswith('!'):
337389
r = restriction[1:]
338390
negate = True
391+
if RE.match(r'.*\(.*\)', r):
392+
raise Exception("Threshold condition %s cannot be used as a restriction" % r)
339393

340394
cond = None
341-
if r in G.restrictions:
342-
if G.restrictions[r].startswith('MPIR_COLL_ATTR__'):
343-
cond = "(coll_sig->attr & %s)" % G.restrictions[r]
344-
else:
345-
cond = "%s(coll_sig)" % G.restrictions[r]
395+
if r in G.conditions:
396+
cond = "%s(coll_sig)" % G.conditions[r]
346397
else:
347398
raise Exception("Restriction %s not listed" % restriction)
348399

@@ -373,8 +424,9 @@ def dump_MPIR_Csel_coll_type_e():
373424
G.out2.append("typedef enum {")
374425
for a in G.coll_names:
375426
for is_blocking in (True, False):
376-
G.out2.append(" %s," % coll_type(a, is_blocking))
377-
G.out2.append(" %s" % coll_type("END", True))
427+
for commkind in ("intra", "inter"):
428+
G.out2.append(" %s," % coll_type(a, is_blocking, commkind))
429+
G.out2.append(" %s" % coll_type_END())
378430
G.out2.append("} MPIR_Csel_coll_type_e;")
379431

380432
def dump_MPIR_Csel_container_type_e():
@@ -383,9 +435,27 @@ def dump_MPIR_Csel_container_type_e():
383435
for a in G.algo_list:
384436
algo_funcname = get_algo_funcname(a)
385437
G.out2.append(" %s," % algo_id(algo_funcname))
386-
G.out2.append(" %s" % algo_id("Algorithm_count"))
438+
G.out2.append(" %s" % algo_id_END())
387439
G.out2.append("} MPIR_Csel_container_type_e;")
388440

441+
def dump_MPIR_Csel_node_type_e():
442+
G.out2.append("")
443+
G.out2.append("typedef enum {")
444+
for a in G.conditions:
445+
G.out2.append(" %s," % condition_id(a))
446+
G.out2.append(" CSEL_NODE_TYPE__OPERATOR__ANY,")
447+
G.out2.append(" CSEL_NODE_TYPE__CONTAINER,")
448+
G.out2.append("} MPIR_Csel_node_type_e;")
449+
450+
def dump_MPIR_Csel_node_s():
451+
G.out2.append("")
452+
G.out2.append("typedef struct MPIR_Csel_node {")
453+
G.out2.append(" MPIR_Csel_node_type_e type;")
454+
G.out2.append(" MPI_Aint thresh;")
455+
G.out2.append(" struct MPIR_Csel_node *success;")
456+
G.out2.append(" struct MPIR_Csel_node *failure;")
457+
G.out2.append("} MPIR_Csel_node_s;")
458+
389459
def dump_MPII_Csel_container():
390460
G.out2.append("")
391461
def dump_algo_params():
@@ -414,19 +484,18 @@ def add_prototype(l):
414484

415485
def load_coll_algos(algo_txt):
416486
G.algos = {}
417-
G.restrictions = {}
487+
G.conditions = {}
418488
with open(algo_txt) as In:
419489
(func_commkind, algo_list, algo) = (None, None, None)
420490
for line in In:
421-
if RE.match(r'(\w+-(intra|inter)|general):', line):
491+
if RE.match(r'(\w+-(intra|inter)|general|conditions):', line):
422492
func_commkind = RE.m.group(1)
423-
algo_list = []
424-
G.algos[func_commkind] = algo_list
425-
elif RE.match(r'restrictions:', line):
426-
func_commkind = "restrictions"
427-
elif func_commkind == "restrictions":
493+
if func_commkind != "conditions":
494+
algo_list = []
495+
G.algos[func_commkind] = algo_list
496+
elif func_commkind == "conditions":
428497
if RE.match(r'\s+([\w-]+):\s*(\w+)', line):
429-
G.restrictions[RE.m.group(1)] = RE.m.group(2)
498+
G.conditions[RE.m.group(1)] = RE.m.group(2)
430499
elif func_commkind:
431500
if RE.match(r'\s+(\w+)\s*$', line):
432501
algo = {"name": RE.m.group(1), "func-commkind": func_commkind}
@@ -462,6 +531,8 @@ def dump_coll_impl(name, blocking_type):
462531
G.out.append("coll_sig.is_persistent = false;")
463532
G.out.append("coll_sig.sched = NULL;")
464533

534+
G.out.append("memset(&coll_sig.cache, sizeof(coll_sig.cache), 0);");
535+
465536
phash = {}
466537
for p in func['parameters']:
467538
if p['name'] == 'comm':
@@ -470,42 +541,6 @@ def dump_coll_impl(name, blocking_type):
470541
phash[p['name']] = 1
471542
G.out.append("coll_sig.u.%s.%s = %s;" % (name, p['name'], p['name']))
472543

473-
# msg_size
474-
def init_msg_size(count, datatype):
475-
G.out.append("MPIR_Datatype_get_size_macro(%s, coll_sig.msg_size);" % datatype)
476-
G.out.append("coll_sig.msg_size *= %s;" % count)
477-
478-
if 'count' in phash and 'datatype' in phash:
479-
init_msg_size('count', 'datatype')
480-
elif 'recvcount' in phash and 'recvtype' in phash:
481-
init_msg_size('recvcount', 'recvtype')
482-
elif 'recvcount' in phash and 'datatype' in phash:
483-
init_msg_size('recvcount', 'datatype')
484-
elif 'recvcounts' in phash and 'recvtype' in phash:
485-
# allgatherv, alltoallv - should we use total/max message size, or just skip?
486-
init_msg_size('recvcounts[0]', 'recvtype')
487-
elif 'recvcounts' in phash and 'recvtypes' in phash:
488-
# alltoallw - should we use total/max message size, or just skip?
489-
init_msg_size('recvcounts[0]', 'recvtypes[0]')
490-
else:
491-
raise Exception("init coll_sig: unhandled coll type")
492-
493-
494-
G.out.append("if (MPL_is_pof2(comm_ptr->local_size)) coll_sig.attr |= MPIR_COLL_ATTR__pof2;")
495-
496-
if 'sendbuf' in phash:
497-
G.out.append("if (sendbuf == MPI_IN_PLACE) coll_sig.attr |= MPIR_COLL_ATTR__inplace;")
498-
499-
if 'op' in phash:
500-
G.out.append("if (HANDLE_IS_BUILTIN(op)) {")
501-
G.out.append(" coll_sig.attr |= MPIR_COLL_ATTR__builtin_op | MPIR_COLL_ATTR__commutative;")
502-
G.out.append("} else if (MPIR_Op_is_commutative(op)) {")
503-
G.out.append(" coll_sig.attr |= MPIR_COLL_ATTR__commutative;")
504-
G.out.append("}")
505-
506-
G.out.append("MPIR_Init_coll_sig(&coll_sig);")
507-
G.out.append("MPID_Init_coll_sig(&coll_sig);")
508-
509544
# Call csel
510545
G.out.append("")
511546
G.out.append("mpi_errno = MPIR_Coll_composition_auto(&coll_sig);")
@@ -701,12 +736,15 @@ def get_func_params(func, name, blocking_type):
701736

702737
return ', '.join(params)
703738

704-
def coll_type(coll_name, is_blocking):
739+
def coll_type(coll_name, is_blocking, commkind):
705740
prefix = "MPIR_CSEL_COLL_TYPE"
706741
if is_blocking:
707-
return "%s__%s" % (prefix, coll_name.upper())
742+
return "%s__%s_%s" % (prefix, commkind.upper(), coll_name.upper())
708743
else:
709-
return "%s__I%s" % (prefix, coll_name.upper())
744+
return "%s__%s_I%s" % (prefix, commkind.upper(), coll_name.upper())
745+
746+
def coll_type_END():
747+
return "MPIR_CSEL_COLL_TYPE__END"
710748

711749
def cvar_name(coll_name, is_blocking, commkind):
712750
if is_blocking:
@@ -722,6 +760,15 @@ def algo_id(algo_funcname):
722760
else:
723761
return "%s__%s" % (prefix, algo_funcname)
724762

763+
def algo_id_END():
764+
return "MPII_CSEL_CONTAINER_TYPE__ALGORITHM__END"
765+
766+
def condition_id(name):
767+
prefix = "CSEL_NODE_TYPE__OPERATOR__"
768+
a = re.sub(r'-', '_', name)
769+
a = re.sub(r'\(thresh\)$', '', a)
770+
return prefix + a
771+
725772
def algo_struct_name(algo):
726773
algo_funcname = get_algo_funcname(algo)
727774
struct_name = re.sub(r'MPIR_', '', algo_funcname).lower()

src/include/mpir_coll.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ struct MPIR_Csel_coll_sig {
4444
void *sched;
4545
enum MPIR_sched_type sched_type;
4646
bool is_persistent;
47-
/* derived info to assist algorithm selection */
48-
MPI_Aint msg_size;
49-
uint64_t attr;
47+
48+
struct {
49+
bool is_gpu;
50+
} cache;
5051

5152
union {
5253
struct {

0 commit comments

Comments
 (0)