diff --git a/devito/core/cpu.py b/devito/core/cpu.py index 561e00e5dd..81611da9b4 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -65,9 +65,13 @@ def _normalize_kwargs(cls, **kwargs): # Distributed parallelism o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN) - # Misc + # Code generation options for derivatives o['expand'] = oo.pop('expand', cls.EXPAND) - o['optcomms'] = oo.pop('optcomms', True) + o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE) + o['deriv-unroll'] = oo.pop('deriv-unroll', False) + + # Misc + o['opt-comms'] = oo.pop('opt-comms', True) o['linearize'] = oo.pop('linearize', False) o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE) o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE) diff --git a/devito/core/gpu.py b/devito/core/gpu.py index 0d3b1969fa..de070c6890 100644 --- a/devito/core/gpu.py +++ b/devito/core/gpu.py @@ -80,9 +80,13 @@ def _normalize_kwargs(cls, **kwargs): # Distributed parallelism o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN) - # Misc + # Code generation options for derivatives o['expand'] = oo.pop('expand', cls.EXPAND) - o['optcomms'] = oo.pop('optcomms', True) + o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE) + o['deriv-unroll'] = oo.pop('deriv-unroll', False) + + # Misc + o['opt-comms'] = oo.pop('opt-comms', True) o['linearize'] = oo.pop('linearize', False) o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE) o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE) diff --git a/devito/core/operator.py b/devito/core/operator.py index 5d7f886a74..1ba976d3b6 100644 --- a/devito/core/operator.py +++ b/devito/core/operator.py @@ -1,7 +1,7 @@ from collections.abc import Iterable from devito.core.autotuning import autotune -from devito.exceptions import InvalidOperator +from devito.exceptions import InvalidArgument, InvalidOperator from devito.logger import warning from devito.mpi.routines import mpi_registry from devito.parameters import configuration @@ -96,6 +96,12 @@ class BasicOperator(Operator): finite-difference derivatives. """ + DERIV_SCHEDULE = 'basic' + """ + The schedule to use for the computation of finite-difference derivatives. + Only meaningful when `EXPAND=False`. + """ + MPI_MODES = tuple(mpi_registry) """ The supported MPI modes. @@ -144,6 +150,11 @@ def _check_kwargs(cls, **kwargs): if oo['mpi'] and oo['mpi'] not in cls.MPI_MODES: raise InvalidOperator("Unsupported MPI mode `%s`" % oo['mpi']) + if oo['deriv-schedule'] not in ('basic', 'smart'): + raise InvalidArgument("Illegal `deriv-schedule` value") + if oo['deriv-unroll'] not in (False, 'inner', 'full'): + raise InvalidArgument("Illegal `deriv-unroll` value") + def _autotune(self, args, setup): if setup in [False, 'off']: return args diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 6cb4d4b45d..94c3139335 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -15,10 +15,11 @@ from devito.logger import warning from devito.tools import (as_tuple, filter_ordered, flatten, frozendict, infer_dtype, is_integer, split) -from devito.types import (Array, DimensionTuple, Evaluable, Indexed, Spacing, +from devito.types import (Array, DimensionTuple, Evaluable, Indexed, StencilDimension) -__all__ = ['Differentiable', 'IndexDerivative', 'EvalDerivative', 'Weights'] +__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative', + 'Weights'] class Differentiable(sympy.Expr, Evaluable): @@ -252,6 +253,14 @@ def __eq__(self, other): return all(getattr(self, i, None) == getattr(other, i, None) for i in self.__rkwargs__) + def _hashable_content(self): + # SymPy computes the hash of all Basic objects as: + # `hash((type(self).__name__,) + self._hashable_content())` + # However, our subclasses will be named after the main SymPy classes, + # for example sympy.Add -> differentiable.Add, so we need to override + # the hashable content to specify it's our own subclasses + return super()._hashable_content() + ('differentiable',) + @property def name(self): return "".join(f.name for f in self._functions) @@ -583,7 +592,7 @@ class Mod(DifferentiableOp, sympy.Mod): __sympy_class__ = sympy.Mod -class IndexSum(DifferentiableOp): +class IndexSum(sympy.Expr, Evaluable): """ Represent the summation over a multiindex, that is a collection of @@ -685,8 +694,6 @@ def __init_finalize__(self, *args, **kwargs): from devito.symbolics import pow_to_mul # noqa, sigh weights = tuple(pow_to_mul(sympy.sympify(i)) for i in weights) - self._spacings = set().union(*[i.find(Spacing) for i in weights]) - kwargs['scope'] = kwargs.get('scope', 'stack') kwargs['initvalue'] = weights @@ -708,10 +715,6 @@ def _hashable_content(self): def dimension(self): return self.dimensions[0] - @property - def spacings(self): - return self._spacings - weights = Array.initvalue def _xreplace(self, rule): @@ -803,9 +806,13 @@ def _evaluate(self, **kwargs): return EvalDerivative(expr, base=self.base) +class DiffDerivative(IndexDerivative, DifferentiableOp): + pass + + # SymPy args ordering is the same for Derivatives and IndexDerivatives -ordering_of_classes.insert(ordering_of_classes.index('Derivative') + 1, - 'IndexDerivative') +for i in ('DiffDerivative', 'IndexDerivative'): + ordering_of_classes.insert(ordering_of_classes.index('Derivative') + 1, i) class EvalDerivative(DifferentiableOp, sympy.Add): @@ -917,6 +924,12 @@ def _diff2sympy(obj): ax, af = _diff2sympy(a) args.append(ax) flag |= af + + # Handle special objects + if isinstance(obj, DiffDerivative): + return IndexDerivative(*args, obj.mapper), True + + # Handle generic objects such as arithmetic operations try: return obj.__sympy_class__(*args, evaluate=False), True except AttributeError: diff --git a/devito/finite_differences/finite_difference.py b/devito/finite_differences/finite_difference.py index 5749a82ae4..4b2274580e 100644 --- a/devito/finite_differences/finite_difference.py +++ b/devito/finite_differences/finite_difference.py @@ -1,6 +1,6 @@ from sympy import sympify -from .differentiable import EvalDerivative, IndexDerivative, Weights +from .differentiable import EvalDerivative, DiffDerivative, Weights from .tools import (numeric_weights, symbolic_weights, left, right, generate_indices, centered, direct, transpose, check_input, check_symbolic) @@ -247,7 +247,7 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic # Pure number pass - deriv = IndexDerivative(expr*weights, {dim: indices.free_dim}) + deriv = DiffDerivative(expr*weights, {dim: indices.free_dim}) else: terms = [] for i, c in zip(indices, weights): diff --git a/devito/ir/equations/algorithms.py b/devito/ir/equations/algorithms.py index c27ef5c54f..ab1d80af8b 100644 --- a/devito/ir/equations/algorithms.py +++ b/devito/ir/equations/algorithms.py @@ -1,7 +1,5 @@ from collections.abc import Iterable -from sympy import sympify - from devito.symbolics import retrieve_indexed, uxreplace, retrieve_dimensions from devito.tools import Ordering, as_tuple, flatten, filter_sorted, filter_ordered from devito.types import Dimension, IgnoreDimSort @@ -89,7 +87,7 @@ def handle_indexed(indexed): return ordering -def lower_exprs(expressions, **kwargs): +def lower_exprs(expressions, subs=None, **kwargs): """ Lowering an expression consists of the following passes: @@ -101,9 +99,10 @@ def lower_exprs(expressions, **kwargs): -------- f(x - 2*h_x, y) -> f[xi + 2, yi + 4] (assuming halo_size=4) """ - # Normalize subs - subs = {k: sympify(v) for k, v in kwargs.get('subs', {}).items()} + return _lower_exprs(expressions, subs or {}) + +def _lower_exprs(expressions, subs): processed = [] for expr in as_tuple(expressions): try: @@ -113,7 +112,7 @@ def lower_exprs(expressions, **kwargs): dimension_map = {} # Handle Functions (typical case) - mapper = {f: lower_exprs(f.indexify(subs=dimension_map), **kwargs) + mapper = {f: _lower_exprs(f.indexify(subs=dimension_map), subs) for f in expr.find(AbstractFunction)} # Handle Indexeds (from index notation) @@ -121,7 +120,7 @@ def lower_exprs(expressions, **kwargs): f = i.function # Introduce shifting to align with the computational domain - indices = [(lower_exprs(a) + o) for a, o in + indices = [_lower_exprs(a, subs) + o for a, o in zip(i.indices, f._size_nodomain.left)] # Substitute spacing (spacing only used in own dimension) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 234ac83a5d..df03ca7653 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -1,9 +1,10 @@ from collections import OrderedDict, namedtuple +import ctypes from operator import attrgetter from math import ceil from cached_property import cached_property -import ctypes +from sympy import sympify from devito.arch import compiler_registry, platform_registry from devito.data import default_allocator @@ -1219,4 +1220,7 @@ def parse_kwargs(**kwargs): kwargs['platform']) ) + # Normalize `subs`, if any + kwargs['subs'] = {k: sympify(v) for k, v in kwargs.get('subs', {}).items()} + return kwargs diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 4992657aaf..f8f339aa1e 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -1,10 +1,13 @@ +from functools import singledispatch + +from sympy import S + from devito.finite_differences import IndexDerivative from devito.ir import Backward, Forward, Interval, IterationSpace, Queue from devito.passes.clusters.misc import fuse -from devito.symbolics import (retrieve_dimensions, reuse_if_untouched, q_leaf, - uxreplace) -from devito.tools import filter_ordered, timed_pass -from devito.types import Eq, Inc, StencilDimension, Symbol +from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace +from devito.tools import infer_dtype, timed_pass +from devito.types import Eq, Inc, Indexed, Symbol __all__ = ['lower_index_derivatives'] @@ -23,12 +26,13 @@ def lower_index_derivatives(clusters, mode=None, **kwargs): # previously were just not detectable via e.g. plain CSE. For example, if # there were two IndexDerivatives such as `(p.dx + m.dx).dx` and `m.dx.dx` # then it's only after `_lower_index_derivatives` that they're detectable! + # TODO: see https://github.com/devitocodes/devito/issues/2306 clusters = CDE(mapper).process(clusters) return clusters -def _lower_index_derivatives(clusters, sregistry=None, **kwargs): +def _lower_index_derivatives(clusters, **kwargs): weights = {} processed = [] mapper = {} @@ -48,7 +52,7 @@ def dump(exprs, c): else: reusable = set() - expr, v = _core(e, c, weights, reusable, mapper, sregistry) + expr, v = _core(e, c, c.ispace, weights, reusable, mapper, **kwargs) if v: dump(exprs, c) @@ -70,87 +74,137 @@ def dump(exprs, c): return processed, weights, mapper -def _core(expr, c, weights, reusables, mapper, sregistry): +@singledispatch +def _core(expr, c, ispace, weights, reusables, mapper, **kwargs): """ - Recursively carry out the core of `lower_index_derivatives`. + Recursively carry out the core of `lower_index_derivatives` based + on single-dispatch. """ - if q_leaf(expr): - return expr, [] - args = [] processed = [] for a in expr.args: - e, clusters = _core(a, c, weights, reusables, mapper, sregistry) + e, clusters = _core(a, c, ispace, weights, reusables, mapper, **kwargs) args.append(e) processed.extend(clusters) expr = reuse_if_untouched(expr, args) - if not isinstance(expr, IndexDerivative): - return expr, processed + return expr, processed + + +@_core.register(Symbol) +@_core.register(Indexed) +@_core.register(BasicWrapperMixin) +def _(expr, c, ispace, weights, reusables, mapper, **kwargs): + return expr, [] + + +@_core.register(IndexDerivative) +def _(expr, c, ispace, weights, reusables, mapper, **kwargs): + sregistry = kwargs['sregistry'] + options = kwargs['options'] + subs_user = kwargs['subs'] - # Create concrete Weights and reuse them whenever possible + try: + cbk0 = deriv_schedule_registry[options['deriv-schedule']] + cbk1 = deriv_unroll_registry[options['deriv-unroll']] + except KeyError: + raise ValueError("Unknown derivative lowering mode") + + # Lower the IndexDerivative + init, ideriv = cbk0(expr) + + # Create the concrete Weights array, or reuse an already existing one + # if possible name = sregistry.make_name(prefix='w') - w0 = expr.weights.function + w0 = ideriv.weights.function + dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32 k = tuple(w0.weights) try: w = weights[k] except KeyError: - w = weights[k] = w0._rebuild(name=name, dtype=expr.dtype) - expr = uxreplace(expr, {w0.indexed: w.indexed}) + initvalue = tuple(i.subs(subs_user) for i in k) + w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue) - dims = retrieve_dimensions(expr, deep=True) - dims = filter_ordered(d for d in dims if isinstance(d, StencilDimension)) + # Replace the abstract Weights array with the concrete one + subs = {w0.indexed: w.indexed} + init = uxreplace(init, subs) + ideriv = uxreplace(ideriv, subs) - dims = tuple(reversed(dims)) - - # If a StencilDimension already appears in `c.ispace`, perhaps with its custom - # upper and lower offsets, we honor it - dims = tuple(d for d in dims if d not in c.ispace) + # The IterationSpace in which the IndexDerivative will be computed + dims = ideriv.dimensions intervals = [Interval(d) for d in dims] directions = {d: Backward if d.backward else Forward for d in dims} ispace0 = IterationSpace(intervals, directions=directions) - extra = (c.ispace.itdims + dims,) - ispace = IterationSpace.union(c.ispace, ispace0, relations=extra) - - # Set the IterationSpace along the StencilDimensions to start from 0 - # (rather than the default `d._min`) to minimize the amount of integer - # arithmetic to calculate the various index access functions + # Minimize the amount of integer arithmetic to calculate the various index + # access functions by enforcing start at 0, e.g. `r0[x + i0 + 2] -> r0[x + i0]` + base = ideriv.base for d in dims: - ispace = ispace.translate(d, -d._min) + ispace0 = ispace0.translate(d, -d._min) + base = base.subs(d, d + d._min) + ideriv = ideriv._subs(ideriv.base, base) + # Should the IndexDerivative be unrolled? + init, expr, ispace0 = cbk1(init, ideriv, ispace0) + + # The full IterationSpace + extra = (ispace.itdims + ispace0.itdims,) + ispace1 = IterationSpace.union(ispace, ispace0, relations=extra) + + # The Symbol that will hold the result of the IndexDerivative computation + # NOTE: created before recurring so that we ultimately get a sound ordering try: s = reusables.pop() - assert s.dtype is w.dtype + assert s.dtype is dtype except KeyError: name = sregistry.make_name(prefix='r') - s = Symbol(name=name, dtype=w.dtype) - expr0 = Eq(s, 0.) - ispace1 = ispace.project(lambda d: d is not dims[-1]) - processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace1)) - - # Transform e.g. `r0[x + i0 + 2, y] -> r0[x + i0, y, z]` for alignment - # with the shifted `ispace` - base = expr.base - for d in dims: - base = base.subs(d, d + d._min) - expr1 = Inc(s, base*expr.weights) - processed.append(c.rebuild(exprs=expr1, ispace=ispace)) + s = Symbol(name=name, dtype=dtype) + + # Go inside `expr` and recursively lower any nested IndexDerivatives + expr, processed = _core(expr, c, ispace1, weights, reusables, mapper, **kwargs) + + # Finally inject the lowered IndexDerivative + if init is not None: + expr0 = Eq(s, init) + processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace)) + + expr1 = Inc(s, expr) + processed.append(c.rebuild(exprs=expr1, ispace=ispace1)) + else: + expr1 = Eq(s, expr) + processed.append(c.rebuild(exprs=expr1, ispace=ispace1)) - # Track lowered IndexDerivative for subsequent optimization by the caller - mapper.setdefault(expr1.rhs, []).append(s) + # Track the lowered IndexDerivative for subsequent optimization by the caller + mapper.setdefault(expr, []).append(s) return s, processed +def _lower_index_derivative_base(ideriv): + return S.Zero, ideriv + + +deriv_schedule_registry = { + 'basic': _lower_index_derivative_base, +} + + +deriv_unroll_registry = { + False: lambda init, ideriv, ispace: (init, ideriv.expr, ispace) +} + + class CDE(Queue): """ Common derivative elimination. """ + _q_guards_in_key = True + _q_syncs_in_key = True + def __init__(self, mapper): super().__init__() diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index a5b97b37ff..0ac3ab51d5 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -345,7 +345,7 @@ def mpiize(graph, **kwargs): """ options = kwargs['options'] - if options['optcomms']: + if options['opt-comms']: optimize_halospots(graph, **kwargs) mpimode = options['mpi'] diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index 52c84956be..33ada6ae5e 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -708,9 +708,10 @@ def test_index_derivative(self): grid = Grid((10,)) x, = grid.dimensions + so = 2 i = StencilDimension('i', 0, 2) - u = Function(name="u", grid=grid, space_order=2) + u = Function(name="u", grid=grid, space_order=so) ui = u.subs(x, x + i*x.spacing) w = Weights(name='w0', dimensions=i, initvalue=[-0.5, 0, 0.5]) @@ -720,7 +721,7 @@ def test_index_derivative(self): assert idxder.evaluate == -0.5*u + 0.5*ui.subs(i, 2) # Make sure subs works as expected - v = Function(name="v", grid=grid, space_order=2) + v = Function(name="v", grid=grid, space_order=so) vi0 = v.subs(x, x + i*x.spacing) vi1 = idxder.subs(ui, vi0) diff --git a/tests/test_operator.py b/tests/test_operator.py index 0801eb7da0..73fb633a88 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -1572,7 +1572,7 @@ def test_consistency_anti_dependences(self, exprs, directions, expected, visit): # Note: `topofuse` is a subset of `advanced` mode. We use it merely to # bypass 'blocking', which would complicate the asserts below - op = Operator(eqns, opt=('topofuse', {'openmp': False, 'optcomms': False})) + op = Operator(eqns, opt=('topofuse', {'openmp': False, 'opt-comms': False})) trees = retrieve_iteration_tree(op) iters = FindNodes(Iteration).visit(op) diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index 186f806926..3caa2dbe9a 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -324,10 +324,12 @@ def test_redundant_derivatives(self): assert len(get_arrays(op)) == 0 assert op._profiler._sections['section0'].sops == 74 exprs = FindNodes(Expression).visit(op) - assert len(exprs) == 6 + assert len(exprs) == 5 temps = [i for i in FindSymbols().visit(exprs) if isinstance(i, Symbol)] assert len(temps) == 2 + op.cfunction + class Test2Pass(object):