Skip to content

Commit bcdae34

Browse files
committed
api: prevent factorization for symbolic coefficients
1 parent a7a7e1c commit bcdae34

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

devito/passes/equations/linearity.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,11 @@ def _(expr):
187187
if not derivs:
188188
return reuse_if_untouched(expr, args)
189189

190+
# Cannot factorize derivatives with symbolic coefficients since
191+
# they may have different coefficient values at evaluation
192+
if any(d._uses_symbolic_coefficients for d in derivs):
193+
return reuse_if_untouched(expr, args)
194+
190195
# Map by type of derivative
191196
# Note: `D0(a) + D1(b) == D(a + b)` <=> `D0` and `D1`'s metadata match,
192197
# i.e. they are the same type of derivative

tests/test_symbolic_coefficients.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Dimension, solve, Operator, NODE)
77
from devito.finite_differences import Differentiable
88
from devito.tools import as_tuple
9+
from devito.passes.equations.linearity import factorize_derivatives
910

1011
_PRECISION = 9
1112

@@ -344,3 +345,17 @@ def test_with_timefunction(self, stagger):
344345
Operator([eq_f, eq_g])(t_m=0, t_M=1)
345346

346347
assert np.allclose(f.data[-1], -g.data[-1], atol=1e-7)
348+
349+
def test_collect_w_custom_coeffs(self):
350+
grid = Grid(shape=(11, 11, 11))
351+
p = TimeFunction(name='p', grid=grid, space_order=8, time_order=2,
352+
coefficients='symbolic')
353+
354+
q = TimeFunction(name='q', grid=grid, space_order=8, time_order=2,
355+
coefficients='symbolic')
356+
357+
expr = p.dx2 + q.dx2
358+
collected = factorize_derivatives(expr)
359+
assert collected == expr
360+
assert collected.is_Add
361+
Operator([Eq(p.forward, expr)])(time_M=2) # noqa

0 commit comments

Comments
 (0)