Skip to content

Commit

Permalink
Merge pull request #2446 from devitocodes/cse-tuplets
Browse files Browse the repository at this point in the history
compiler: Add various performance optimization variants
  • Loading branch information
FabioLuporini authored Sep 2, 2024
2 parents 40b081e + b1a10ee commit f012589
Show file tree
Hide file tree
Showing 18 changed files with 661 additions and 409 deletions.
8 changes: 5 additions & 3 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def _normalize_kwargs(cls, **kwargs):
# Fusion
o['fuse-tasks'] = oo.pop('fuse-tasks', False)

# CSE
# Flops minimization
o['cse-min-cost'] = oo.pop('cse-min-cost', cls.CSE_MIN_COST)
o['cse-algo'] = oo.pop('cse-algo', cls.CSE_ALGO)
o['fact-schedule'] = oo.pop('fact-schedule', cls.FACT_SCHEDULE)

# Blocking
o['blockinner'] = oo.pop('blockinner', False)
Expand Down Expand Up @@ -168,14 +170,14 @@ def _specialize_clusters(cls, clusters, **kwargs):

# Reduce flops
clusters = cire(clusters, 'sops', sregistry, options, platform)
clusters = factorize(clusters)
clusters = factorize(clusters, **kwargs)
clusters = optimize_pows(clusters)

# The previous passes may have created fusion opportunities
clusters = fuse(clusters)

# Reduce flops
clusters = cse(clusters, sregistry, options)
clusters = cse(clusters, **kwargs)

# Blocking to improve data locality
if options['blocklazy']:
Expand Down
8 changes: 5 additions & 3 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ def _normalize_kwargs(cls, **kwargs):
# Fusion
o['fuse-tasks'] = oo.pop('fuse-tasks', False)

# CSE
# Flops minimization
o['cse-min-cost'] = oo.pop('cse-min-cost', cls.CSE_MIN_COST)
o['cse-algo'] = oo.pop('cse-algo', cls.CSE_ALGO)
o['fact-schedule'] = oo.pop('fact-schedule', cls.FACT_SCHEDULE)

# Blocking
o['blockinner'] = oo.pop('blockinner', True)
Expand Down Expand Up @@ -196,14 +198,14 @@ def _specialize_clusters(cls, clusters, **kwargs):

# Reduce flops
clusters = cire(clusters, 'sops', sregistry, options, platform)
clusters = factorize(clusters)
clusters = factorize(clusters, **kwargs)
clusters = optimize_pows(clusters)

# The previous passes may have created fusion opportunities
clusters = fuse(clusters)

# Reduce flops
clusters = cse(clusters, sregistry, options)
clusters = cse(clusters, **kwargs)

# Blocking to define thread blocks
if options['blocklazy']:
Expand Down
13 changes: 13 additions & 0 deletions devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ class BasicOperator(Operator):
common sub=expression.
"""

CSE_ALGO = 'basic'
"""
The algorithm to use for common sub-expression elimination.
"""

FACT_SCHEDULE = 'basic'
"""
The schedule to use for the computation of factorizations.
"""

BLOCK_LEVELS = 1
"""
Loop blocking depth. So, 1 => "blocks", 2 => "blocks" and "sub-blocks",
Expand Down Expand Up @@ -159,6 +169,9 @@ def _check_kwargs(cls, **kwargs):
if oo['mpi'] and oo['mpi'] not in cls.MPI_MODES:
raise InvalidOperator("Unsupported MPI mode `%s`" % oo['mpi'])

if oo['cse-algo'] not in ('basic', 'smartsort', 'advanced'):
raise InvalidArgument("Illegal `cse-algo` value")

if oo['deriv-schedule'] not in ('basic', 'smart'):
raise InvalidArgument("Illegal `deriv-schedule` value")
if oo['deriv-unroll'] not in (False, 'inner', 'full'):
Expand Down
2 changes: 2 additions & 0 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ def __new__(cls, *args, **kwargs):
# set of basic simplifications

# (a+b)+c -> a+b+c (flattening)
# TODO: use symbolics.flatten_args; not using it to avoid a circular import
nested, others = split(args, lambda e: isinstance(e, Add))
args = flatten(e.args for e in nested) + list(others)

Expand All @@ -533,6 +534,7 @@ def __new__(cls, *args, **kwargs):
# to avoid generating functional, but ugly, code

# (a*b)*c -> a*b*c (flattening)
# TODO: use symbolics.flatten_args; not using it to avoid a circular import
nested, others = split(args, lambda e: isinstance(e, Mul))
args = flatten(e.args for e in nested) + list(others)

Expand Down
7 changes: 4 additions & 3 deletions devito/ir/clusters/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,12 @@ def __init__(self, func, mode='dense'):
else:
self.cond = lambda c: True

def __call__(self, *args):
def __call__(self, *args, **kwargs):
if timed_pass.is_enabled():
maybe_timed = lambda *_args: timed_pass(self.func, self.func.__name__)(*_args)
maybe_timed = lambda *_args: \
timed_pass(self.func, self.func.__name__)(*_args, **kwargs)
else:
maybe_timed = lambda *_args: self.func(*_args)
maybe_timed = lambda *_args: self.func(*_args, **kwargs)
args = list(args)
maybe_clusters = args.pop(0)
if isinstance(maybe_clusters, Iterable):
Expand Down
6 changes: 5 additions & 1 deletion devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ def detect_accesses(exprs):
other_dims = set()
for e in as_tuple(exprs):
other_dims.update(i for i in e.free_symbols if isinstance(i, Dimension))
other_dims.update(e.implicit_dims or {})
try:
other_dims.update(e.implicit_dims or {})
except AttributeError:
# Not a types.Eq
pass
other_dims = filter_sorted(other_dims)
mapper[None] = Stencil([(i, 0) for i in other_dims])

Expand Down
Loading

0 comments on commit f012589

Please sign in to comment.