@@ -28,35 +28,71 @@ def __init__(self, ops, edges):
2828 self .successors_of = successors_of
2929 self .available = available
3030
31+ def copy (self ):
32+ copy = type (self ).__new__ (type (self ))
33+ copy .predecessors_of = {op : set (s ) for op , s in iteritems (self .predecessors_of )}
34+ copy .successors_of = {op : set (s ) for op , s in iteritems (self .successors_of )}
35+ copy .available = defaultdict (set )
36+ for op_type , s in iteritems (self .available ):
37+ copy .available [op_type ] = set (s )
38+ return copy
39+
3140 def greedy_choose_type (self ):
3241 chosen_type = sorted (
3342 self .available .items (), key = lambda x : len (x [1 ]))[- 1 ][0 ]
3443 candidates = self .available [chosen_type ]
3544 return chosen_type , candidates
3645
37- def remove_op (self , op ):
38- if type (op ) in self .available :
39- self .available [type (op )].difference_update ([op ])
40- if not self .available [type (op )]:
41- del self .available [type (op )]
42- self ._remove_op_from_predsucc (op )
46+ def remove_op (self , op , update_available = True ):
47+ available = self ._remove_op_from_predsucc (op )
48+
49+ if update_available :
50+ if type (op ) in self .available :
51+ self .available [type (op )].difference_update ([op ])
52+ if not self .available [type (op )]:
53+ del self .available [type (op )]
54+ for op in available :
55+ self .available [type (op )].add (op )
56+
57+ def remove_ops (self , ops , update_available = True ):
58+ available = []
59+ for op in ops :
60+ available .extend (self ._remove_op_from_predsucc (op ))
4361
44- def remove_op_type (self , chosen_type , chosen_ops ):
45- self .available [chosen_type ].difference_update (chosen_ops )
46- if not self .available [chosen_type ]:
47- del self .available [chosen_type ]
62+ if update_available :
63+ for op in ops :
64+ if type (op ) in self .available :
65+ self .available [type (op )].difference_update ([op ])
66+ for op in available :
67+ self .available [type (op )].add (op )
68+ for op_type , s in self .available :
69+ if not s :
70+ del self .available [op_type ]
4871
72+ def remove_op_type (self , chosen_type , chosen_ops , update_available = True ):
73+ available = []
4974 for op in chosen_ops :
50- self ._remove_op_from_predsucc (op )
75+ available .extend (self ._remove_op_from_predsucc (op ))
76+
77+ if update_available :
78+ self .available [chosen_type ].difference_update (chosen_ops )
79+ for op in available :
80+ self .available [type (op )].add (op )
81+ if not self .available [chosen_type ]:
82+ del self .available [chosen_type ]
5183
5284 def _remove_op_from_predsucc (self , op ):
85+ available = []
5386 for op2 in self .successors_of [op ]:
5487 preds = self .predecessors_of [op2 ]
5588 preds .remove (op )
5689 if len (preds ) == 0 :
57- self .available [type (op2 )].add (op2 )
90+ available .append (op2 )
91+ for op2 in self .predecessors_of [op ]:
92+ self .successors_of [op2 ].remove (op )
5893 del self .predecessors_of [op ]
5994 del self .successors_of [op ]
95+ return available
6096
6197
6298def _greedy_nonoverlapping (ops ):
@@ -130,3 +166,93 @@ def greedy_planner(operators):
130166 assert len (operators ) == sum (len (p [1 ]) for p in rval )
131167 # print('greedy_planner: Program len:', len(rval))
132168 return rval
169+
170+
171+ def parallel_planner (operators ):
172+ """Plan order of operators by determining parallel sets and optimizing.
173+
174+ TODO
175+ """
176+ from hunse_tools .timing import tic , toc
177+
178+ edges = operator_depencency_graph (operators )
179+
180+ is_op = lambda op : isinstance (op , Operator )
181+ for op , dests in iteritems (edges ):
182+ assert is_op (op ) and all (is_op (op2 ) for op2 in dests )
183+
184+ deps = DependencyTracker (operators , edges )
185+
186+ # --- determine all successors of each op
187+ temp_deps = deps .copy ()
188+ all_successors = {}
189+ while temp_deps .successors_of :
190+ queued = [
191+ op for op , succs in iteritems (temp_deps .successors_of ) if not succs ]
192+ assert queued
193+
194+ for op in queued :
195+ op_preds = set ()
196+ for a in deps .successors_of [op ]:
197+ op_preds .add (a )
198+ op_preds .update (all_successors [a ])
199+
200+ all_successors [op ] = op_preds
201+
202+ temp_deps .remove_ops (queued , update_available = False )
203+
204+ # --- determine which operators of the same type are independent (parallel)
205+ operators_by_type = defaultdict (set )
206+ for op in operators :
207+ operators_by_type [type (op )].add (op )
208+
209+ parallel_by_type = {}
210+ for op_type , ops in iteritems (operators_by_type ):
211+ ops = set (ops )
212+
213+ groups = []
214+ while ops :
215+ op = ops .pop ()
216+ group = set ([op ])
217+ for op2 in ops :
218+ if (op2 not in all_successors [op ] and
219+ op not in all_successors [op2 ]):
220+ group .add (op2 )
221+ groups .append (group )
222+
223+ parallel_by_type [op_type ] = groups
224+
225+ rval = []
226+ while len (deps .predecessors_of ) > 0 :
227+ if len (deps .available ) == 0 :
228+ raise ValueError ("Cycles in the op graph" )
229+
230+ fracs = {}
231+ for op_type , ops in iteritems (deps .available ):
232+ for parallel_ops in parallel_by_type [op_type ]:
233+ if ops .issubset (parallel_ops ):
234+ break
235+ else :
236+ raise ValueError ("Could not find superset group" )
237+
238+ if len (ops ) == len (parallel_ops ):
239+ fracs [op_type ] = - 1 # to choose this first
240+ else :
241+ fracs [op_type ] = float (len (ops )) / len (parallel_ops )
242+
243+ # chosen_type = sorted(iteritems(fracs), key=lambda x: x[1])[-1][0]
244+ chosen_type = sorted (iteritems (fracs ), key = lambda x : x [1 ])[0 ][0 ]
245+ candidates = deps .available [chosen_type ]
246+
247+ # --- greedily pick non-overlapping ops
248+ chosen_ops = _greedy_nonoverlapping (candidates )
249+
250+ # --- schedule ops
251+ assert chosen_ops
252+ rval .append ((chosen_type , chosen_ops ))
253+
254+ # --- update predecessors and successors of unsheduled ops
255+ deps .remove_op_type (chosen_type , chosen_ops )
256+
257+ assert len (operators ) == sum (len (p [1 ]) for p in rval )
258+ return rval
0 commit comments