Skip to content

Commit 4c7e4f5

Browse files
committed
Added parallel planner
1 parent 8e55d55 commit 4c7e4f5

File tree

1 file changed

+138
-12
lines changed

1 file changed

+138
-12
lines changed

nengo_ocl/planners.py

Lines changed: 138 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,35 +28,71 @@ def __init__(self, ops, edges):
2828
self.successors_of = successors_of
2929
self.available = available
3030

31+
def copy(self):
32+
copy = type(self).__new__(type(self))
33+
copy.predecessors_of = {op: set(s) for op, s in iteritems(self.predecessors_of)}
34+
copy.successors_of = {op: set(s) for op, s in iteritems(self.successors_of)}
35+
copy.available = defaultdict(set)
36+
for op_type, s in iteritems(self.available):
37+
copy.available[op_type] = set(s)
38+
return copy
39+
3140
def greedy_choose_type(self):
3241
chosen_type = sorted(
3342
self.available.items(), key=lambda x: len(x[1]))[-1][0]
3443
candidates = self.available[chosen_type]
3544
return chosen_type, candidates
3645

37-
def remove_op(self, op):
38-
if type(op) in self.available:
39-
self.available[type(op)].difference_update([op])
40-
if not self.available[type(op)]:
41-
del self.available[type(op)]
42-
self._remove_op_from_predsucc(op)
46+
def remove_op(self, op, update_available=True):
47+
available = self._remove_op_from_predsucc(op)
48+
49+
if update_available:
50+
if type(op) in self.available:
51+
self.available[type(op)].difference_update([op])
52+
if not self.available[type(op)]:
53+
del self.available[type(op)]
54+
for op in available:
55+
self.available[type(op)].add(op)
56+
57+
def remove_ops(self, ops, update_available=True):
58+
available = []
59+
for op in ops:
60+
available.extend(self._remove_op_from_predsucc(op))
4361

44-
def remove_op_type(self, chosen_type, chosen_ops):
45-
self.available[chosen_type].difference_update(chosen_ops)
46-
if not self.available[chosen_type]:
47-
del self.available[chosen_type]
62+
if update_available:
63+
for op in ops:
64+
if type(op) in self.available:
65+
self.available[type(op)].difference_update([op])
66+
for op in available:
67+
self.available[type(op)].add(op)
68+
for op_type, s in self.available:
69+
if not s:
70+
del self.available[op_type]
4871

72+
def remove_op_type(self, chosen_type, chosen_ops, update_available=True):
73+
available = []
4974
for op in chosen_ops:
50-
self._remove_op_from_predsucc(op)
75+
available.extend(self._remove_op_from_predsucc(op))
76+
77+
if update_available:
78+
self.available[chosen_type].difference_update(chosen_ops)
79+
for op in available:
80+
self.available[type(op)].add(op)
81+
if not self.available[chosen_type]:
82+
del self.available[chosen_type]
5183

5284
def _remove_op_from_predsucc(self, op):
85+
available = []
5386
for op2 in self.successors_of[op]:
5487
preds = self.predecessors_of[op2]
5588
preds.remove(op)
5689
if len(preds) == 0:
57-
self.available[type(op2)].add(op2)
90+
available.append(op2)
91+
for op2 in self.predecessors_of[op]:
92+
self.successors_of[op2].remove(op)
5893
del self.predecessors_of[op]
5994
del self.successors_of[op]
95+
return available
6096

6197

6298
def _greedy_nonoverlapping(ops):
@@ -130,3 +166,93 @@ def greedy_planner(operators):
130166
assert len(operators) == sum(len(p[1]) for p in rval)
131167
# print('greedy_planner: Program len:', len(rval))
132168
return rval
169+
170+
171+
def parallel_planner(operators):
172+
"""Plan order of operators by determining parallel sets and optimizing.
173+
174+
TODO
175+
"""
176+
from hunse_tools.timing import tic, toc
177+
178+
edges = operator_depencency_graph(operators)
179+
180+
is_op = lambda op: isinstance(op, Operator)
181+
for op, dests in iteritems(edges):
182+
assert is_op(op) and all(is_op(op2) for op2 in dests)
183+
184+
deps = DependencyTracker(operators, edges)
185+
186+
# --- determine all successors of each op
187+
temp_deps = deps.copy()
188+
all_successors = {}
189+
while temp_deps.successors_of:
190+
queued = [
191+
op for op, succs in iteritems(temp_deps.successors_of) if not succs]
192+
assert queued
193+
194+
for op in queued:
195+
op_preds = set()
196+
for a in deps.successors_of[op]:
197+
op_preds.add(a)
198+
op_preds.update(all_successors[a])
199+
200+
all_successors[op] = op_preds
201+
202+
temp_deps.remove_ops(queued, update_available=False)
203+
204+
# --- determine which operators of the same type are independent (parallel)
205+
operators_by_type = defaultdict(set)
206+
for op in operators:
207+
operators_by_type[type(op)].add(op)
208+
209+
parallel_by_type = {}
210+
for op_type, ops in iteritems(operators_by_type):
211+
ops = set(ops)
212+
213+
groups = []
214+
while ops:
215+
op = ops.pop()
216+
group = set([op])
217+
for op2 in ops:
218+
if (op2 not in all_successors[op] and
219+
op not in all_successors[op2]):
220+
group.add(op2)
221+
groups.append(group)
222+
223+
parallel_by_type[op_type] = groups
224+
225+
rval = []
226+
while len(deps.predecessors_of) > 0:
227+
if len(deps.available) == 0:
228+
raise ValueError("Cycles in the op graph")
229+
230+
fracs = {}
231+
for op_type, ops in iteritems(deps.available):
232+
for parallel_ops in parallel_by_type[op_type]:
233+
if ops.issubset(parallel_ops):
234+
break
235+
else:
236+
raise ValueError("Could not find superset group")
237+
238+
if len(ops) == len(parallel_ops):
239+
fracs[op_type] = -1 # to choose this first
240+
else:
241+
fracs[op_type] = float(len(ops)) / len(parallel_ops)
242+
243+
# chosen_type = sorted(iteritems(fracs), key=lambda x: x[1])[-1][0]
244+
chosen_type = sorted(iteritems(fracs), key=lambda x: x[1])[0][0]
245+
candidates = deps.available[chosen_type]
246+
247+
# --- greedily pick non-overlapping ops
248+
chosen_ops = _greedy_nonoverlapping(candidates)
249+
250+
# --- schedule ops
251+
assert chosen_ops
252+
rval.append((chosen_type, chosen_ops))
253+
254+
# --- update predecessors and successors of unsheduled ops
255+
deps.remove_op_type(chosen_type, chosen_ops)
256+
257+
assert len(operators) == sum(len(p[1]) for p in rval)
258+
return rval

0 commit comments

Comments
 (0)