From 52b519161d03214c686562a164573f11db1e3904 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 29 Jul 2024 14:42:17 +0000 Subject: [PATCH 01/16] compiler: Fixup factorization --- devito/passes/clusters/factorization.py | 28 ++++++++++++++++++++++--- tests/test_dse.py | 14 +++++++++---- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index 33253e245e..0fb94b610a 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -3,7 +3,8 @@ from sympy import Add, Mul, S, collect from devito.ir import cluster_pass -from devito.symbolics import BasicWrapperMixin, estimate_cost, retrieve_symbols +from devito.symbolics import (BasicWrapperMixin, estimate_cost, reuse_if_untouched, + retrieve_symbols, q_routine) from devito.tools import ReducerMap from devito.types.object import AbstractObject @@ -72,6 +73,17 @@ def collect_special(expr): else: terms.append(i) + # Any factorization possible? + if not any(len(i) > 1 for i in [w_funcs, w_pows, w_coeffs]): + # `evaluate=True` below guarantees de-nesting of Adds + # For example: + # `args[0] = ((0.1*(a + b)) + (0.2*(c + d)))` + # `args[1] = ((0.1*(e + f)) + (0.2*(g + h)))` + # -> + # `expr = (0.1*(a + b) + 0.2*(c + d) + 0.1*(e + f) + 0.2*(g + h))` + # Which is then further factorizable by `collect_const` + return reuse_if_untouched(expr, args, evaluate=True) + # Collect common funcs if len(w_funcs) > 1: w_funcs = Add(*w_funcs, evaluate=False) @@ -96,7 +108,7 @@ def collect_special(expr): # Collect common temporaries (r0, r1, ...) w_coeffs = Add(*w_coeffs, evaluate=False) symbols = retrieve_symbols(w_coeffs) - if symbols: + if len(set(symbols)) != len(symbols): w_coeffs = collect(w_coeffs, symbols, evaluate=False) try: terms.extend([Mul(k, collect_const(v), evaluate=False) @@ -128,6 +140,11 @@ def collect_const(expr): else: inverse_mapper[-v].append(-k) + # Any factorization possible? + if len(inverse_mapper) == len(expr.args) or \ + list(inverse_mapper) == [1]: + return expr + terms = [] for k, v in inverse_mapper.items(): if len(v) == 1 and not v[0].is_Add: @@ -176,6 +193,10 @@ def _collect_nested(expr): if expr.is_Number: return expr, {'coeffs': expr} + elif q_routine(expr): + # E.g., a DefFunction + args, candidates = zip(*[_collect_nested(arg) for arg in expr.args]) + return expr.func(*args, evaluate=False), {} elif expr.is_Function: return expr, {'funcs': expr} elif expr.is_Pow: @@ -187,7 +208,8 @@ def _collect_nested(expr): return strategies['default'](expr), {} elif expr.is_Mul: args, candidates = zip(*[_collect_nested(arg) for arg in expr.args]) - return Mul(*args), ReducerMap.fromdicts(*candidates) + expr = reuse_if_untouched(expr, args, evaluate=True) + return expr, ReducerMap.fromdicts(*candidates) elif expr.is_Equality: args, candidates = zip(*[_collect_nested(expr.lhs), _collect_nested(expr.rhs)]) diff --git a/tests/test_dse.py b/tests/test_dse.py index b9c419eb01..4e77ef11ff 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -13,7 +13,7 @@ ConditionalDimension, DefaultDimension, Grid, Operator, norm, grad, div, dimensions, switchconfig, configuration, centered, first_derivative, solve, transpose, Abs, cos, - sin, sqrt, Ge, Lt) + sin, sqrt, floor, Ge, Lt) from devito.exceptions import InvalidArgument, InvalidOperator from devito.finite_differences.differentiable import diffify from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes, @@ -276,13 +276,19 @@ def test_pow_to_mul(expr, expected): @pytest.mark.parametrize('expr,expected', [ ('s - SizeOf("int")*fa[x]', 's - fa[x]*sizeof(int)'), + ('foo(4*fa[x] + 4*fb[x])', 'foo(4*(fa[x] + fb[x]))'), + ('floor(0.1*a + 0.1*fa[x])', 'floor(0.1*(a + fa[x]))'), + ('floor(0.1*(a + fa[x]))', 'floor(0.1*(a + fa[x]))'), ]) def test_factorize(expr, expected): grid = Grid((4, 5)) x, y = grid.dimensions - s = Scalar(name='s') # noqa + s = Scalar(name='s', dtype=np.float32) # noqa + a = Symbol(name='a', dtype=np.float32) # noqa fa = Function(name='fa', grid=grid, dimensions=(x,), shape=(4,)) # noqa + fb = Function(name='fb', grid=grid, dimensions=(x,), shape=(4,)) # noqa + foo = lambda *args: DefFunction('foo', tuple(args)) # noqa assert str(collect_nested(eval(expr))) == expected @@ -2205,9 +2211,9 @@ def test_nested_first_derivatives_unbalanced(self): ('v.dx.dx + p.dx.dx', (2, 2, (0, 2)), (61, 61, 25)), ('(v.dx + v.dy).dx - (v.dx + v.dy).dy + 2*f.dx.dx + f*f.dy.dy + f.dx.dx(x0=1)', - (3, 3, (0, 3)), (218, 202, 75)), + (3, 3, (0, 3)), (218, 202, 66)), ('(g*(1 + f)*v.dx).dx + (2*g*f*v.dx).dx', - (1, 1, (0, 1)), (52, 70, 20)), + (1, 1, (0, 1)), (52, 66, 20)), ('g*(f.dx.dx + g.dx.dx)', (1, 2, (0, 1)), (47, 62, 17)), ]) From 7d14a06b0cdcfa6ee3913205010ff3cb9f1f5cc8 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 24 Jul 2024 13:11:44 +0000 Subject: [PATCH 02/16] tests: Move CSE tests into a separate module --- tests/test_cse.py | 199 ++++++++++++++++++++++++++++++++++++++++++++++ tests/test_dse.py | 192 +------------------------------------------- 2 files changed, 201 insertions(+), 190 deletions(-) create mode 100644 tests/test_cse.py diff --git a/tests/test_cse.py b/tests/test_cse.py new file mode 100644 index 0000000000..3fecb03acc --- /dev/null +++ b/tests/test_cse.py @@ -0,0 +1,199 @@ +import pytest + +from sympy import Ge, Lt +from sympy.core.mul import _mulsort + +from conftest import assert_structure +from devito import (Grid, Function, TimeFunction, ConditionalDimension, Eq, # noqa + Operator, cos) +from devito.finite_differences.differentiable import diffify +from devito.ir import DummyEq, FindNodes, FindSymbols, Conditional +from devito.ir.support import generator +from devito.passes.clusters.cse import CTemp, _cse +from devito.symbolics import indexify +from devito.types import Array, Symbol, Temp + + +@pytest.mark.parametrize('exprs,expected,min_cost', [ + # Simple cases + (['Eq(tu, 2/(t0 + t1))', 'Eq(ti0, t0 + t1)', 'Eq(ti1, t0 + t1)'], + ['t0 + t1', '2/r0', 'r0', 'r0'], 0), + (['Eq(tu, 2/(t0 + t1))', 'Eq(ti0, 2/(t0 + t1) + 1)', 'Eq(ti1, 2/(t0 + t1) + 1)'], + ['2/(t0 + t1)', 'r1', 'r1 + 1', 'r0', 'r0'], 0), + (['Eq(tu, (tv + tw + 5.)*(ti0 + ti1) + (t0 + t1)*(ti0 + ti1))'], + ['ti0[x, y, z] + ti1[x, y, z]', + 'r0*(t0 + t1) + r0*(tv[t, x, y, z] + tw[t, x, y, z] + 5.0)'], 0), + (['Eq(tu, t0/t1)', 'Eq(ti0, 2 + t0/t1)', 'Eq(ti1, 2 + t0/t1)'], + ['t0/t1', 'r1', 'r1 + 2', 'r0', 'r0'], 0), + # Across expressions + (['Eq(tu, tv*4 + tw*5 + tw*5*t0)', 'Eq(tv, tw*5)'], + ['5*tw[t, x, y, z]', 'r0 + 5*t0*tw[t, x, y, z] + 4*tv[t, x, y, z]', 'r0'], 0), + # Intersecting + pytest.param(['Eq(tu, ti0*ti1 + ti0*ti1*t0 + ti0*ti1*t0*t1)'], + ['ti0*ti1', 'r0', 'r0*t0', 'r0*t0*t1'], 0, + marks=pytest.mark.xfail), + # Divisions (== powers with negative exponenet) are always captured + (['Eq(tu, tv**-1*(tw*5 + tw*5*t0))', 'Eq(ti0, tv**-1*t0)'], + ['1/tv[t, x, y, z]', 'r0*(5*t0*tw[t, x, y, z] + 5*tw[t, x, y, z])', 'r0*t0'], 0), + # `compact_temporaries` must detect chains of isolated temporaries + (['Eq(t0, tv)', 'Eq(t1, t0)', 'Eq(t2, t1)', 'Eq(tu, t2)'], + ['tv[t, x, y, z]'], 0), + # Dimension-independent flow+anti dependences should be a stopper for CSE + (['Eq(t0, cos(t1))', 'Eq(t1, 5)', 'Eq(t2, cos(t1))'], + ['cos(t1)', '5', 'cos(t1)'], 0), + (['Eq(tu, tv + 1)', 'Eq(tv, tu)', 'Eq(tw, tv + 1)'], + ['tv[t, x, y, z] + 1', 'tu[t, x, y, z]', 'tv[t, x, y, z] + 1'], 0), + # Dimension-independent flow (but not anti) dependences are OK instead as + # long as the temporaries are introduced after the write + (['Eq(tu.forward, tu.dx + 1)', 'Eq(tv.forward, tv.dx + 1)', + 'Eq(tw.forward, tv.dt + 1)', 'Eq(tz.forward, tv.dt + 2)'], + ['1/h_x', '-r1*tu[t, x, y, z] + r1*tu[t, x + 1, y, z] + 1', + '-r1*tv[t, x, y, z] + r1*tv[t, x + 1, y, z] + 1', + '1/dt', '-r2*tv[t, x, y, z] + r2*tv[t + 1, x, y, z]', + 'r0 + 1', 'r0 + 2'], 0), + # Fancy use case with lots of temporaries + (['Eq(tu.forward, tu.dx + 1)', 'Eq(tv.forward, tv.dx + 1)', + 'Eq(tw.forward, tv.dt.dx2.dy2 + 1)', 'Eq(tz.forward, tv.dt.dy2.dx2 + 2)'], + ['1/h_x', '-r11*tu[t, x, y, z] + r11*tu[t, x + 1, y, z] + 1', + '-r11*tv[t, x, y, z] + r11*tv[t, x + 1, y, z] + 1', + '1/dt', '-r12*tv[t, x - 1, y - 1, z] + r12*tv[t + 1, x - 1, y - 1, z]', + '-r12*tv[t, x + 1, y - 1, z] + r12*tv[t + 1, x + 1, y - 1, z]', + '-r12*tv[t, x, y - 1, z] + r12*tv[t + 1, x, y - 1, z]', + '-r12*tv[t, x - 1, y + 1, z] + r12*tv[t + 1, x - 1, y + 1, z]', + '-r12*tv[t, x + 1, y + 1, z] + r12*tv[t + 1, x + 1, y + 1, z]', + '-r12*tv[t, x, y + 1, z] + r12*tv[t + 1, x, y + 1, z]', + '-r12*tv[t, x - 1, y, z] + r12*tv[t + 1, x - 1, y, z]', + '-r12*tv[t, x + 1, y, z] + r12*tv[t + 1, x + 1, y, z]', + '-r12*tv[t, x, y, z] + r12*tv[t + 1, x, y, z]', + 'h_x**(-2)', '-2.0*r13', 'h_y**(-2)', '-2.0*r14', + 'r10*(r13*r6 + r13*r7 + r8*r9) + r14*(r0*r13 + r1*r13 + r2*r9) + ' + + 'r14*(r13*r3 + r13*r4 + r5*r9) + 1', + 'r13*(r0*r14 + r10*r6 + r14*r3) + r13*(r1*r14 + r10*r7 + r14*r4) + ' + + 'r9*(r10*r8 + r14*r2 + r14*r5) + 2'], 0), + # Existing temporaries from nested Function as index + (['Eq(e0, fx[x])', 'Eq(tu, cos(-tu[t, e0, y, z]) + tv[t, x, y, z])', + 'Eq(tv, cos(tu[t, e0, y, z]) + tw)'], + ['fx[x]', 'cos(tu[t, e0, y, z])', 'r0 + tv[t, x, y, z]', 'r0 + tw[t, x, y, z]'], 0), + # Make sure -x isn't factorized with default minimum cse cost + (['Eq(e0, fx[x])', 'Eq(tu, -tu[t, e0, y, z] + tv[t, x, y, z])', + 'Eq(tv, -tu[t, e0, y, z] + tw)'], + ['fx[x]', '-tu[t, e0, y, z] + tv[t, x, y, z]', + '-tu[t, e0, y, z] + tw[t, x, y, z]'], 1) +]) +def test_default_algo(exprs, expected, min_cost): + """Test common subexpressions elimination.""" + grid = Grid((3, 3, 3)) + x, y, z = grid.dimensions + t = grid.stepping_dim # noqa + + tu = TimeFunction(name="tu", grid=grid, space_order=2) # noqa + tv = TimeFunction(name="tv", grid=grid, space_order=2) # noqa + tw = TimeFunction(name="tw", grid=grid, space_order=2) # noqa + tz = TimeFunction(name="tz", grid=grid, space_order=2) # noqa + fx = Function(name="fx", grid=grid, dimensions=(x,), shape=(3,)) # noqa + ti0 = Array(name='ti0', shape=(3, 5, 7), dimensions=(x, y, z)).indexify() # noqa + ti1 = Array(name='ti1', shape=(3, 5, 7), dimensions=(x, y, z)).indexify() # noqa + t0 = CTemp(name='t0') # noqa + t1 = CTemp(name='t1') # noqa + t2 = CTemp(name='t2') # noqa + # Needs to not be a Temp to mimic nested index extraction and prevent + # cse to compact the temporary back. + e0 = Symbol(name='e0') # noqa + + # List comprehension would need explicit locals/globals mappings to eval + for i, e in enumerate(list(exprs)): + exprs[i] = DummyEq(indexify(diffify(eval(e).evaluate))) + + counter = generator() + make = lambda: CTemp(name='r%d' % counter()).indexify() + processed = _cse(exprs, make, min_cost) + + assert len(processed) == len(expected) + assert all(str(i.rhs) == j for i, j in zip(processed, expected)) + + +def test_temp_order(): + # Test order of classes inserted to Sympy's core ordering + a = Temp(name='r6') + b = CTemp(name='r6') + c = Symbol(name='r6') + + args = [b, a, c] + + _mulsort(args) + + assert type(args[0]) is Symbol + assert type(args[1]) is Temp + assert type(args[2]) is CTemp + + +def test_w_conditionals(): + grid = Grid(shape=(10, 10, 10)) + x, _, _ = grid.dimensions + + cd = ConditionalDimension(name='cd', parent=x, condition=Ge(x, 4), + indirect=True) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + h = Function(name='h', grid=grid) + a0 = Function(name='a0', grid=grid) + a1 = Function(name='a1', grid=grid) + + eqns = [Eq(h, a0, implicit_dims=cd), + Eq(a0, a0 + f*g, implicit_dims=cd), + Eq(a1, a1 + f*g, implicit_dims=cd)] + + op = Operator(eqns) + + assert_structure(op, ['x,y,z'], 'xyz') + assert len(FindNodes(Conditional).visit(op)) == 1 + + +def test_w_multi_conditionals(): + grid = Grid(shape=(10, 10, 10)) + x, _, _ = grid.dimensions + + cd = ConditionalDimension(name='cd', parent=x, condition=Ge(x, 4), + indirect=True) + + cd2 = ConditionalDimension(name='cd2', parent=x, condition=Lt(x, 4), + indirect=True) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + h = Function(name='h', grid=grid) + a0 = Function(name='a0', grid=grid) + a1 = Function(name='a1', grid=grid) + a2 = Function(name='a2', grid=grid) + a3 = Function(name='a3', grid=grid) + + eq0 = Eq(h, a0, implicit_dims=cd) + eq1 = Eq(a0, a0 + f*g, implicit_dims=cd) + eq2 = Eq(a1, a1 + f*g, implicit_dims=cd) + eq3 = Eq(a2, a2 + f*g, implicit_dims=cd2) + eq4 = Eq(a3, a3 + f*g, implicit_dims=cd2) + + op = Operator([eq0, eq1, eq3]) + + assert_structure(op, ['x,y,z'], 'xyz') + assert len(FindNodes(Conditional).visit(op)) == 2 + + tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')] + assert len(tmps) == 0 + + op = Operator([eq0, eq1, eq3, eq4]) + + assert_structure(op, ['x,y,z'], 'xyz') + assert len(FindNodes(Conditional).visit(op)) == 2 + + tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')] + assert len(tmps) == 1 + + op = Operator([eq0, eq1, eq2, eq3, eq4]) + + assert_structure(op, ['x,y,z'], 'xyz') + assert len(FindNodes(Conditional).visit(op)) == 2 + + tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')] + assert len(tmps) == 2 diff --git a/tests/test_dse.py b/tests/test_dse.py index 4e77ef11ff..a8c95ce393 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -4,7 +4,6 @@ import pytest from sympy import Mul # noqa -from sympy.core.mul import _mulsort from conftest import (skipif, EVAL, _R, assert_structure, assert_blocking, # noqa get_params, get_arrays, check_array) @@ -15,18 +14,16 @@ centered, first_derivative, solve, transpose, Abs, cos, sin, sqrt, floor, Ge, Lt) from devito.exceptions import InvalidArgument, InvalidOperator -from devito.finite_differences.differentiable import diffify from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes, FindSymbols, ParallelIteration, retrieve_iteration_tree) from devito.passes.clusters.aliases import collect from devito.passes.clusters.factorization import collect_nested -from devito.passes.clusters.cse import CTemp, _cse from devito.passes.iet.parpragma import VExpanded from devito.symbolics import (INT, FLOAT, DefFunction, FieldFromPointer, # noqa IndexedPointer, Keyword, SizeOf, estimate_cost, pow_to_mul, indexify) -from devito.tools import as_tuple, generator -from devito.types import Array, Scalar, Symbol, PrecomputedSparseTimeFunction, Temp +from devito.tools import as_tuple +from devito.types import Scalar, Symbol, PrecomputedSparseTimeFunction from examples.seismic.acoustic import AcousticWaveSolver from examples.seismic import demo_model, AcquisitionGeometry @@ -55,191 +52,6 @@ def test_scheduling_after_rewrite(): assert all(trees[1].root.dim is tree.root.dim for tree in trees[1:]) -@pytest.mark.parametrize('exprs,expected,min_cost', [ - # Simple cases - (['Eq(tu, 2/(t0 + t1))', 'Eq(ti0, t0 + t1)', 'Eq(ti1, t0 + t1)'], - ['t0 + t1', '2/r0', 'r0', 'r0'], 0), - (['Eq(tu, 2/(t0 + t1))', 'Eq(ti0, 2/(t0 + t1) + 1)', 'Eq(ti1, 2/(t0 + t1) + 1)'], - ['2/(t0 + t1)', 'r1', 'r1 + 1', 'r0', 'r0'], 0), - (['Eq(tu, (tv + tw + 5.)*(ti0 + ti1) + (t0 + t1)*(ti0 + ti1))'], - ['ti0[x, y, z] + ti1[x, y, z]', - 'r0*(t0 + t1) + r0*(tv[t, x, y, z] + tw[t, x, y, z] + 5.0)'], 0), - (['Eq(tu, t0/t1)', 'Eq(ti0, 2 + t0/t1)', 'Eq(ti1, 2 + t0/t1)'], - ['t0/t1', 'r1', 'r1 + 2', 'r0', 'r0'], 0), - # Across expressions - (['Eq(tu, tv*4 + tw*5 + tw*5*t0)', 'Eq(tv, tw*5)'], - ['5*tw[t, x, y, z]', 'r0 + 5*t0*tw[t, x, y, z] + 4*tv[t, x, y, z]', 'r0'], 0), - # Intersecting - pytest.param(['Eq(tu, ti0*ti1 + ti0*ti1*t0 + ti0*ti1*t0*t1)'], - ['ti0*ti1', 'r0', 'r0*t0', 'r0*t0*t1'], 0, - marks=pytest.mark.xfail), - # Divisions (== powers with negative exponenet) are always captured - (['Eq(tu, tv**-1*(tw*5 + tw*5*t0))', 'Eq(ti0, tv**-1*t0)'], - ['1/tv[t, x, y, z]', 'r0*(5*t0*tw[t, x, y, z] + 5*tw[t, x, y, z])', 'r0*t0'], 0), - # `compact_temporaries` must detect chains of isolated temporaries - (['Eq(t0, tv)', 'Eq(t1, t0)', 'Eq(t2, t1)', 'Eq(tu, t2)'], - ['tv[t, x, y, z]'], 0), - # Dimension-independent flow+anti dependences should be a stopper for CSE - (['Eq(t0, cos(t1))', 'Eq(t1, 5)', 'Eq(t2, cos(t1))'], - ['cos(t1)', '5', 'cos(t1)'], 0), - (['Eq(tu, tv + 1)', 'Eq(tv, tu)', 'Eq(tw, tv + 1)'], - ['tv[t, x, y, z] + 1', 'tu[t, x, y, z]', 'tv[t, x, y, z] + 1'], 0), - # Dimension-independent flow (but not anti) dependences are OK instead as - # long as the temporaries are introduced after the write - (['Eq(tu.forward, tu.dx + 1)', 'Eq(tv.forward, tv.dx + 1)', - 'Eq(tw.forward, tv.dt + 1)', 'Eq(tz.forward, tv.dt + 2)'], - ['1/h_x', '-r1*tu[t, x, y, z] + r1*tu[t, x + 1, y, z] + 1', - '-r1*tv[t, x, y, z] + r1*tv[t, x + 1, y, z] + 1', - '1/dt', '-r2*tv[t, x, y, z] + r2*tv[t + 1, x, y, z]', - 'r0 + 1', 'r0 + 2'], 0), - # Fancy use case with lots of temporaries - (['Eq(tu.forward, tu.dx + 1)', 'Eq(tv.forward, tv.dx + 1)', - 'Eq(tw.forward, tv.dt.dx2.dy2 + 1)', 'Eq(tz.forward, tv.dt.dy2.dx2 + 2)'], - ['1/h_x', '-r11*tu[t, x, y, z] + r11*tu[t, x + 1, y, z] + 1', - '-r11*tv[t, x, y, z] + r11*tv[t, x + 1, y, z] + 1', - '1/dt', '-r12*tv[t, x - 1, y - 1, z] + r12*tv[t + 1, x - 1, y - 1, z]', - '-r12*tv[t, x + 1, y - 1, z] + r12*tv[t + 1, x + 1, y - 1, z]', - '-r12*tv[t, x, y - 1, z] + r12*tv[t + 1, x, y - 1, z]', - '-r12*tv[t, x - 1, y + 1, z] + r12*tv[t + 1, x - 1, y + 1, z]', - '-r12*tv[t, x + 1, y + 1, z] + r12*tv[t + 1, x + 1, y + 1, z]', - '-r12*tv[t, x, y + 1, z] + r12*tv[t + 1, x, y + 1, z]', - '-r12*tv[t, x - 1, y, z] + r12*tv[t + 1, x - 1, y, z]', - '-r12*tv[t, x + 1, y, z] + r12*tv[t + 1, x + 1, y, z]', - '-r12*tv[t, x, y, z] + r12*tv[t + 1, x, y, z]', - 'h_x**(-2)', '-2.0*r13', 'h_y**(-2)', '-2.0*r14', - 'r10*(r13*r6 + r13*r7 + r8*r9) + r14*(r0*r13 + r1*r13 + r2*r9) + ' + - 'r14*(r13*r3 + r13*r4 + r5*r9) + 1', - 'r13*(r0*r14 + r10*r6 + r14*r3) + r13*(r1*r14 + r10*r7 + r14*r4) + ' + - 'r9*(r10*r8 + r14*r2 + r14*r5) + 2'], 0), - # Existing temporaries from nested Function as index - (['Eq(e0, fx[x])', 'Eq(tu, cos(-tu[t, e0, y, z]) + tv[t, x, y, z])', - 'Eq(tv, cos(tu[t, e0, y, z]) + tw)'], - ['fx[x]', 'cos(tu[t, e0, y, z])', 'r0 + tv[t, x, y, z]', 'r0 + tw[t, x, y, z]'], 0), - # Make sure -x isn't factorized with default minimum cse cost - (['Eq(e0, fx[x])', 'Eq(tu, -tu[t, e0, y, z] + tv[t, x, y, z])', - 'Eq(tv, -tu[t, e0, y, z] + tw)'], - ['fx[x]', '-tu[t, e0, y, z] + tv[t, x, y, z]', - '-tu[t, e0, y, z] + tw[t, x, y, z]'], 1) -]) -def test_cse(exprs, expected, min_cost): - """Test common subexpressions elimination.""" - grid = Grid((3, 3, 3)) - x, y, z = grid.dimensions - t = grid.stepping_dim # noqa - - tu = TimeFunction(name="tu", grid=grid, space_order=2) # noqa - tv = TimeFunction(name="tv", grid=grid, space_order=2) # noqa - tw = TimeFunction(name="tw", grid=grid, space_order=2) # noqa - tz = TimeFunction(name="tz", grid=grid, space_order=2) # noqa - fx = Function(name="fx", grid=grid, dimensions=(x,), shape=(3,)) # noqa - ti0 = Array(name='ti0', shape=(3, 5, 7), dimensions=(x, y, z)).indexify() # noqa - ti1 = Array(name='ti1', shape=(3, 5, 7), dimensions=(x, y, z)).indexify() # noqa - t0 = CTemp(name='t0') # noqa - t1 = CTemp(name='t1') # noqa - t2 = CTemp(name='t2') # noqa - # Needs to not be a Temp to mimic nested index extraction and prevent - # cse to compact the temporary back. - e0 = Symbol(name='e0') # noqa - - # List comprehension would need explicit locals/globals mappings to eval - for i, e in enumerate(list(exprs)): - exprs[i] = DummyEq(indexify(diffify(eval(e).evaluate))) - - counter = generator() - make = lambda: CTemp(name='r%d' % counter()).indexify() - processed = _cse(exprs, make, min_cost) - - assert len(processed) == len(expected) - assert all(str(i.rhs) == j for i, j in zip(processed, expected)) - - -def test_cse_temp_order(): - # Test order of classes inserted to Sympy's core ordering - a = Temp(name='r6') - b = CTemp(name='r6') - c = Symbol(name='r6') - - args = [b, a, c] - - _mulsort(args) - - assert type(args[0]) is Symbol - assert type(args[1]) is Temp - assert type(args[2]) is CTemp - - -def test_cse_w_conditionals(): - grid = Grid(shape=(10, 10, 10)) - x, _, _ = grid.dimensions - - cd = ConditionalDimension(name='cd', parent=x, condition=Ge(x, 4), - indirect=True) - - f = Function(name='f', grid=grid) - g = Function(name='g', grid=grid) - h = Function(name='h', grid=grid) - a0 = Function(name='a0', grid=grid) - a1 = Function(name='a1', grid=grid) - - eqns = [Eq(h, a0, implicit_dims=cd), - Eq(a0, a0 + f*g, implicit_dims=cd), - Eq(a1, a1 + f*g, implicit_dims=cd)] - - op = Operator(eqns) - - assert_structure(op, ['x,y,z'], 'xyz') - assert len(FindNodes(Conditional).visit(op)) == 1 - - -def test_cse_w_multi_conditionals(): - grid = Grid(shape=(10, 10, 10)) - x, _, _ = grid.dimensions - - cd = ConditionalDimension(name='cd', parent=x, condition=Ge(x, 4), - indirect=True) - - cd2 = ConditionalDimension(name='cd2', parent=x, condition=Lt(x, 4), - indirect=True) - - f = Function(name='f', grid=grid) - g = Function(name='g', grid=grid) - h = Function(name='h', grid=grid) - a0 = Function(name='a0', grid=grid) - a1 = Function(name='a1', grid=grid) - a2 = Function(name='a2', grid=grid) - a3 = Function(name='a3', grid=grid) - - eq0 = Eq(h, a0, implicit_dims=cd) - eq1 = Eq(a0, a0 + f*g, implicit_dims=cd) - eq2 = Eq(a1, a1 + f*g, implicit_dims=cd) - eq3 = Eq(a2, a2 + f*g, implicit_dims=cd2) - eq4 = Eq(a3, a3 + f*g, implicit_dims=cd2) - - op = Operator([eq0, eq1, eq3]) - - assert_structure(op, ['x,y,z'], 'xyz') - assert len(FindNodes(Conditional).visit(op)) == 2 - - tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')] - assert len(tmps) == 0 - - op = Operator([eq0, eq1, eq3, eq4]) - - assert_structure(op, ['x,y,z'], 'xyz') - assert len(FindNodes(Conditional).visit(op)) == 2 - - tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')] - assert len(tmps) == 1 - - op = Operator([eq0, eq1, eq2, eq3, eq4]) - - assert_structure(op, ['x,y,z'], 'xyz') - assert len(FindNodes(Conditional).visit(op)) == 2 - - tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')] - assert len(tmps) == 2 - - @pytest.mark.parametrize('expr,expected', [ ('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'), ('fa[x]**2', 'fa[x]*fa[x]'), From d9cb86fbc783ea9c3e3a78aab8ae0f524c8d80d0 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 24 Jul 2024 13:22:42 +0000 Subject: [PATCH 03/16] compiler: Add skeleton of cse_tuplets --- devito/passes/clusters/cse.py | 59 ++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index ba578d89e6..c97ebeee4d 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -32,12 +32,12 @@ class CTemp(Temp): @cluster_pass -def cse(cluster, sregistry, options, *args): +def cse(cluster, sregistry, options, *args, mode='default'): """ Common sub-expressions elimination (CSE). """ make = lambda: CTemp(name=sregistry.make_name(), dtype=cluster.dtype) - exprs = _cse(cluster, make, min_cost=options['cse-min-cost']) + exprs = _cse(cluster, make, min_cost=options['cse-min-cost'], mode=mode) return cluster.rebuild(exprs=exprs) @@ -55,28 +55,27 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): make : callable Build symbols to store temporary, redundant values. mode : str, optional - The CSE algorithm applied. Accepted: ['default']. + The CSE algorithm applied. Accepted: ['default', 'tuplets', 'all']. """ + assert mode in ('default', 'tuplets', 'all') # Note: not defaulting to SymPy's CSE() function for three reasons: - # - it also captures array index access functions (eg, i+1 in A[i+1] and B[i+1]); + # - it also captures array index access functions + # (e.g., i+1 in A[i+1] and B[i+1]); # - it sometimes "captures too much", losing factorization opportunities; # - very slow - # TODO: a second "sympy" mode will be provided, relying on SymPy's CSE() but - # also ensuring some form of post-processing - assert mode == 'default' # Only supported mode ATM # Accept Clusters, Eqs or even just exprs if isinstance(maybe_exprs, Cluster): - processed = list(maybe_exprs.exprs) + exprs = list(maybe_exprs.exprs) scope = maybe_exprs.scope else: maybe_exprs = as_list(maybe_exprs) if all(e.is_Equality for e in maybe_exprs): - processed = maybe_exprs + exprs = maybe_exprs scope = Scope(maybe_exprs) else: - processed = [Eq(make(), e) for e in maybe_exprs] + exprs = [Eq(make(), e) for e in maybe_exprs] scope = Scope([]) # Some sub-expressions aren't really "common" -- that's the case of Dimension- @@ -92,9 +91,24 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): d_anti = {i.source.access for i in scope.d_anti.independent()} exclude = d_flow & d_anti + if mode in ('default', 'all'): + exprs = _cse_default(exprs, exclude, make, min_cost) + if mode in ('tuplets', 'all'): + exprs = _cse_tuplets(exprs, exclude, make) + + # Drop useless temporaries (e.g., r0=r1) + processed = _compact_temporaries(exprs, exclude) + + return processed + + +def _cse_default(exprs, exclude, make, min_cost): + """ + The default common sub-expressions elimination algorithm. + """ while True: # Detect redundancies - counted = count(processed).items() + counted = count(exprs).items() targets = OrderedDict([(k, estimate_cost(k.expr, True)) for k, v in counted if v > 1]) # Rule out Dimension-independent data dependencies @@ -111,27 +125,34 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): # The extracted temporaries are inserted before the first expression # that contains it scheduled = [] - updated = [] - for e in processed: + processed = [] + for e in exprs: pe = e for k, v in chosen: if not k.conditionals == e.conditionals: continue pe, changed = _uxreplace(pe, {k.expr: v}) if changed and v not in scheduled: - updated.append(pe.func(v, k.expr, operation=None)) + processed.append(pe.func(v, k.expr, operation=None)) scheduled.append(v) - updated.append(pe) - processed = updated + processed.append(pe) + exprs = processed # Update `exclude` for the same reasons as above -- to rule out CSE across # Dimension-independent data dependences exclude.update(scheduled) - # At this point we may have useless temporaries (e.g., r0=r1). Let's drop them - processed = _compact_temporaries(processed, exclude) + return exprs - return processed + +def _cse_tuplets(exprs, exclude, make): + """ + The tuplets-based common sub-expressions elimination algorithm. + """ + while True: + break + + return exprs def _compact_temporaries(exprs, exclude): From e72037704c10bef91cfd800ce062ae9048f4318c Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 24 Jul 2024 15:56:41 +0000 Subject: [PATCH 04/16] compiler: Draft tuplets-based CSE algo --- devito/core/cpu.py | 3 +- devito/core/gpu.py | 3 +- devito/core/operator.py | 5 ++ devito/ir/clusters/visitors.py | 7 +- devito/passes/clusters/cse.py | 123 +++++++++++++++++++++++++-------- tests/test_cse.py | 32 ++++++++- 6 files changed, 139 insertions(+), 34 deletions(-) diff --git a/devito/core/cpu.py b/devito/core/cpu.py index dbf50c3e3b..01752daa53 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -38,6 +38,7 @@ def _normalize_kwargs(cls, **kwargs): # CSE o['cse-min-cost'] = oo.pop('cse-min-cost', cls.CSE_MIN_COST) + o['cse-algo'] = oo.pop('cse-algo', cls.CSE_ALGO) # Blocking o['blockinner'] = oo.pop('blockinner', False) @@ -175,7 +176,7 @@ def _specialize_clusters(cls, clusters, **kwargs): clusters = fuse(clusters) # Reduce flops - clusters = cse(clusters, sregistry, options) + clusters = cse(clusters, **kwargs) # Blocking to improve data locality if options['blocklazy']: diff --git a/devito/core/gpu.py b/devito/core/gpu.py index 05e4e6451d..4c594a09ac 100644 --- a/devito/core/gpu.py +++ b/devito/core/gpu.py @@ -48,6 +48,7 @@ def _normalize_kwargs(cls, **kwargs): # CSE o['cse-min-cost'] = oo.pop('cse-min-cost', cls.CSE_MIN_COST) + o['cse-algo'] = oo.pop('cse-algo', cls.CSE_ALGO) # Blocking o['blockinner'] = oo.pop('blockinner', True) @@ -203,7 +204,7 @@ def _specialize_clusters(cls, clusters, **kwargs): clusters = fuse(clusters) # Reduce flops - clusters = cse(clusters, sregistry, options) + clusters = cse(clusters, **kwargs) # Blocking to define thread blocks if options['blocklazy']: diff --git a/devito/core/operator.py b/devito/core/operator.py index c05d98b60e..0e95a6c90d 100644 --- a/devito/core/operator.py +++ b/devito/core/operator.py @@ -27,6 +27,11 @@ class BasicOperator(Operator): common sub=expression. """ + CSE_ALGO = 'default' + """ + The algorithm to use for common sub-expression elimination. + """ + BLOCK_LEVELS = 1 """ Loop blocking depth. So, 1 => "blocks", 2 => "blocks" and "sub-blocks", diff --git a/devito/ir/clusters/visitors.py b/devito/ir/clusters/visitors.py index 07d7fcb76e..441dce5f50 100644 --- a/devito/ir/clusters/visitors.py +++ b/devito/ir/clusters/visitors.py @@ -204,11 +204,12 @@ def __init__(self, func, mode='dense'): else: self.cond = lambda c: True - def __call__(self, *args): + def __call__(self, *args, **kwargs): if timed_pass.is_enabled(): - maybe_timed = lambda *_args: timed_pass(self.func, self.func.__name__)(*_args) + maybe_timed = lambda *_args: \ + timed_pass(self.func, self.func.__name__)(*_args, **kwargs) else: - maybe_timed = lambda *_args: self.func(*_args) + maybe_timed = lambda *_args: self.func(*_args, **kwargs) args = list(args) maybe_clusters = args.pop(0) if isinstance(maybe_clusters, Iterable): diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index c97ebeee4d..5f33767efb 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -1,4 +1,4 @@ -from collections import Counter, OrderedDict, namedtuple +from collections import Counter, OrderedDict, defaultdict, namedtuple from functools import singledispatch import sympy @@ -12,15 +12,16 @@ from devito.finite_differences.differentiable import IndexDerivative from devito.ir import Cluster, Scope, cluster_pass from devito.passes.clusters.utils import makeit_ssa -from devito.symbolics import estimate_cost, q_leaf -from devito.symbolics.manipulation import _uxreplace +from devito.symbolics import estimate_cost, q_leaf, search +from devito.symbolics.manipulation import Uxmapper, _uxreplace from devito.tools import as_list, frozendict from devito.types import Eq, Symbol, Temp __all__ = ['cse'] -Counted = namedtuple('Candidate', 'expr, conditionals') +Candidate = namedtuple('Candidate', 'expr conditionals sources') +Candidate.__new__.__defaults__ = (None, None, None) class CTemp(Temp): @@ -32,12 +33,14 @@ class CTemp(Temp): @cluster_pass -def cse(cluster, sregistry, options, *args, mode='default'): +def cse(cluster, sregistry=None, options=None, **kwargs): """ Common sub-expressions elimination (CSE). """ make = lambda: CTemp(name=sregistry.make_name(), dtype=cluster.dtype) - exprs = _cse(cluster, make, min_cost=options['cse-min-cost'], mode=mode) + exprs = _cse(cluster, make, + min_cost=options['cse-min-cost'], + mode=options['cse-algo']) return cluster.rebuild(exprs=exprs) @@ -55,9 +58,9 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): make : callable Build symbols to store temporary, redundant values. mode : str, optional - The CSE algorithm applied. Accepted: ['default', 'tuplets', 'all']. + The CSE algorithm applied. Accepted: ['default', 'tuplets', 'advanced']. """ - assert mode in ('default', 'tuplets', 'all') + assert mode in ('default', 'tuplets', 'advanced') # Note: not defaulting to SymPy's CSE() function for three reasons: # - it also captures array index access functions @@ -91,9 +94,9 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): d_anti = {i.source.access for i in scope.d_anti.independent()} exclude = d_flow & d_anti - if mode in ('default', 'all'): + if mode in ('default', 'advanced'): exprs = _cse_default(exprs, exclude, make, min_cost) - if mode in ('tuplets', 'all'): + if mode in ('tuplets', 'advanced'): exprs = _cse_tuplets(exprs, exclude, make) # Drop useless temporaries (e.g., r0=r1) @@ -111,6 +114,7 @@ def _cse_default(exprs, exclude, make, min_cost): counted = count(exprs).items() targets = OrderedDict([(k, estimate_cost(k.expr, True)) for k, v in counted if v > 1]) + # Rule out Dimension-independent data dependencies targets = OrderedDict([(k, v) for k, v in targets.items() if not k.expr.free_symbols & exclude]) @@ -122,21 +126,7 @@ def _cse_default(exprs, exclude, make, min_cost): chosen = [(k, make()) for k, v in targets.items() if v == hit] # Apply replacements - # The extracted temporaries are inserted before the first expression - # that contains it - scheduled = [] - processed = [] - for e in exprs: - pe = e - for k, v in chosen: - if not k.conditionals == e.conditionals: - continue - pe, changed = _uxreplace(pe, {k.expr: v}) - if changed and v not in scheduled: - processed.append(pe.func(v, k.expr, operation=None)) - scheduled.append(v) - processed.append(pe) - exprs = processed + exprs, scheduled = _inject_temporaries(exprs, chosen, exclude) # Update `exclude` for the same reasons as above -- to rule out CSE across # Dimension-independent data dependences @@ -148,13 +138,92 @@ def _cse_default(exprs, exclude, make, min_cost): def _cse_tuplets(exprs, exclude, make): """ The tuplets-based common sub-expressions elimination algorithm. + + This algo relies on SymPy's canonical ordering of operands. It extracts + sub-expressions of decreasing size that may or may not be redundant. + + Unlike the default algorithm, this one looks inside the individual operations. + However, it does so speculatively, as it doesn't attempt to estimate the cost + of the extracted sub-expressions, which would be an hard problem to solve. + + Another simplification is that we only explore operations whose operands are + leaves, i.e., symbols or indexed objects. + + Examples + -------- + Given the expression `a*b*c*d + c*d + a*b*c + a*b*e`, the following + sub-expressions are extracted: `r0 = a*b, r1 = r0*c`, which leads to the + following optimized expression: `r1*d + c*d + r1 + r0*e`. """ + key = lambda candidate: len(candidate.expr.args) + while True: - break + mapper = defaultdict(list) + for e in exprs: + try: + cond = e.conditionals + except AttributeError: + cond = None + + for op in (Add, Mul): + for i in search(e, op): + # The args are in canonical order (thanks to SymPy); let's pick + # the largest sub-expression that is not `i` itself + args = i.args[:-1] + + terms = [a for a in args if q_leaf(a)] + + if len(terms) > 1: + mapper[Candidate(op(*terms), cond)].append(i) + + #mapper = {k: v for k, v in mapper.items() + # if not k.expr.free_symbols & exclude} + + if not mapper: + break + + # Create temporaries of decreasing size + hit = max(mapper, key=key) + chosen = [(Candidate(i.expr, i.conditionals, sources), make()) + for i, sources in mapper.items() if key(i) == key(hit)] + + # Apply replacements + exprs, _ = _inject_temporaries(exprs, chosen, exclude) return exprs +def _inject_temporaries(exprs, chosen, exclude): + """ + Insert temporaries into the expression list such that they appear right + before the first expression that contains them. + """ + scheduled = [] + processed = [] + for e in exprs: + pe = e + for k, v in chosen: + if k.conditionals != e.conditionals: + continue + + if k.sources: + # Perform compound-based replacement, see uxreplace.__doc__ + args = list(k.expr.args) + pivot = args.pop(0) + compound = {pivot: v, **{a: None for a in args}} + subs = {i: compound for i in k.sources} + else: + subs = {k.expr: v} + + pe, changed = _uxreplace(pe, subs) + if changed and v not in scheduled: + processed.append(pe.func(v, k.expr, operation=None)) + scheduled.append(v) + processed.append(pe) + + return processed, scheduled + + def _compact_temporaries(exprs, exclude): """ Drop temporaries consisting of isolated symbols. @@ -209,7 +278,7 @@ def _(expr): cond = expr.conditionals except AttributeError: cond = frozendict() - return {Counted(e, cond): v for e, v in mapper.items()} + return {Candidate(e, cond): v for e, v in mapper.items()} @count.register(Indexed) diff --git a/tests/test_cse.py b/tests/test_cse.py index 3fecb03acc..3c61d89861 100644 --- a/tests/test_cse.py +++ b/tests/test_cse.py @@ -5,7 +5,7 @@ from conftest import assert_structure from devito import (Grid, Function, TimeFunction, ConditionalDimension, Eq, # noqa - Operator, cos) + Operator, cos, sin) from devito.finite_differences.differentiable import diffify from devito.ir import DummyEq, FindNodes, FindSymbols, Conditional from devito.ir.support import generator @@ -81,7 +81,7 @@ '-tu[t, e0, y, z] + tw[t, x, y, z]'], 1) ]) def test_default_algo(exprs, expected, min_cost): - """Test common subexpressions elimination.""" + """Test the default common subexpressions elimination algorithm.""" grid = Grid((3, 3, 3)) x, y, z = grid.dimensions t = grid.stepping_dim # noqa @@ -197,3 +197,31 @@ def test_w_multi_conditionals(): tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')] assert len(tmps) == 2 + + +@pytest.mark.parametrize('exprs,expected', [ + (['Eq(u, sin(f)*cos(g)*sin(g) + sin(f)*cos(g)*cos(f))'], + ['cos(g[x, y, z])', 'sin(f[x, y, z])', 'r0*r1', + 'r2*sin(g[x, y, z]) + r2*cos(f[x, y, z])']), + (['Eq(u, sin(f)*cos(f)*sin(g)*cos(g) + sin(f)*cos(f)*sin(g) + sin(f)*cos(f))'], + ['cos(f[x, y, z])', 'sin(f[x, y, z])', 'sin(g[x, y, z])', 'r0*r1', + 'r2*r4', 'r0*r1 + r2*r4 + r3*cos(g[x, y, z])']), +]) +def test_tuplets_algo(exprs, expected): + """Test the tuplets-based common subexpressions elimination algorithm.""" + grid = Grid((3, 3, 3)) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + u = TimeFunction(name="u", grid=grid, space_order=2) + + # List comprehension would need explicit locals/globals mappings to eval + for i, e in enumerate(list(exprs)): + exprs[i] = DummyEq(indexify(diffify(eval(e).evaluate))) + + counter = generator() + make = lambda: CTemp(name='r%d' % counter()).indexify() + processed = _cse(exprs, make, mode='advanced') + + assert len(processed) == len(expected) + assert all(str(i.rhs) == j for i, j in zip(processed, expected)) From fbbbc5bf02f23100da8d0778bfb758699b3824b1 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 25 Jul 2024 10:01:06 +0000 Subject: [PATCH 05/16] compiler: Enhance and fix tuplets-based CSE algo --- devito/finite_differences/differentiable.py | 2 + devito/passes/clusters/cse.py | 405 ++++++++++++-------- devito/symbolics/manipulation.py | 24 +- tests/test_cse.py | 85 ++-- tests/test_dse.py | 2 +- tests/test_unexpansion.py | 4 +- 6 files changed, 329 insertions(+), 193 deletions(-) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index b392bd977e..eee2bdcf4c 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -511,6 +511,7 @@ def __new__(cls, *args, **kwargs): # set of basic simplifications # (a+b)+c -> a+b+c (flattening) + # TODO: use symbolics.flatten_args; not using it to avoid a circular import nested, others = split(args, lambda e: isinstance(e, Add)) args = flatten(e.args for e in nested) + list(others) @@ -533,6 +534,7 @@ def __new__(cls, *args, **kwargs): # to avoid generating functional, but ugly, code # (a*b)*c -> a*b*c (flattening) + # TODO: use symbolics.flatten_args; not using it to avoid a circular import nested, others = split(args, lambda e: isinstance(e, Mul)) args = flatten(e.args for e in nested) + list(others) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 5f33767efb..4c1e85d783 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -1,5 +1,5 @@ -from collections import Counter, OrderedDict, defaultdict, namedtuple -from functools import singledispatch +from collections import defaultdict +from functools import cached_property, singledispatch import sympy from sympy import Add, Function, Indexed, Mul, Pow @@ -12,18 +12,14 @@ from devito.finite_differences.differentiable import IndexDerivative from devito.ir import Cluster, Scope, cluster_pass from devito.passes.clusters.utils import makeit_ssa -from devito.symbolics import estimate_cost, q_leaf, search -from devito.symbolics.manipulation import Uxmapper, _uxreplace -from devito.tools import as_list, frozendict +from devito.symbolics import estimate_cost, q_leaf, q_terminal +from devito.symbolics.manipulation import _uxreplace +from devito.tools import DAG, as_list, as_tuple, frozendict from devito.types import Eq, Symbol, Temp __all__ = ['cse'] -Candidate = namedtuple('Candidate', 'expr conditionals sources') -Candidate.__new__.__defaults__ = (None, None, None) - - class CTemp(Temp): """ @@ -35,21 +31,50 @@ class CTemp(Temp): @cluster_pass def cse(cluster, sregistry=None, options=None, **kwargs): """ - Common sub-expressions elimination (CSE). + Perform common sub-expressions elimination (CSE) on a Cluster. + + Two algorithms are available, 'default' and 'advanced'. + + The 'default' algorithm searches for common sub-expressions across the + operations in a given Cluster. However, it does not look for sub-expressions + that are subsets of operands of a given n-ary operation. For example, given + the expression `a*b*c*d + c*d + a*b*c + a*b*e`, it would capture `a*b*c`, + but not `a*b`. + + The 'advanced' algorithm also extracts subsets of operands from a + given n-ary operation, e.g. `a*b` in `a*b*c*d`. In particular, for a given + operation `op(a1, a2, ..., an)` it searches for `n-2` additional + sub-expressions of increasing size, namely `a1*a2`, `a1*a2*a3`, etc. + This algorithm heuristically relies on SymPy's canonical ordering of operands + to maximize the likelihood of finding common sub-expressions. + + Parameters + ---------- + cluster : Cluster + The input Cluster. + sregistry : SymbolRegistry + The symbol registry to use for creating temporaries. + options : dict + The optimization options. + Accepted: ['cse-min-cost', 'cse-algo']. + * 'cse-min-cost': int. The minimum cost of a common sub-expression to be + considered for CSE. Default is 1. + * 'cse-algo': str. The CSE algorithm to apply. Accepted: ['default', + 'advanced']. Default is 'default'. """ + min_cost = options['cse-min-cost'] + mode = options['cse-algo'] + make = lambda: CTemp(name=sregistry.make_name(), dtype=cluster.dtype) - exprs = _cse(cluster, make, - min_cost=options['cse-min-cost'], - mode=options['cse-algo']) + + exprs = _cse(cluster, make, min_cost=min_cost, mode=mode) return cluster.rebuild(exprs=exprs) def _cse(maybe_exprs, make, min_cost=1, mode='default'): """ - Main common sub-expressions elimination routine. - - Note: the output is guaranteed to be topologically sorted. + Carry out the bulk of the CSE process. Parameters ---------- @@ -58,15 +83,17 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): make : callable Build symbols to store temporary, redundant values. mode : str, optional - The CSE algorithm applied. Accepted: ['default', 'tuplets', 'advanced']. - """ - assert mode in ('default', 'tuplets', 'advanced') + The CSE algorithm applied. Accepted: ['default', 'advanced']. - # Note: not defaulting to SymPy's CSE() function for three reasons: - # - it also captures array index access functions - # (e.g., i+1 in A[i+1] and B[i+1]); - # - it sometimes "captures too much", losing factorization opportunities; - # - very slow + Notes + ----- + We're not using SymPy's CSE for three reasons: + + * It also captures array index access functions (e.g., i+1 in A[i+1]); + * It sometimes "captures too much", losing factorization opportunities; + * It tends to be very slow. + """ + assert mode in ('default', 'advanced') # Accept Clusters, Eqs or even just exprs if isinstance(maybe_exprs, Cluster): @@ -94,111 +121,47 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): d_anti = {i.source.access for i in scope.d_anti.independent()} exclude = d_flow & d_anti - if mode in ('default', 'advanced'): - exprs = _cse_default(exprs, exclude, make, min_cost) - if mode in ('tuplets', 'advanced'): - exprs = _cse_tuplets(exprs, exclude, make) - - # Drop useless temporaries (e.g., r0=r1) - processed = _compact_temporaries(exprs, exclude) - - return processed - + # Perform CSE + key = lambda c: c.cost -def _cse_default(exprs, exclude, make, min_cost): - """ - The default common sub-expressions elimination algorithm. - """ + scheduled = {} while True: # Detect redundancies - counted = count(exprs).items() - targets = OrderedDict([(k, estimate_cost(k.expr, True)) - for k, v in counted if v > 1]) + candidates = catch(exprs, mode) # Rule out Dimension-independent data dependencies - targets = OrderedDict([(k, v) for k, v in targets.items() - if not k.expr.free_symbols & exclude]) - if not targets or max(targets.values()) < min_cost: + candidates = [c for c in candidates if not c.expr.free_symbols & exclude] + + if not candidates: break - # Create temporaries - hit = max(targets.values()) - chosen = [(k, make()) for k, v in targets.items() if v == hit] + # Start with the largest + cost = key(max(candidates, key=key)) + if cost < min_cost: + break + candidates = [c for c in candidates if c.cost == cost] # Apply replacements - exprs, scheduled = _inject_temporaries(exprs, chosen, exclude) - - # Update `exclude` for the same reasons as above -- to rule out CSE across - # Dimension-independent data dependences - exclude.update(scheduled) + chosen = [(c, scheduled.get(c.key) or make()) for c in candidates] + exprs = _inject(exprs, chosen, scheduled) - return exprs - - -def _cse_tuplets(exprs, exclude, make): - """ - The tuplets-based common sub-expressions elimination algorithm. + # Drop useless temporaries (e.g., r0=r1) + processed = _compact(exprs, exclude) - This algo relies on SymPy's canonical ordering of operands. It extracts - sub-expressions of decreasing size that may or may not be redundant. + # Ensure topo-sorting + if mode == 'advanced': + processed = _toposort(processed) - Unlike the default algorithm, this one looks inside the individual operations. - However, it does so speculatively, as it doesn't attempt to estimate the cost - of the extracted sub-expressions, which would be an hard problem to solve. + return processed - Another simplification is that we only explore operations whose operands are - leaves, i.e., symbols or indexed objects. - Examples - -------- - Given the expression `a*b*c*d + c*d + a*b*c + a*b*e`, the following - sub-expressions are extracted: `r0 = a*b, r1 = r0*c`, which leads to the - following optimized expression: `r1*d + c*d + r1 + r0*e`. +def _inject(exprs, chosen, scheduled): """ - key = lambda candidate: len(candidate.expr.args) - - while True: - mapper = defaultdict(list) - for e in exprs: - try: - cond = e.conditionals - except AttributeError: - cond = None - - for op in (Add, Mul): - for i in search(e, op): - # The args are in canonical order (thanks to SymPy); let's pick - # the largest sub-expression that is not `i` itself - args = i.args[:-1] - - terms = [a for a in args if q_leaf(a)] - - if len(terms) > 1: - mapper[Candidate(op(*terms), cond)].append(i) - - #mapper = {k: v for k, v in mapper.items() - # if not k.expr.free_symbols & exclude} - - if not mapper: - break - - # Create temporaries of decreasing size - hit = max(mapper, key=key) - chosen = [(Candidate(i.expr, i.conditionals, sources), make()) - for i, sources in mapper.items() if key(i) == key(hit)] - - # Apply replacements - exprs, _ = _inject_temporaries(exprs, chosen, exclude) - - return exprs - + Insert temporaries into the expression list. -def _inject_temporaries(exprs, chosen, exclude): + The resulting expression list may not be topologically sorted. The caller + is responsible for ensuring that. """ - Insert temporaries into the expression list such that they appear right - before the first expression that contains them. - """ - scheduled = [] processed = [] for e in exprs: pe = e @@ -206,37 +169,49 @@ def _inject_temporaries(exprs, chosen, exclude): if k.conditionals != e.conditionals: continue - if k.sources: - # Perform compound-based replacement, see uxreplace.__doc__ - args = list(k.expr.args) - pivot = args.pop(0) - compound = {pivot: v, **{a: None for a in args}} - subs = {i: compound for i in k.sources} - else: - subs = {k.expr: v} + if e.lhs is v: + # This happens when `k.expr` wasn't substituted in a previous + # iteration because `k.sources` (whose construction + # is based on heuristics to avoid a combinatorial explosion) + # didn't include all of the `k.expr` occurrences across `exprs`, + # in particular those as part of a middle-term in a n-ary operation + # (e.g., `b*c` in `a*b*c*d`) + assert k.expr == e.rhs + continue + + subs = k.as_subs(v) pe, changed = _uxreplace(pe, subs) - if changed and v not in scheduled: + + if changed and k.key not in scheduled: processed.append(pe.func(v, k.expr, operation=None)) - scheduled.append(v) + scheduled[k.key] = v + processed.append(pe) - return processed, scheduled + return processed -def _compact_temporaries(exprs, exclude): +def _compact(exprs, exclude): """ - Drop temporaries consisting of isolated symbols. + Drop useless temporaries: + + * Temporaries of the form `t0 = s`, where `s` is a leaf; + * Temporaries of the form `t0 = expr` such that `t0` is accessed only once. """ - # First of all, convert to SSA exprs = makeit_ssa(exprs) - # Drop candidates are all exprs in the form `t0 = s` where `s` is a symbol - # Note: only CSE-captured Temps, which are by construction local objects, may - # safely be compacted; a generic Symbol could instead be accessed in a subsequent - # Cluster, for example: `for (i = ...) { a = b; for (j = a ...) ...` - mapper = {e.lhs: e.rhs for e in exprs - if isinstance(e.lhs, CTemp) and q_leaf(e.rhs) and e.lhs not in exclude} + # Only CSE-captured Temps, namely CTemps, can safely be optimized; a + # generic Symbol could instead be accessed in a subsequent Cluster, e.g. + # `for (i = ...) { a = b; for (j = a ...) ... }` + candidates = [e for e in exprs if isinstance(e.lhs, CTemp)] + + mapper = {e.lhs: e.rhs for e in candidates + if q_leaf(e.rhs) and e.lhs not in exclude} + + #TODO? TO GO? + mapper.update({e.lhs: e.rhs for e in candidates + if sum([i.rhs.count(e.lhs) for i in exprs]) == 1}) processed = [] for e in exprs: @@ -250,65 +225,181 @@ def _compact_temporaries(exprs, exclude): return processed +def _toposort(exprs): + """ + Ensure the expression list is topologically sorted. + """ + dag = DAG(exprs) + + for e0 in exprs: + if not isinstance(e0.lhs, CTemp): + continue + + for e1 in exprs: + if e0.lhs in e1.rhs.free_symbols: + dag.add_edge(e0, e1, force_add=True) + + def choose_element(queue, scheduled): + # Try to honor temporary names as much as possible + first = sorted(queue, key=lambda i: i.lhs.base.name).pop(0) + queue.remove(first) + return first + + processed = dag.topological_sort(choose_element) + + return processed + + +class Candidate(tuple): + + def __new__(cls, expr, conditionals=None, sources=()): + conditionals = frozendict(conditionals or {}) + sources = as_tuple(sources) + return tuple.__new__(cls, (expr, conditionals, sources)) + + @property + def expr(self): + return self[0] + + @property + def conditionals(self): + return self[1] + + @property + def sources(self): + return self[2] + + @property + def key(self): + return (self.expr, self.conditionals) + + @cached_property + def cost(self): + if len(self.sources) == 1: + return 0 + else: + return estimate_cost(self.expr) + + def as_subs(self, v): + subs = {self.expr: v} + + # Also add in subs for compound-based replacement + # E.g., `a*b*c*d` -> `r0*c*d` + for i in self.sources: + if self.expr == i: + continue + + args = [v] + queue = list(self.expr.args) + for a in i.args: + try: + queue.remove(a) + except ValueError: + args.append(a) + assert not queue + subs[i] = self.expr.func(*args) + + return subs + + +def catch(exprs, mode): + """ + Return all common sub-expressions in `exprs` as Candidates. + """ + mapper = _catch(exprs) + + candidates = [] + for k, v in mapper.items(): + if mode == 'default': + sources = [i for i in v if i == k.expr] + else: + sources = v + + if len(sources) > 1: + candidates.append(Candidate(k.expr, k.conditionals, sources)) + + return candidates + + @singledispatch -def count(expr): +def _catch(expr): """ - Construct a mapper `expr -> #occurrences` for each sub-expression in `expr`. + Construct a mapper `(expr, cond) -> [occurrences]` for each sub-expression + in `expr`. + + For example, given `expr = a*b*c`, the output would be: + `{(a*b*c, None): [a*b*c], (a*b, None): [a*b*c]}`. """ - mapper = Counter() + mapper = defaultdict(list) for a in expr.args: - mapper.update(count(a)) + for k, v in _catch(a).items(): + mapper[k].extend(v) return mapper -@count.register(list) -@count.register(tuple) +@_catch.register(list) +@_catch.register(tuple) def _(exprs): - mapper = Counter() + mapper = defaultdict(list) for e in exprs: - mapper.update(count(e)) - + for k, v in _catch(e).items(): + mapper[k].extend(v) return mapper -@count.register(sympy.Eq) +@_catch.register(sympy.Eq) def _(expr): - mapper = count(expr.rhs) + mapper = _catch(expr.rhs) try: cond = expr.conditionals except AttributeError: cond = frozendict() - return {Candidate(e, cond): v for e, v in mapper.items()} + return {Candidate(c.expr, cond): v for c, v in mapper.items()} -@count.register(Indexed) -@count.register(Symbol) +@_catch.register(Indexed) +@_catch.register(Symbol) def _(expr): """ Handler for objects preventing CSE to propagate through their arguments. """ - return Counter() + return {} -@count.register(IndexDerivative) +@_catch.register(IndexDerivative) def _(expr): """ Handler for symbol-binding objects. There can be many of them and therefore they should be detected as common subexpressions, but it's either pointless or forbidden to look inside them. """ - return Counter([expr]) + return {Candidate(expr): [expr]} -@count.register(Add) -@count.register(Mul) -@count.register(Pow) -@count.register(Function) +@_catch.register(Add) +@_catch.register(Mul) def _(expr): - mapper = Counter() - for a in expr.args: - mapper.update(count(a)) + mapper = _catch(expr.args) + + mapper[Candidate(expr)].append(expr) + + for n in range(2, len(expr.args)): + terms = expr.args[:n] + + # Heuristic: let the factorizer handle the rest + terms = [a for a in terms if q_terminal(a)] + + v = expr.func(*terms, evaluate=False) + mapper[Candidate(v)].append(expr) + + return mapper + + +@_catch.register(Pow) +@_catch.register(Function) +def _(expr): + mapper = _catch(expr.args) - mapper[expr] += 1 + mapper[Candidate(expr)].append(expr) return mapper diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 8447c22b92..1762c4250d 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -6,6 +6,7 @@ from sympy.core.add import _addsort from sympy.core.mul import _mulsort +from devito.finite_differences.differentiable import EvalDerivative from devito.symbolics.extended_sympy import DefFunction, rfunc from devito.symbolics.queries import q_leaf from devito.symbolics.search import retrieve_indexed, retrieve_functions @@ -17,7 +18,7 @@ __all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args', 'normalize_args', 'uxreplace', 'Uxmapper', 'reuse_if_untouched', - 'evalrel'] + 'evalrel', 'flatten_args'] def uxreplace(expr, rule): @@ -139,6 +140,7 @@ def _(expr, args, kwargs): if all(i.is_commutative for i in args): _addsort(args) _eval_numbers(expr, args) + args = flatten_args(args, Add, ignore=EvalDerivative) return expr.func(*args, evaluate=False) else: return expr._new_rawargs(*args) @@ -154,6 +156,7 @@ def _(expr, args, kwargs): if all(i.is_commutative for i in args): _mulsort(args) _eval_numbers(expr, args) + args = flatten_args(args, Mul, ignore=EvalDerivative) return expr.func(*args, evaluate=False) else: return expr._new_rawargs(*args) @@ -276,6 +279,25 @@ def _eval_numbers(expr, args): args[:] = [expr.func(*numbers)] + others +def flatten_args(args, op, ignore=None): + """ + Flatten the arguments of type `op` in `args`. + + Examples + -------- + * (a+b)+c -> a+b+c + * (a*b)*c -> a*b*c + * (a+b)*c -> (a+b)*c + """ + if ignore is not None and any(isinstance(a, ignore) for a in args): + return args + + key = lambda e: isinstance(e, op) + nested, others = split(args, key) + + return flatten(e.args for e in nested) + list(others) + + def pow_to_mul(expr): if q_leaf(expr) or isinstance(expr, Basic): return expr diff --git a/tests/test_cse.py b/tests/test_cse.py index 3c61d89861..56e0fddba0 100644 --- a/tests/test_cse.py +++ b/tests/test_cse.py @@ -1,5 +1,6 @@ import pytest +import numpy as np from sympy import Ge, Lt from sympy.core.mul import _mulsort @@ -35,7 +36,7 @@ # Divisions (== powers with negative exponenet) are always captured (['Eq(tu, tv**-1*(tw*5 + tw*5*t0))', 'Eq(ti0, tv**-1*t0)'], ['1/tv[t, x, y, z]', 'r0*(5*t0*tw[t, x, y, z] + 5*tw[t, x, y, z])', 'r0*t0'], 0), - # `compact_temporaries` must detect chains of isolated temporaries + # `cse._compact(...)` must detect chains of isolated temporaries (['Eq(t0, tv)', 'Eq(t1, t0)', 'Eq(t2, t1)', 'Eq(tu, t2)'], ['tv[t, x, y, z]'], 0), # Dimension-independent flow+anti dependences should be a stopper for CSE @@ -54,22 +55,23 @@ # Fancy use case with lots of temporaries (['Eq(tu.forward, tu.dx + 1)', 'Eq(tv.forward, tv.dx + 1)', 'Eq(tw.forward, tv.dt.dx2.dy2 + 1)', 'Eq(tz.forward, tv.dt.dy2.dx2 + 2)'], - ['1/h_x', '-r11*tu[t, x, y, z] + r11*tu[t, x + 1, y, z] + 1', - '-r11*tv[t, x, y, z] + r11*tv[t, x + 1, y, z] + 1', - '1/dt', '-r12*tv[t, x - 1, y - 1, z] + r12*tv[t + 1, x - 1, y - 1, z]', - '-r12*tv[t, x + 1, y - 1, z] + r12*tv[t + 1, x + 1, y - 1, z]', - '-r12*tv[t, x, y - 1, z] + r12*tv[t + 1, x, y - 1, z]', - '-r12*tv[t, x - 1, y + 1, z] + r12*tv[t + 1, x - 1, y + 1, z]', - '-r12*tv[t, x + 1, y + 1, z] + r12*tv[t + 1, x + 1, y + 1, z]', - '-r12*tv[t, x, y + 1, z] + r12*tv[t + 1, x, y + 1, z]', - '-r12*tv[t, x - 1, y, z] + r12*tv[t + 1, x - 1, y, z]', - '-r12*tv[t, x + 1, y, z] + r12*tv[t + 1, x + 1, y, z]', - '-r12*tv[t, x, y, z] + r12*tv[t + 1, x, y, z]', - 'h_x**(-2)', '-2.0*r13', 'h_y**(-2)', '-2.0*r14', - 'r10*(r13*r6 + r13*r7 + r8*r9) + r14*(r0*r13 + r1*r13 + r2*r9) + ' + - 'r14*(r13*r3 + r13*r4 + r5*r9) + 1', - 'r13*(r0*r14 + r10*r6 + r14*r3) + r13*(r1*r14 + r10*r7 + r14*r4) + ' + - 'r9*(r10*r8 + r14*r2 + r14*r5) + 2'], 0), + ['1/h_x', + '-r9*tu[t, x, y, z] + r9*tu[t, x + 1, y, z] + 1', + '-r9*tv[t, x, y, z] + r9*tv[t, x + 1, y, z] + 1', + '1/dt', + '-r10*tv[t, x - 1, y - 1, z] + r10*tv[t + 1, x - 1, y - 1, z]', + '-r10*tv[t, x + 1, y - 1, z] + r10*tv[t + 1, x + 1, y - 1, z]', + '-r10*tv[t, x, y - 1, z] + r10*tv[t + 1, x, y - 1, z]', + '-r10*tv[t, x - 1, y + 1, z] + r10*tv[t + 1, x - 1, y + 1, z]', + '-r10*tv[t, x + 1, y + 1, z] + r10*tv[t + 1, x + 1, y + 1, z]', + '-r10*tv[t, x, y + 1, z] + r10*tv[t + 1, x, y + 1, z]', + '-r10*tv[t, x - 1, y, z] + r10*tv[t + 1, x - 1, y, z]', + '-r10*tv[t, x + 1, y, z] + r10*tv[t + 1, x + 1, y, z]', + '-r10*tv[t, x, y, z] + r10*tv[t + 1, x, y, z]', + 'h_y**(-2)', + 'h_x**(-2)', + '(-2.0*r11)*(r12*r6 + r12*r7 - 2.0*r12*r8) + r11*(r0*r12 + r1*r12 - 2.0*r12*r2) + r11*(r12*r3 + r12*r4 - 2.0*r12*r5) + 1', + '(-2.0*r12)*(r11*r2 + r11*r5 - 2.0*r11*r8) + r12*(r0*r11 + r11*r3 - 2.0*r11*r6) + r12*(r1*r11 + r11*r4 - 2.0*r11*r7) + 2'], 0), # Existing temporaries from nested Function as index (['Eq(e0, fx[x])', 'Eq(tu, cos(-tu[t, e0, y, z]) + tv[t, x, y, z])', 'Eq(tv, cos(tu[t, e0, y, z]) + tw)'], @@ -91,14 +93,16 @@ def test_default_algo(exprs, expected, min_cost): tw = TimeFunction(name="tw", grid=grid, space_order=2) # noqa tz = TimeFunction(name="tz", grid=grid, space_order=2) # noqa fx = Function(name="fx", grid=grid, dimensions=(x,), shape=(3,)) # noqa - ti0 = Array(name='ti0', shape=(3, 5, 7), dimensions=(x, y, z)).indexify() # noqa - ti1 = Array(name='ti1', shape=(3, 5, 7), dimensions=(x, y, z)).indexify() # noqa - t0 = CTemp(name='t0') # noqa - t1 = CTemp(name='t1') # noqa - t2 = CTemp(name='t2') # noqa + ti0 = Array(name='ti0', shape=(3, 5, 7), dimensions=(x, y, z), + dtype=np.float32).indexify() # noqa + ti1 = Array(name='ti1', shape=(3, 5, 7), dimensions=(x, y, z), + dtype=np.float32).indexify() # noqa + t0 = CTemp(name='t0', dtype=np.float32) # noqa + t1 = CTemp(name='t1', dtype=np.float32) # noqa + t2 = CTemp(name='t2', dtype=np.float32) # noqa # Needs to not be a Temp to mimic nested index extraction and prevent # cse to compact the temporary back. - e0 = Symbol(name='e0') # noqa + e0 = Symbol(name='e0', dtype=np.float32) # noqa # List comprehension would need explicit locals/globals mappings to eval for i, e in enumerate(list(exprs)): @@ -201,26 +205,43 @@ def test_w_multi_conditionals(): @pytest.mark.parametrize('exprs,expected', [ (['Eq(u, sin(f)*cos(g)*sin(g) + sin(f)*cos(g)*cos(f))'], - ['cos(g[x, y, z])', 'sin(f[x, y, z])', 'r0*r1', + ['sin(f[x, y, z])*cos(g[x, y, z])', 'r2*sin(g[x, y, z]) + r2*cos(f[x, y, z])']), (['Eq(u, sin(f)*cos(f)*sin(g)*cos(g) + sin(f)*cos(f)*sin(g) + sin(f)*cos(f))'], - ['cos(f[x, y, z])', 'sin(f[x, y, z])', 'sin(g[x, y, z])', 'r0*r1', - 'r2*r4', 'r0*r1 + r2*r4 + r3*cos(g[x, y, z])']), + ['sin(f[x, y, z])*cos(f[x, y, z])', 'r4*sin(g[x, y, z])', + 'r3*cos(g[x, y, z]) + r3 + r4']), + (['Eq(u, t0*t1*t2)'], + ['t0*t1*t2']), + # Because of the compound heuristic, we ain't catching the inner r0*r1 + (['Eq(u, 2*sin(f)*cos(f)*sin(g) + 3*sin(f)*cos(f))'], + ['cos(f[x, y, z])', 'sin(f[x, y, z])', '2*r0*r1*sin(g[x, y, z]) + 3*r0*r1']), + (['Eq(u, 2*sin(f)*cos(f)*sin(g) + sin(f)*cos(f))'], + ['sin(f[x, y, z])*cos(f[x, y, z])', '2*r2*sin(g[x, y, z]) + r2']), + (['Eq(u, t0 + t1 - (t2 + t3 + f))', 'Eq(v, t0 + t1 - (t2 + t3 + g))'], + ['t0 + t1', 'r0 - t2 - t3 - f[x, y, z]', 'r0 - t2 - t3 - g[x, y, z]']), + (['Eq(u, t0 + t1 - f*(t2 + t3))', 'Eq(v, f*(t0 + t1) - g*(t2 + t3))'], + ['t2 + t3', 't0 + t1', '-r0*f[x, y, z] + r1', + '-r0*g[x, y, z] + r1*f[x, y, z]']), ]) -def test_tuplets_algo(exprs, expected): - """Test the tuplets-based common subexpressions elimination algorithm.""" +def test_advanced_algo(exprs, expected): + """Test the advanced common subexpressions elimination algorithm.""" grid = Grid((3, 3, 3)) - f = Function(name='f', grid=grid) - g = Function(name='g', grid=grid) - u = TimeFunction(name="u", grid=grid, space_order=2) + f = Function(name='f', grid=grid) # noqa + g = Function(name='g', grid=grid) # noqa + u = TimeFunction(name="u", grid=grid, space_order=2) # noqa + v = TimeFunction(name="v", grid=grid, space_order=2) # noqa + t0 = CTemp(name='t0', dtype=np.float32) # noqa + t1 = CTemp(name='t1', dtype=np.float32) # noqa + t2 = CTemp(name='t2', dtype=np.float32) # noqa + t3 = CTemp(name='t3', dtype=np.float32) # noqa # List comprehension would need explicit locals/globals mappings to eval for i, e in enumerate(list(exprs)): exprs[i] = DummyEq(indexify(diffify(eval(e).evaluate))) counter = generator() - make = lambda: CTemp(name='r%d' % counter()).indexify() + make = lambda: CTemp(name='r%d' % counter(), dtype=np.float32).indexify() processed = _cse(exprs, make, mode='advanced') assert len(processed) == len(expected) diff --git a/tests/test_dse.py b/tests/test_dse.py index a8c95ce393..bba5cf2a78 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -1763,7 +1763,7 @@ def g2_tilde(field, phi, theta): assert len([i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array]) == 7 assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 3 - @pytest.mark.parametrize('so_ops', [(4, 146), (8, 210)]) + @pytest.mark.parametrize('so_ops', [(4, 147), (8, 211)]) @switchconfig(profiling='advanced') def test_tti_J_akin_complete(self, so_ops): grid = Grid(shape=(16, 16, 16)) diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index 211dcd234e..8b865b155a 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -227,7 +227,7 @@ def test_v4(self): 'cire-mingain': 400})) # Check code generation - assert op._profiler._sections['section1'].sops == 1442 + assert op._profiler._sections['section1'].sops == 1443 assert_structure(op, ['x,y,z', 't,x0_blk0,y0_blk0,x,y,z', 't,x0_blk0,y0_blk0,x,y,z,i1', @@ -397,7 +397,7 @@ def test_v1(self): 'openmp': False})) # Check code generation - assert op._profiler._sections['section1'].sops == 190 + assert op._profiler._sections['section1'].sops == 191 assert_structure(op, ['x,y,z', 't,x0_blk0,y0_blk0,x,y,z', 't,x0_blk0,y0_blk0,x,y,z,i0', From 910129cece506bc8fac2ce86c43c190ca3bbc90d Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 31 Jul 2024 14:02:45 +0000 Subject: [PATCH 06/16] compiler: Refine CSE and factorization algo selection --- devito/core/cpu.py | 5 ++-- devito/core/gpu.py | 5 ++-- devito/core/operator.py | 10 ++++++- devito/passes/clusters/cse.py | 26 +++++++++++------- devito/passes/clusters/factorization.py | 36 +++++++++++++------------ 5 files changed, 50 insertions(+), 32 deletions(-) diff --git a/devito/core/cpu.py b/devito/core/cpu.py index 01752daa53..b9baedb237 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -36,9 +36,10 @@ def _normalize_kwargs(cls, **kwargs): # Fusion o['fuse-tasks'] = oo.pop('fuse-tasks', False) - # CSE + # Flops minimization o['cse-min-cost'] = oo.pop('cse-min-cost', cls.CSE_MIN_COST) o['cse-algo'] = oo.pop('cse-algo', cls.CSE_ALGO) + o['fact-schedule'] = oo.pop('fact-schedule', cls.FACT_SCHEDULE) # Blocking o['blockinner'] = oo.pop('blockinner', False) @@ -169,7 +170,7 @@ def _specialize_clusters(cls, clusters, **kwargs): # Reduce flops clusters = cire(clusters, 'sops', sregistry, options, platform) - clusters = factorize(clusters) + clusters = factorize(clusters, **kwargs) clusters = optimize_pows(clusters) # The previous passes may have created fusion opportunities diff --git a/devito/core/gpu.py b/devito/core/gpu.py index 4c594a09ac..0e42b4886c 100644 --- a/devito/core/gpu.py +++ b/devito/core/gpu.py @@ -46,9 +46,10 @@ def _normalize_kwargs(cls, **kwargs): # Fusion o['fuse-tasks'] = oo.pop('fuse-tasks', False) - # CSE + # Flops minimization o['cse-min-cost'] = oo.pop('cse-min-cost', cls.CSE_MIN_COST) o['cse-algo'] = oo.pop('cse-algo', cls.CSE_ALGO) + o['fact-schedule'] = oo.pop('fact-schedule', cls.FACT_SCHEDULE) # Blocking o['blockinner'] = oo.pop('blockinner', True) @@ -197,7 +198,7 @@ def _specialize_clusters(cls, clusters, **kwargs): # Reduce flops clusters = cire(clusters, 'sops', sregistry, options, platform) - clusters = factorize(clusters) + clusters = factorize(clusters, **kwargs) clusters = optimize_pows(clusters) # The previous passes may have created fusion opportunities diff --git a/devito/core/operator.py b/devito/core/operator.py index 0e95a6c90d..39c7a61fef 100644 --- a/devito/core/operator.py +++ b/devito/core/operator.py @@ -27,11 +27,16 @@ class BasicOperator(Operator): common sub=expression. """ - CSE_ALGO = 'default' + CSE_ALGO = 'basic' """ The algorithm to use for common sub-expression elimination. """ + FACT_SCHEDULE = 'basic' + """ + The schedule to use for the computation of factorizations. + """ + BLOCK_LEVELS = 1 """ Loop blocking depth. So, 1 => "blocks", 2 => "blocks" and "sub-blocks", @@ -164,6 +169,9 @@ def _check_kwargs(cls, **kwargs): if oo['mpi'] and oo['mpi'] not in cls.MPI_MODES: raise InvalidOperator("Unsupported MPI mode `%s`" % oo['mpi']) + if oo['cse-algo'] not in ('basic', 'smartsort', 'advanced'): + raise InvalidArgument("Illegal `cse-algo` value") + if oo['deriv-schedule'] not in ('basic', 'smart'): raise InvalidArgument("Illegal `deriv-schedule` value") if oo['deriv-unroll'] not in (False, 'inner', 'full'): diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 4c1e85d783..839dbe5edc 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -33,20 +33,26 @@ def cse(cluster, sregistry=None, options=None, **kwargs): """ Perform common sub-expressions elimination (CSE) on a Cluster. - Two algorithms are available, 'default' and 'advanced'. + Three algorithms are available, 'basic', 'smartsort', and 'advanced'. - The 'default' algorithm searches for common sub-expressions across the + The 'basic' algorithm searches for common sub-expressions across the operations in a given Cluster. However, it does not look for sub-expressions that are subsets of operands of a given n-ary operation. For example, given the expression `a*b*c*d + c*d + a*b*c + a*b*e`, it would capture `a*b*c`, but not `a*b`. + The 'smartsort' algorithm is an extension of the 'basic' algorithm. It + performs a final topological sorting of the expressions to maximize the + proximity of the common sub-expressions to their uses. + The 'advanced' algorithm also extracts subsets of operands from a given n-ary operation, e.g. `a*b` in `a*b*c*d`. In particular, for a given operation `op(a1, a2, ..., an)` it searches for `n-2` additional sub-expressions of increasing size, namely `a1*a2`, `a1*a2*a3`, etc. This algorithm heuristically relies on SymPy's canonical ordering of operands to maximize the likelihood of finding common sub-expressions. + This algorithm also performs a final topological sorting of the expressions, + like the 'smartsort' algorithm. Parameters ---------- @@ -59,8 +65,8 @@ def cse(cluster, sregistry=None, options=None, **kwargs): Accepted: ['cse-min-cost', 'cse-algo']. * 'cse-min-cost': int. The minimum cost of a common sub-expression to be considered for CSE. Default is 1. - * 'cse-algo': str. The CSE algorithm to apply. Accepted: ['default', - 'advanced']. Default is 'default'. + * 'cse-algo': str. The CSE algorithm to apply. Accepted: ['basic', + 'smartsort', 'advanced']. Default is 'basic'. """ min_cost = options['cse-min-cost'] mode = options['cse-algo'] @@ -72,7 +78,7 @@ def cse(cluster, sregistry=None, options=None, **kwargs): return cluster.rebuild(exprs=exprs) -def _cse(maybe_exprs, make, min_cost=1, mode='default'): +def _cse(maybe_exprs, make, min_cost=1, mode='basic'): """ Carry out the bulk of the CSE process. @@ -83,7 +89,7 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): make : callable Build symbols to store temporary, redundant values. mode : str, optional - The CSE algorithm applied. Accepted: ['default', 'advanced']. + The CSE algorithm applied. Accepted: ['basic', 'smartsort', 'advanced']. Notes ----- @@ -93,7 +99,7 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): * It sometimes "captures too much", losing factorization opportunities; * It tends to be very slow. """ - assert mode in ('default', 'advanced') + assert mode in ('basic', 'smartsort', 'advanced') # Accept Clusters, Eqs or even just exprs if isinstance(maybe_exprs, Cluster): @@ -148,8 +154,8 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): # Drop useless temporaries (e.g., r0=r1) processed = _compact(exprs, exclude) - # Ensure topo-sorting - if mode == 'advanced': + # Ensure topo-sorting ('basic' doesn't require it) + if mode in ('smartsort', 'advanced'): processed = _toposort(processed) return processed @@ -310,7 +316,7 @@ def catch(exprs, mode): candidates = [] for k, v in mapper.items(): - if mode == 'default': + if mode in ('basic', 'smartsort'): sources = [i for i in v if i == k.expr] else: sources = v diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index 0fb94b610a..da9b74a7db 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -19,7 +19,7 @@ @cluster_pass -def factorize(cluster, *args): +def factorize(cluster, *args, options=None, **kwargs): """ Factorize trascendental functions, symbolic powers, numeric coefficients. @@ -27,16 +27,18 @@ def factorize(cluster, *args): then the algorithm is applied recursively until no more factorization opportunities are detected. """ + strategy = options.get('fact_schedule', 'basic') + processed = [] for expr in cluster.exprs: - handle = collect_nested(expr) + handle = collect_nested(expr, strategy) cost_handle = estimate_cost(handle) if cost_handle >= MIN_COST_FACTORIZE: handle_prev = handle cost_prev = estimate_cost(expr) while cost_handle < cost_prev: - handle_prev, handle = handle, collect_nested(handle) + handle_prev, handle = handle, collect_nested(handle, strategy) cost_prev, cost_handle = cost_handle, estimate_cost(handle) cost_handle, handle = cost_prev, handle_prev @@ -45,12 +47,12 @@ def factorize(cluster, *args): return cluster.rebuild(processed) -def collect_special(expr): +def collect_special(expr, strategy): """ Factorize elemental functions, pows, and other special symbolic objects, prioritizing the most expensive entities. """ - args, candidates = zip(*[_collect_nested(arg) for arg in expr.args]) + args, candidates = zip(*[_collect_nested(a, strategy) for a in expr.args]) candidates = ReducerMap.fromdicts(*candidates) funcs = candidates.getall('funcs', []) @@ -173,19 +175,19 @@ def collect_const(expr): return Add(*terms) -def strategy0(expr): - rebuilt = collect_special(expr) +def strategy0(expr, strategy): + rebuilt = collect_special(expr, strategy) rebuilt = collect_const(rebuilt) return rebuilt strategies = { - 'default': strategy0 + 'basic': strategy0 } -def _collect_nested(expr): +def _collect_nested(expr, strategy): """ Recursion helper for `collect_nested`. """ @@ -195,7 +197,7 @@ def _collect_nested(expr): return expr, {'coeffs': expr} elif q_routine(expr): # E.g., a DefFunction - args, candidates = zip(*[_collect_nested(arg) for arg in expr.args]) + args, candidates = zip(*[_collect_nested(a, strategy) for a in expr.args]) return expr.func(*args, evaluate=False), {} elif expr.is_Function: return expr, {'funcs': expr} @@ -205,21 +207,21 @@ def _collect_nested(expr): isinstance(expr, (BasicWrapperMixin, AbstractObject))): return expr, {} elif expr.is_Add: - return strategies['default'](expr), {} + return strategies[strategy](expr, strategy), {} elif expr.is_Mul: - args, candidates = zip(*[_collect_nested(arg) for arg in expr.args]) + args, candidates = zip(*[_collect_nested(a, strategy) for a in expr.args]) expr = reuse_if_untouched(expr, args, evaluate=True) return expr, ReducerMap.fromdicts(*candidates) elif expr.is_Equality: - args, candidates = zip(*[_collect_nested(expr.lhs), - _collect_nested(expr.rhs)]) + args, candidates = zip(*[_collect_nested(expr.lhs, strategy), + _collect_nested(expr.rhs, strategy)]) return expr.func(*args, evaluate=False), ReducerMap.fromdicts(*candidates) else: - args, candidates = zip(*[_collect_nested(arg) for arg in expr.args]) + args, candidates = zip(*[_collect_nested(a, strategy) for a in expr.args]) return expr.func(*args), ReducerMap.fromdicts(*candidates) -def collect_nested(expr): +def collect_nested(expr, strategy='basic'): """ Collect numeric coefficients, trascendental functions, pows, and other symbolic objects across all levels of the expression tree. @@ -229,4 +231,4 @@ def collect_nested(expr): expr : expr-like The expression to be factorized. """ - return _collect_nested(expr)[0] + return _collect_nested(expr, strategy)[0] From 7232dfca6742f7661084fed2470e94db9b73091e Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 20 Aug 2024 14:32:02 +0000 Subject: [PATCH 07/16] compiler: Relax detect_accesses --- devito/ir/support/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/devito/ir/support/utils.py b/devito/ir/support/utils.py index 7a40883e12..0e75619cbc 100644 --- a/devito/ir/support/utils.py +++ b/devito/ir/support/utils.py @@ -186,7 +186,11 @@ def detect_accesses(exprs): other_dims = set() for e in as_tuple(exprs): other_dims.update(i for i in e.free_symbols if isinstance(i, Dimension)) - other_dims.update(e.implicit_dims or {}) + try: + other_dims.update(e.implicit_dims or {}) + except AttributeError: + # Not a types.Eq + pass other_dims = filter_sorted(other_dims) mapper[None] = Stencil([(i, 0) for i in other_dims]) From 60aeb314fcf88d325afd5eb08e50d1d2e3ec0d2c Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 21 Aug 2024 08:59:35 +0000 Subject: [PATCH 08/16] compiler: Drop useless makeit_ssa --- devito/passes/clusters/cse.py | 3 --- devito/passes/clusters/utils.py | 40 ++------------------------------- tests/test_dse.py | 38 ------------------------------- tests/test_roundoff.py | 1 - 4 files changed, 2 insertions(+), 80 deletions(-) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 839dbe5edc..5c8c289af0 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -11,7 +11,6 @@ from devito.finite_differences.differentiable import IndexDerivative from devito.ir import Cluster, Scope, cluster_pass -from devito.passes.clusters.utils import makeit_ssa from devito.symbolics import estimate_cost, q_leaf, q_terminal from devito.symbolics.manipulation import _uxreplace from devito.tools import DAG, as_list, as_tuple, frozendict @@ -205,8 +204,6 @@ def _compact(exprs, exclude): * Temporaries of the form `t0 = s`, where `s` is a leaf; * Temporaries of the form `t0 = expr` such that `t0` is accessed only once. """ - exprs = makeit_ssa(exprs) - # Only CSE-captured Temps, namely CTemps, can safely be optimized; a # generic Symbol could instead be accessed in a subsequent Cluster, e.g. # `for (i = ...) { a = b; for (j = a ...) ... }` diff --git a/devito/passes/clusters/utils.py b/devito/passes/clusters/utils.py index 7405dd984d..7a48f3c486 100644 --- a/devito/passes/clusters/utils.py +++ b/devito/passes/clusters/utils.py @@ -1,44 +1,8 @@ from devito.ir import Cluster -from devito.symbolics import uxreplace from devito.tools import as_tuple -from devito.types import CriticalRegion, Eq, Symbol, Wildcard +from devito.types import CriticalRegion, Eq, Symbol -__all__ = ['makeit_ssa', 'is_memcpy', 'make_critical_sequence', - 'in_critical_region'] - - -def makeit_ssa(exprs): - """ - Convert an iterable of Eqs into Static Single Assignment (SSA) form. - """ - # Identify recurring LHSs - seen = {} - for i, e in enumerate(exprs): - if not isinstance(e.lhs, Wildcard): - seen.setdefault(e.lhs, []).append(i) - # Optimization: don't waste time reconstructing stuff if already in SSA form - if all(len(i) == 1 for i in seen.values()): - return exprs - # SSA conversion - c = 0 - mapper = {} - processed = [] - for i, e in enumerate(exprs): - where = seen[e.lhs] - rhs = uxreplace(e.rhs, mapper) - if len(where) > 1: - needssa = e.is_Scalar or where[-1] != i - lhs = Symbol(name='ssa%d' % c, dtype=e.dtype) if needssa else e.lhs - if e.is_Increment: - # Turn AugmentedAssignment into Assignment - processed.append(e.func(lhs, mapper[e.lhs] + rhs, operation=None)) - else: - processed.append(e.func(lhs, rhs)) - mapper[e.lhs] = lhs - c += 1 - else: - processed.append(e.func(e.lhs, rhs)) - return processed +__all__ = ['is_memcpy', 'make_critical_sequence', 'in_critical_region'] def is_memcpy(expr): diff --git a/tests/test_dse.py b/tests/test_dse.py index bba5cf2a78..60006365ce 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -190,44 +190,6 @@ def test_estimate_cost(expr, expected, estimate): assert estimate_cost(eval(expr), estimate) == expected -@pytest.mark.parametrize('exprs,exp_u,exp_v', [ - (['Eq(s, 0, implicit_dims=(x, y))', 'Eq(s, s + 4, implicit_dims=(x, y))', - 'Eq(u, s)'], 4, 0), - (['Eq(s, 0, implicit_dims=(x, y))', 'Eq(s, s + s + 4, implicit_dims=(x, y))', - 'Eq(s, s + 4, implicit_dims=(x, y))', 'Eq(u, s)'], 8, 0), - (['Eq(s, 0, implicit_dims=(x, y))', 'Inc(s, 4, implicit_dims=(x, y))', - 'Eq(u, s)'], 4, 0), - (['Eq(s, 0, implicit_dims=(x, y))', 'Inc(s, 4, implicit_dims=(x, y))', 'Eq(v, s)', - 'Eq(u, s)'], 4, 4), - (['Eq(s, 0, implicit_dims=(x, y))', 'Inc(s, 4, implicit_dims=(x, y))', 'Eq(v, s)', - 'Eq(s, s + 4, implicit_dims=(x, y))', 'Eq(u, s)'], 8, 4), - (['Eq(s, 0, implicit_dims=(x, y))', 'Inc(s, 4, implicit_dims=(x, y))', 'Eq(v, s)', - 'Inc(s, 4, implicit_dims=(x, y))', 'Eq(u, s)'], 8, 4), - (['Eq(u, 0)', 'Inc(u, 4)', 'Eq(v, u)', 'Inc(u, 4)'], 8, 4), - (['Eq(u, 1)', 'Eq(v, 4)', 'Inc(u, v)', 'Inc(v, u)'], 5, 9), -]) -def test_makeit_ssa(exprs, exp_u, exp_v): - """ - A test building Operators with non-trivial sequences of input expressions - that push hard on the `makeit_ssa` utility function. - """ - grid = Grid(shape=(4, 4)) - x, y = grid.dimensions # noqa - u = Function(name='u', grid=grid) # noqa - v = Function(name='v', grid=grid) # noqa - s = Scalar(name='s') # noqa - - # List comprehension would need explicit locals/globals mappings to eval - for i, e in enumerate(list(exprs)): - exprs[i] = eval(e) - - op = Operator(exprs) - op.apply() - - assert np.all(u.data == exp_u) - assert np.all(v.data == exp_v) - - @pytest.mark.parametrize('opt', ['noop', 'advanced']) def test_time_dependent_split(opt): grid = Grid(shape=(10, 10)) diff --git a/tests/test_roundoff.py b/tests/test_roundoff.py index e853019deb..5b3901f791 100644 --- a/tests/test_roundoff.py +++ b/tests/test_roundoff.py @@ -97,7 +97,6 @@ def test_lm_fb(self, dat, dtype): grid = Grid(shape=(2, 2), extent=(1, 1), dtype=dtype) dt = grid.stepping_dim.spacing - print("dt = ", dt) f0 = TimeFunction(name='f0', grid=grid, time_order=2, dtype=dtype) f1 = TimeFunction(name='f1', grid=grid, time_order=2, save=iterations+2, From 5df6a9074686cb9dcd6b4d09ceef5fd022d536d6 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 21 Aug 2024 09:24:29 +0000 Subject: [PATCH 09/16] compiler: Improve CSE's compact() --- devito/passes/clusters/cse.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 5c8c289af0..5e06b1d1be 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -203,16 +203,19 @@ def _compact(exprs, exclude): * Temporaries of the form `t0 = s`, where `s` is a leaf; * Temporaries of the form `t0 = expr` such that `t0` is accessed only once. + + Notes + ----- + Only CSE-captured Temps, namely CTemps, can safely be optimized; a + generic Symbol could instead be accessed in a subsequent Cluster, e.g. + `for (i = ...) { a = b; for (j = a ...) ... }`. Hence, this routine + only targets CTemps. """ - # Only CSE-captured Temps, namely CTemps, can safely be optimized; a - # generic Symbol could instead be accessed in a subsequent Cluster, e.g. - # `for (i = ...) { a = b; for (j = a ...) ... }` - candidates = [e for e in exprs if isinstance(e.lhs, CTemp)] + candidates = [e for e in exprs + if isinstance(e.lhs, CTemp) and e.lhs not in exclude] - mapper = {e.lhs: e.rhs for e in candidates - if q_leaf(e.rhs) and e.lhs not in exclude} + mapper = {e.lhs: e.rhs for e in candidates if q_leaf(e.rhs)} - #TODO? TO GO? mapper.update({e.lhs: e.rhs for e in candidates if sum([i.rhs.count(e.lhs) for i in exprs]) == 1}) From eda1b7008f9922aaffcf2c4a509cda96751469f3 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 21 Aug 2024 09:55:58 +0000 Subject: [PATCH 10/16] compiler: Fix CSE's smartsort determinism --- devito/passes/clusters/cse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 5e06b1d1be..566eefb616 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -247,7 +247,7 @@ def _toposort(exprs): def choose_element(queue, scheduled): # Try to honor temporary names as much as possible - first = sorted(queue, key=lambda i: i.lhs.base.name).pop(0) + first = sorted(queue, key=lambda i: str(i.lhs)).pop(0) queue.remove(first) return first From 370c3a142f652ab408bf44ba2b22e6162826c084 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 21 Aug 2024 09:56:20 +0000 Subject: [PATCH 11/16] api: Change nthreads default value --- devito/types/parallel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/devito/types/parallel.py b/devito/types/parallel.py index 65ee2fbdbc..9f127e1263 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -44,8 +44,10 @@ class NThreadsBase(NThreadsAbstract): @cached_property def default_value(self): - return int(os.environ.get('OMP_NUM_THREADS', - configuration['platform'].cores_physical)) + return int(os.environ.get( + 'OMP_NUM_THREADS', + configuration['platform'].cores_physical_per_numa_domain + )) def _arg_defaults(self, **kwargs): base_nthreads = self.default_value From 91f04ca3537bf0ed8aa3c938cfcbc41799dd5ba4 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 21 Aug 2024 10:02:07 +0000 Subject: [PATCH 12/16] pep8 happiness --- tests/test_cse.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_cse.py b/tests/test_cse.py index 56e0fddba0..b84f7c099c 100644 --- a/tests/test_cse.py +++ b/tests/test_cse.py @@ -70,8 +70,8 @@ '-r10*tv[t, x, y, z] + r10*tv[t + 1, x, y, z]', 'h_y**(-2)', 'h_x**(-2)', - '(-2.0*r11)*(r12*r6 + r12*r7 - 2.0*r12*r8) + r11*(r0*r12 + r1*r12 - 2.0*r12*r2) + r11*(r12*r3 + r12*r4 - 2.0*r12*r5) + 1', - '(-2.0*r12)*(r11*r2 + r11*r5 - 2.0*r11*r8) + r12*(r0*r11 + r11*r3 - 2.0*r11*r6) + r12*(r1*r11 + r11*r4 - 2.0*r11*r7) + 2'], 0), + '(-2.0*r11)*(r12*r6 + r12*r7 - 2.0*r12*r8) + r11*(r0*r12 + r1*r12 - 2.0*r12*r2) + r11*(r12*r3 + r12*r4 - 2.0*r12*r5) + 1', # noqa + '(-2.0*r12)*(r11*r2 + r11*r5 - 2.0*r11*r8) + r12*(r0*r11 + r11*r3 - 2.0*r11*r6) + r12*(r1*r11 + r11*r4 - 2.0*r11*r7) + 2'], 0), # noqa # Existing temporaries from nested Function as index (['Eq(e0, fx[x])', 'Eq(tu, cos(-tu[t, e0, y, z]) + tv[t, x, y, z])', 'Eq(tv, cos(tu[t, e0, y, z]) + tw)'], @@ -93,10 +93,10 @@ def test_default_algo(exprs, expected, min_cost): tw = TimeFunction(name="tw", grid=grid, space_order=2) # noqa tz = TimeFunction(name="tz", grid=grid, space_order=2) # noqa fx = Function(name="fx", grid=grid, dimensions=(x,), shape=(3,)) # noqa - ti0 = Array(name='ti0', shape=(3, 5, 7), dimensions=(x, y, z), - dtype=np.float32).indexify() # noqa - ti1 = Array(name='ti1', shape=(3, 5, 7), dimensions=(x, y, z), - dtype=np.float32).indexify() # noqa + ti0 = Array(name='ti0', shape=(3, 5, 7), dimensions=(x, y, z), # noqa + dtype=np.float32).indexify() + ti1 = Array(name='ti1', shape=(3, 5, 7), dimensions=(x, y, z), # noqa + dtype=np.float32).indexify() t0 = CTemp(name='t0', dtype=np.float32) # noqa t1 = CTemp(name='t1', dtype=np.float32) # noqa t2 = CTemp(name='t2', dtype=np.float32) # noqa From 469aee23a8557d68ec028f7ca959480ab8bac80f Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 21 Aug 2024 12:54:00 +0000 Subject: [PATCH 13/16] compiler: Fixup factorization setup --- devito/passes/clusters/factorization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index da9b74a7db..766e2c2baa 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -27,7 +27,10 @@ def factorize(cluster, *args, options=None, **kwargs): then the algorithm is applied recursively until no more factorization opportunities are detected. """ - strategy = options.get('fact_schedule', 'basic') + try: + strategy = options['fact-schedule'] + except TypeError: + strategy = 'basic' processed = [] for expr in cluster.exprs: From b7186eba6153401223aa37dff545a08e4282f791 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 21 Aug 2024 12:54:08 +0000 Subject: [PATCH 14/16] examples: Update expected output --- examples/performance/00_overview.ipynb | 12 ++++++------ examples/performance/01_gpu.ipynb | 3 +-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/performance/00_overview.ipynb b/examples/performance/00_overview.ipynb index abe32cd8f1..3ed1a45014 100644 --- a/examples/performance/00_overview.ipynb +++ b/examples/performance/00_overview.ipynb @@ -719,7 +719,7 @@ " #pragma omp simd aligned(f,u:32)\n", " for (int z = z_m; z <= z_M; z += 1)\n", " {\n", - " u[t1][x + 4][y + 4][z + 4] = ((-6.66666667e-1F/h_y)*r0[y + 1][z] + (-8.33333333e-2F/h_y)*r0[y + 4][z] + (8.33333333e-2F/h_y)*r0[y][z] + (6.66666667e-1F/h_y)*r0[y + 3][z])*sinf(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n", + " u[t1][x + 4][y + 4][z + 4] = (8.33333333e-2F*r0[y][z]/h_y - 6.66666667e-1F*r0[y + 1][z]/h_y + 6.66666667e-1F*r0[y + 3][z]/h_y - 8.33333333e-2F*r0[y + 4][z]/h_y)*sinf(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n", " }\n", " }\n", " }\n", @@ -783,7 +783,7 @@ " {\n", " for (int z = z_m; z <= z_M; z += 1)\n", " {\n", - " u[t1][x + 4][y + 4][z + 4] = ((-6.66666667e-1F/h_x)*r0[x + 1][y + 2][z] + (-8.33333333e-2F/h_x)*r0[x + 4][y + 2][z] + (8.33333333e-2F/h_x)*r0[x][y + 2][z] + (6.66666667e-1F/h_x)*r0[x + 3][y + 2][z] + (-6.66666667e-1F/h_y)*r1[x + 2][y + 1][z] + (-8.33333333e-2F/h_y)*r1[x + 2][y + 4][z] + (8.33333333e-2F/h_y)*r1[x + 2][y][z] + (6.66666667e-1F/h_y)*r1[x + 2][y + 3][z])*sinf(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n", + " u[t1][x + 4][y + 4][z + 4] = (8.33333333e-2F*r0[x][y + 2][z]/h_x - 6.66666667e-1F*r0[x + 1][y + 2][z]/h_x + 6.66666667e-1F*r0[x + 3][y + 2][z]/h_x - 8.33333333e-2F*r0[x + 4][y + 2][z]/h_x + 8.33333333e-2F*r1[x + 2][y][z]/h_y - 6.66666667e-1F*r1[x + 2][y + 1][z]/h_y + 6.66666667e-1F*r1[x + 2][y + 3][z]/h_y - 8.33333333e-2F*r1[x + 2][y + 4][z]/h_y)*sinf(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n", " }\n", " }\n", " }\n", @@ -854,7 +854,7 @@ " #pragma omp simd aligned(f,u:32)\n", " for (int z = z_m; z <= z_M; z += 1)\n", " {\n", - " u[t1][x + 4][y + 4][z + 4] = ((-6.66666667e-1F/h_x)*r0[x + 1][y][z] + (-8.33333333e-2F/h_x)*r0[x + 4][y][z] + (8.33333333e-2F/h_x)*r0[x][y][z] + (6.66666667e-1F/h_x)*r0[x + 3][y][z] + (-6.66666667e-1F/h_y)*r1[y + 1][z] + (-8.33333333e-2F/h_y)*r1[y + 4][z] + (8.33333333e-2F/h_y)*r1[y][z] + (6.66666667e-1F/h_y)*r1[y + 3][z])*sinf(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n", + " u[t1][x + 4][y + 4][z + 4] = (8.33333333e-2F*r0[x][y][z]/h_x - 6.66666667e-1F*r0[x + 1][y][z]/h_x + 6.66666667e-1F*r0[x + 3][y][z]/h_x - 8.33333333e-2F*r0[x + 4][y][z]/h_x + 8.33333333e-2F*r1[y][z]/h_y - 6.66666667e-1F*r1[y + 1][z]/h_y + 6.66666667e-1F*r1[y + 3][z]/h_y - 8.33333333e-2F*r1[y + 4][z]/h_y)*sinf(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n", " }\n", " }\n", " }\n", @@ -986,7 +986,7 @@ " #pragma omp simd aligned(f,u:32)\n", " for (int z = z_m; z <= z_M; z += 1)\n", " {\n", - " u[t1][x + 4][y + 4][z + 4] = ((-6.66666667e-1F/h_y)*r0[y + 1][z] + (-8.33333333e-2F/h_y)*r0[y + 4][z] + (8.33333333e-2F/h_y)*r0[y][z] + (6.66666667e-1F/h_y)*r0[y + 3][z])*sinf(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n", + " u[t1][x + 4][y + 4][z + 4] = (8.33333333e-2F*r0[y][z]/h_y - 6.66666667e-1F*r0[y + 1][z]/h_y + 6.66666667e-1F*r0[y + 3][z]/h_y - 8.33333333e-2F*r0[y + 4][z]/h_y)*sinf(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n", " }\n", " }\n", " }\n", @@ -1113,7 +1113,7 @@ " {\n", " for (int z = z_m; z <= z_M; z += 1)\n", " {\n", - " u[t1][x + 4][y + 4][z + 4] = ((-6.66666667e-1F/h_y)*r0[x][y + 1][z] + (-8.33333333e-2F/h_y)*r0[x][y + 4][z] + (8.33333333e-2F/h_y)*r0[x][y][z] + (6.66666667e-1F/h_y)*r0[x][y + 3][z])*sinf(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n", + " u[t1][x + 4][y + 4][z + 4] = (8.33333333e-2F*r0[x][y][z]/h_y - 6.66666667e-1F*r0[x][y + 1][z]/h_y + 6.66666667e-1F*r0[x][y + 3][z]/h_y - 8.33333333e-2F*r0[x][y + 4][z]/h_y)*sinf(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n", " }\n", " }\n", " }\n", @@ -1692,7 +1692,7 @@ " #pragma omp simd aligned(f,u:32)\n", " for (int z = z_m; z <= z_M; z += 1)\n", " {\n", - " u[t1][x + 4][y + 4][z + 4] = (f[x + 1][y + 1][z + 1]*f[x + 1][y + 1][z + 1])*(r1*(8.33333333e-2F*(r3[x][y + 2][z] - r3[x + 4][y + 2][z]) + 6.66666667e-1F*(-r3[x + 1][y + 2][z] + r3[x + 3][y + 2][z])) + r2*(8.33333333e-2F*(r4[x + 2][y][z] - r4[x + 2][y + 4][z]) + 6.66666667e-1F*(-r4[x + 2][y + 1][z] + r4[x + 2][y + 3][z])))*r0[x][y][z];\n", + " u[t1][x + 4][y + 4][z + 4] = (r1*(8.33333333e-2F*(r3[x][y + 2][z] - r3[x + 4][y + 2][z]) + 6.66666667e-1F*(-r3[x + 1][y + 2][z] + r3[x + 3][y + 2][z])) + r2*(8.33333333e-2F*(r4[x + 2][y][z] - r4[x + 2][y + 4][z]) + 6.66666667e-1F*(-r4[x + 2][y + 1][z] + r4[x + 2][y + 3][z])))*f[x + 1][y + 1][z + 1]*f[x + 1][y + 1][z + 1]*r0[x][y][z];\n", " }\n", " }\n", " }\n", diff --git a/examples/performance/01_gpu.ipynb b/examples/performance/01_gpu.ipynb index 02fabe139d..78c7b0f81a 100644 --- a/examples/performance/01_gpu.ipynb +++ b/examples/performance/01_gpu.ipynb @@ -307,8 +307,7 @@ " {\n", " for (int y = y_m; y <= y_M; y += 1)\n", " {\n", - " float r3 = -2.0F*uL0(time, x + 2, y + 2);\n", - " uL0(time + 1, x + 2, y + 2) = dt*(c*(r3*r1 + r3*r2 + r1*uL0(time, x + 1, y + 2) + r1*uL0(time, x + 3, y + 2) + r2*uL0(time, x + 2, y + 1) + r2*uL0(time, x + 2, y + 3)) + r0*uL0(time, x + 2, y + 2));\n", + " uL0(time + 1, x + 2, y + 2) = dt*(c*(r1*uL0(time, x + 1, y + 2) + r1*uL0(time, x + 3, y + 2) + r2*uL0(time, x + 2, y + 1) + r2*uL0(time, x + 2, y + 3) - 2.0F*(r1*uL0(time, x + 2, y + 2) + r2*uL0(time, x + 2, y + 2))) + r0*uL0(time, x + 2, y + 2));\n", " }\n", " }\n", " STOP(section0,timers)\n", From ad5fd141a735b651214a2fbe185d5d0eeaa470d5 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 21 Aug 2024 15:15:59 +0000 Subject: [PATCH 15/16] tests: Relax tolerance --- tests/test_linearize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_linearize.py b/tests/test_linearize.py index 0a95d1e3f3..1b531e6ab8 100644 --- a/tests/test_linearize.py +++ b/tests/test_linearize.py @@ -49,7 +49,7 @@ def test_mpi(mode): op0.apply(time_M=10) op1.apply(time_M=10, u=u1) - assert np.all(u.data == u1.data) + assert np.allclose(u.data, u1.data, rtol=1e-5) def test_cire(): From b1a10eeb96f86d762fd2b8fe95d21ca88cfaa9c7 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 22 Aug 2024 07:43:17 +0000 Subject: [PATCH 16/16] compiler: Polish factorizer --- devito/passes/clusters/factorization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index 766e2c2baa..8007d3fee2 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -147,7 +147,7 @@ def collect_const(expr): # Any factorization possible? if len(inverse_mapper) == len(expr.args) or \ - list(inverse_mapper) == [1]: + (len(inverse_mapper) == 1 and 1 in inverse_mapper): return expr terms = [] @@ -216,9 +216,9 @@ def _collect_nested(expr, strategy): expr = reuse_if_untouched(expr, args, evaluate=True) return expr, ReducerMap.fromdicts(*candidates) elif expr.is_Equality: - args, candidates = zip(*[_collect_nested(expr.lhs, strategy), - _collect_nested(expr.rhs, strategy)]) - return expr.func(*args, evaluate=False), ReducerMap.fromdicts(*candidates) + rhs, _ = _collect_nested(expr.rhs, strategy) + expr = reuse_if_untouched(expr, (expr.lhs, rhs)) + return expr, {} else: args, candidates = zip(*[_collect_nested(a, strategy) for a in expr.args]) return expr.func(*args), ReducerMap.fromdicts(*candidates)