Skip to content

Commit

Permalink
api: prevent factorization for symbolic coefficients
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Aug 3, 2023
1 parent a7a7e1c commit bcdae34
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
5 changes: 5 additions & 0 deletions devito/passes/equations/linearity.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ def _(expr):
if not derivs:
return reuse_if_untouched(expr, args)

# Cannot factorize derivatives with symbolic coefficients since
# they may have different coefficient values at evaluation
if any(d._uses_symbolic_coefficients for d in derivs):
return reuse_if_untouched(expr, args)

# Map by type of derivative
# Note: `D0(a) + D1(b) == D(a + b)` <=> `D0` and `D1`'s metadata match,
# i.e. they are the same type of derivative
Expand Down
15 changes: 15 additions & 0 deletions tests/test_symbolic_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Dimension, solve, Operator, NODE)
from devito.finite_differences import Differentiable
from devito.tools import as_tuple
from devito.passes.equations.linearity import factorize_derivatives

_PRECISION = 9

Expand Down Expand Up @@ -344,3 +345,17 @@ def test_with_timefunction(self, stagger):
Operator([eq_f, eq_g])(t_m=0, t_M=1)

assert np.allclose(f.data[-1], -g.data[-1], atol=1e-7)

def test_collect_w_custom_coeffs(self):
grid = Grid(shape=(11, 11, 11))
p = TimeFunction(name='p', grid=grid, space_order=8, time_order=2,
coefficients='symbolic')

q = TimeFunction(name='q', grid=grid, space_order=8, time_order=2,
coefficients='symbolic')

expr = p.dx2 + q.dx2
collected = factorize_derivatives(expr)
assert collected == expr
assert collected.is_Add
Operator([Eq(p.forward, expr)])(time_M=2) # noqa

0 comments on commit bcdae34

Please sign in to comment.