From bcdae345030b4a061e7f3479918317de55cb7919 Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 3 Aug 2023 09:41:49 -0400 Subject: [PATCH] api: prevent factorization for symbolic coefficients --- devito/passes/equations/linearity.py | 5 +++++ tests/test_symbolic_coefficients.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/devito/passes/equations/linearity.py b/devito/passes/equations/linearity.py index 44429caa23..ee905041d4 100644 --- a/devito/passes/equations/linearity.py +++ b/devito/passes/equations/linearity.py @@ -187,6 +187,11 @@ def _(expr): if not derivs: return reuse_if_untouched(expr, args) + # Cannot factorize derivatives with symbolic coefficients since + # they may have different coefficient values at evaluation + if any(d._uses_symbolic_coefficients for d in derivs): + return reuse_if_untouched(expr, args) + # Map by type of derivative # Note: `D0(a) + D1(b) == D(a + b)` <=> `D0` and `D1`'s metadata match, # i.e. they are the same type of derivative diff --git a/tests/test_symbolic_coefficients.py b/tests/test_symbolic_coefficients.py index 40da905f3f..05843fcc61 100644 --- a/tests/test_symbolic_coefficients.py +++ b/tests/test_symbolic_coefficients.py @@ -6,6 +6,7 @@ Dimension, solve, Operator, NODE) from devito.finite_differences import Differentiable from devito.tools import as_tuple +from devito.passes.equations.linearity import factorize_derivatives _PRECISION = 9 @@ -344,3 +345,17 @@ def test_with_timefunction(self, stagger): Operator([eq_f, eq_g])(t_m=0, t_M=1) assert np.allclose(f.data[-1], -g.data[-1], atol=1e-7) + + def test_collect_w_custom_coeffs(self): + grid = Grid(shape=(11, 11, 11)) + p = TimeFunction(name='p', grid=grid, space_order=8, time_order=2, + coefficients='symbolic') + + q = TimeFunction(name='q', grid=grid, space_order=8, time_order=2, + coefficients='symbolic') + + expr = p.dx2 + q.dx2 + collected = factorize_derivatives(expr) + assert collected == expr + assert collected.is_Add + Operator([Eq(p.forward, expr)])(time_M=2) # noqa