Skip to content

Commit

Permalink
Merge pull request #2310 from devitocodes/optimize-coeffs
Browse files Browse the repository at this point in the history
compiler: Revamp lowering of IndexDerivatives
  • Loading branch information
FabioLuporini authored Feb 22, 2024
2 parents 2f27c9d + 8f22f90 commit ccfb823
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 78 deletions.
8 changes: 6 additions & 2 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ def _normalize_kwargs(cls, **kwargs):
# Distributed parallelism
o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN)

# Misc
# Code generation options for derivatives
o['expand'] = oo.pop('expand', cls.EXPAND)
o['optcomms'] = oo.pop('optcomms', True)
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
o['deriv-unroll'] = oo.pop('deriv-unroll', False)

# Misc
o['opt-comms'] = oo.pop('opt-comms', True)
o['linearize'] = oo.pop('linearize', False)
o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE)
o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE)
Expand Down
8 changes: 6 additions & 2 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,13 @@ def _normalize_kwargs(cls, **kwargs):
# Distributed parallelism
o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN)

# Misc
# Code generation options for derivatives
o['expand'] = oo.pop('expand', cls.EXPAND)
o['optcomms'] = oo.pop('optcomms', True)
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
o['deriv-unroll'] = oo.pop('deriv-unroll', False)

# Misc
o['opt-comms'] = oo.pop('opt-comms', True)
o['linearize'] = oo.pop('linearize', False)
o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE)
o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE)
Expand Down
13 changes: 12 additions & 1 deletion devito/core/operator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Iterable

from devito.core.autotuning import autotune
from devito.exceptions import InvalidOperator
from devito.exceptions import InvalidArgument, InvalidOperator
from devito.logger import warning
from devito.mpi.routines import mpi_registry
from devito.parameters import configuration
Expand Down Expand Up @@ -96,6 +96,12 @@ class BasicOperator(Operator):
finite-difference derivatives.
"""

DERIV_SCHEDULE = 'basic'
"""
The schedule to use for the computation of finite-difference derivatives.
Only meaningful when `EXPAND=False`.
"""

MPI_MODES = tuple(mpi_registry)
"""
The supported MPI modes.
Expand Down Expand Up @@ -144,6 +150,11 @@ def _check_kwargs(cls, **kwargs):
if oo['mpi'] and oo['mpi'] not in cls.MPI_MODES:
raise InvalidOperator("Unsupported MPI mode `%s`" % oo['mpi'])

if oo['deriv-schedule'] not in ('basic', 'smart'):
raise InvalidArgument("Illegal `deriv-schedule` value")
if oo['deriv-unroll'] not in (False, 'inner', 'full'):
raise InvalidArgument("Illegal `deriv-unroll` value")

def _autotune(self, args, setup):
if setup in [False, 'off']:
return args
Expand Down
35 changes: 24 additions & 11 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from devito.logger import warning
from devito.tools import (as_tuple, filter_ordered, flatten, frozendict,
infer_dtype, is_integer, split)
from devito.types import (Array, DimensionTuple, Evaluable, Indexed, Spacing,
from devito.types import (Array, DimensionTuple, Evaluable, Indexed,
StencilDimension)

__all__ = ['Differentiable', 'IndexDerivative', 'EvalDerivative', 'Weights']
__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
'Weights']


class Differentiable(sympy.Expr, Evaluable):
Expand Down Expand Up @@ -252,6 +253,14 @@ def __eq__(self, other):
return all(getattr(self, i, None) == getattr(other, i, None)
for i in self.__rkwargs__)

def _hashable_content(self):
# SymPy computes the hash of all Basic objects as:
# `hash((type(self).__name__,) + self._hashable_content())`
# However, our subclasses will be named after the main SymPy classes,
# for example sympy.Add -> differentiable.Add, so we need to override
# the hashable content to specify it's our own subclasses
return super()._hashable_content() + ('differentiable',)

@property
def name(self):
return "".join(f.name for f in self._functions)
Expand Down Expand Up @@ -583,7 +592,7 @@ class Mod(DifferentiableOp, sympy.Mod):
__sympy_class__ = sympy.Mod


class IndexSum(DifferentiableOp):
class IndexSum(sympy.Expr, Evaluable):

"""
Represent the summation over a multiindex, that is a collection of
Expand Down Expand Up @@ -685,8 +694,6 @@ def __init_finalize__(self, *args, **kwargs):
from devito.symbolics import pow_to_mul # noqa, sigh
weights = tuple(pow_to_mul(sympy.sympify(i)) for i in weights)

self._spacings = set().union(*[i.find(Spacing) for i in weights])

kwargs['scope'] = kwargs.get('scope', 'stack')
kwargs['initvalue'] = weights

Expand All @@ -708,10 +715,6 @@ def _hashable_content(self):
def dimension(self):
return self.dimensions[0]

@property
def spacings(self):
return self._spacings

weights = Array.initvalue

def _xreplace(self, rule):
Expand Down Expand Up @@ -803,9 +806,13 @@ def _evaluate(self, **kwargs):
return EvalDerivative(expr, base=self.base)


class DiffDerivative(IndexDerivative, DifferentiableOp):
pass


# SymPy args ordering is the same for Derivatives and IndexDerivatives
ordering_of_classes.insert(ordering_of_classes.index('Derivative') + 1,
'IndexDerivative')
for i in ('DiffDerivative', 'IndexDerivative'):
ordering_of_classes.insert(ordering_of_classes.index('Derivative') + 1, i)


class EvalDerivative(DifferentiableOp, sympy.Add):
Expand Down Expand Up @@ -917,6 +924,12 @@ def _diff2sympy(obj):
ax, af = _diff2sympy(a)
args.append(ax)
flag |= af

# Handle special objects
if isinstance(obj, DiffDerivative):
return IndexDerivative(*args, obj.mapper), True

# Handle generic objects such as arithmetic operations
try:
return obj.__sympy_class__(*args, evaluate=False), True
except AttributeError:
Expand Down
4 changes: 2 additions & 2 deletions devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sympy import sympify

from .differentiable import EvalDerivative, IndexDerivative, Weights
from .differentiable import EvalDerivative, DiffDerivative, Weights
from .tools import (numeric_weights, symbolic_weights, left, right,
generate_indices, centered, direct, transpose,
check_input, check_symbolic)
Expand Down Expand Up @@ -247,7 +247,7 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic
# Pure number
pass

deriv = IndexDerivative(expr*weights, {dim: indices.free_dim})
deriv = DiffDerivative(expr*weights, {dim: indices.free_dim})
else:
terms = []
for i, c in zip(indices, weights):
Expand Down
13 changes: 6 additions & 7 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from collections.abc import Iterable

from sympy import sympify

from devito.symbolics import retrieve_indexed, uxreplace, retrieve_dimensions
from devito.tools import Ordering, as_tuple, flatten, filter_sorted, filter_ordered
from devito.types import Dimension, IgnoreDimSort
Expand Down Expand Up @@ -89,7 +87,7 @@ def handle_indexed(indexed):
return ordering


def lower_exprs(expressions, **kwargs):
def lower_exprs(expressions, subs=None, **kwargs):
"""
Lowering an expression consists of the following passes:
Expand All @@ -101,9 +99,10 @@ def lower_exprs(expressions, **kwargs):
--------
f(x - 2*h_x, y) -> f[xi + 2, yi + 4] (assuming halo_size=4)
"""
# Normalize subs
subs = {k: sympify(v) for k, v in kwargs.get('subs', {}).items()}
return _lower_exprs(expressions, subs or {})


def _lower_exprs(expressions, subs):
processed = []
for expr in as_tuple(expressions):
try:
Expand All @@ -113,15 +112,15 @@ def lower_exprs(expressions, **kwargs):
dimension_map = {}

# Handle Functions (typical case)
mapper = {f: lower_exprs(f.indexify(subs=dimension_map), **kwargs)
mapper = {f: _lower_exprs(f.indexify(subs=dimension_map), subs)
for f in expr.find(AbstractFunction)}

# Handle Indexeds (from index notation)
for i in retrieve_indexed(expr):
f = i.function

# Introduce shifting to align with the computational domain
indices = [(lower_exprs(a) + o) for a, o in
indices = [_lower_exprs(a, subs) + o for a, o in
zip(i.indices, f._size_nodomain.left)]

# Substitute spacing (spacing only used in own dimension)
Expand Down
6 changes: 5 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections import OrderedDict, namedtuple
import ctypes
from operator import attrgetter
from math import ceil

from cached_property import cached_property
import ctypes
from sympy import sympify

from devito.arch import compiler_registry, platform_registry
from devito.data import default_allocator
Expand Down Expand Up @@ -1219,4 +1220,7 @@ def parse_kwargs(**kwargs):
kwargs['platform'])
)

# Normalize `subs`, if any
kwargs['subs'] = {k: sympify(v) for k, v in kwargs.get('subs', {}).items()}

return kwargs
Loading

0 comments on commit ccfb823

Please sign in to comment.