Skip to content

Commit f012589

Browse files
Merge pull request #2446 from devitocodes/cse-tuplets
compiler: Add various performance optimization variants
2 parents 40b081e + b1a10ee commit f012589

File tree

18 files changed

+661
-409
lines changed

18 files changed

+661
-409
lines changed

devito/core/cpu.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ def _normalize_kwargs(cls, **kwargs):
3636
# Fusion
3737
o['fuse-tasks'] = oo.pop('fuse-tasks', False)
3838

39-
# CSE
39+
# Flops minimization
4040
o['cse-min-cost'] = oo.pop('cse-min-cost', cls.CSE_MIN_COST)
41+
o['cse-algo'] = oo.pop('cse-algo', cls.CSE_ALGO)
42+
o['fact-schedule'] = oo.pop('fact-schedule', cls.FACT_SCHEDULE)
4143

4244
# Blocking
4345
o['blockinner'] = oo.pop('blockinner', False)
@@ -168,14 +170,14 @@ def _specialize_clusters(cls, clusters, **kwargs):
168170

169171
# Reduce flops
170172
clusters = cire(clusters, 'sops', sregistry, options, platform)
171-
clusters = factorize(clusters)
173+
clusters = factorize(clusters, **kwargs)
172174
clusters = optimize_pows(clusters)
173175

174176
# The previous passes may have created fusion opportunities
175177
clusters = fuse(clusters)
176178

177179
# Reduce flops
178-
clusters = cse(clusters, sregistry, options)
180+
clusters = cse(clusters, **kwargs)
179181

180182
# Blocking to improve data locality
181183
if options['blocklazy']:

devito/core/gpu.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ def _normalize_kwargs(cls, **kwargs):
4646
# Fusion
4747
o['fuse-tasks'] = oo.pop('fuse-tasks', False)
4848

49-
# CSE
49+
# Flops minimization
5050
o['cse-min-cost'] = oo.pop('cse-min-cost', cls.CSE_MIN_COST)
51+
o['cse-algo'] = oo.pop('cse-algo', cls.CSE_ALGO)
52+
o['fact-schedule'] = oo.pop('fact-schedule', cls.FACT_SCHEDULE)
5153

5254
# Blocking
5355
o['blockinner'] = oo.pop('blockinner', True)
@@ -196,14 +198,14 @@ def _specialize_clusters(cls, clusters, **kwargs):
196198

197199
# Reduce flops
198200
clusters = cire(clusters, 'sops', sregistry, options, platform)
199-
clusters = factorize(clusters)
201+
clusters = factorize(clusters, **kwargs)
200202
clusters = optimize_pows(clusters)
201203

202204
# The previous passes may have created fusion opportunities
203205
clusters = fuse(clusters)
204206

205207
# Reduce flops
206-
clusters = cse(clusters, sregistry, options)
208+
clusters = cse(clusters, **kwargs)
207209

208210
# Blocking to define thread blocks
209211
if options['blocklazy']:

devito/core/operator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ class BasicOperator(Operator):
2727
common sub=expression.
2828
"""
2929

30+
CSE_ALGO = 'basic'
31+
"""
32+
The algorithm to use for common sub-expression elimination.
33+
"""
34+
35+
FACT_SCHEDULE = 'basic'
36+
"""
37+
The schedule to use for the computation of factorizations.
38+
"""
39+
3040
BLOCK_LEVELS = 1
3141
"""
3242
Loop blocking depth. So, 1 => "blocks", 2 => "blocks" and "sub-blocks",
@@ -159,6 +169,9 @@ def _check_kwargs(cls, **kwargs):
159169
if oo['mpi'] and oo['mpi'] not in cls.MPI_MODES:
160170
raise InvalidOperator("Unsupported MPI mode `%s`" % oo['mpi'])
161171

172+
if oo['cse-algo'] not in ('basic', 'smartsort', 'advanced'):
173+
raise InvalidArgument("Illegal `cse-algo` value")
174+
162175
if oo['deriv-schedule'] not in ('basic', 'smart'):
163176
raise InvalidArgument("Illegal `deriv-schedule` value")
164177
if oo['deriv-unroll'] not in (False, 'inner', 'full'):

devito/finite_differences/differentiable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ def __new__(cls, *args, **kwargs):
511511
# set of basic simplifications
512512

513513
# (a+b)+c -> a+b+c (flattening)
514+
# TODO: use symbolics.flatten_args; not using it to avoid a circular import
514515
nested, others = split(args, lambda e: isinstance(e, Add))
515516
args = flatten(e.args for e in nested) + list(others)
516517

@@ -533,6 +534,7 @@ def __new__(cls, *args, **kwargs):
533534
# to avoid generating functional, but ugly, code
534535

535536
# (a*b)*c -> a*b*c (flattening)
537+
# TODO: use symbolics.flatten_args; not using it to avoid a circular import
536538
nested, others = split(args, lambda e: isinstance(e, Mul))
537539
args = flatten(e.args for e in nested) + list(others)
538540

devito/ir/clusters/visitors.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,12 @@ def __init__(self, func, mode='dense'):
204204
else:
205205
self.cond = lambda c: True
206206

207-
def __call__(self, *args):
207+
def __call__(self, *args, **kwargs):
208208
if timed_pass.is_enabled():
209-
maybe_timed = lambda *_args: timed_pass(self.func, self.func.__name__)(*_args)
209+
maybe_timed = lambda *_args: \
210+
timed_pass(self.func, self.func.__name__)(*_args, **kwargs)
210211
else:
211-
maybe_timed = lambda *_args: self.func(*_args)
212+
maybe_timed = lambda *_args: self.func(*_args, **kwargs)
212213
args = list(args)
213214
maybe_clusters = args.pop(0)
214215
if isinstance(maybe_clusters, Iterable):

devito/ir/support/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,11 @@ def detect_accesses(exprs):
186186
other_dims = set()
187187
for e in as_tuple(exprs):
188188
other_dims.update(i for i in e.free_symbols if isinstance(i, Dimension))
189-
other_dims.update(e.implicit_dims or {})
189+
try:
190+
other_dims.update(e.implicit_dims or {})
191+
except AttributeError:
192+
# Not a types.Eq
193+
pass
190194
other_dims = filter_sorted(other_dims)
191195
mapper[None] = Stencil([(i, 0) for i in other_dims])
192196

0 commit comments

Comments
 (0)