Skip to content

Commit

Permalink
Merge pull request #2254 from devitocodes/fix-fd-shct
Browse files Browse the repository at this point in the history
api: prevent derivative shortcut with incompatible fd order
  • Loading branch information
mloubout authored Oct 31, 2023
2 parents d4ebfa9 + e6ca17e commit 8e33092
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 3 deletions.
11 changes: 10 additions & 1 deletion devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,16 @@ def is_TimeDependent(self):

@cached_property
def _fd(self):
return dict(ChainMap(*[getattr(i, '_fd', {}) for i in self._args_diff]))
# Filter out all args with fd order too high
fd_args = []
for f in self._args_diff:
try:
if f.space_order <= self.space_order and \
(not f.is_TimeDependent or f.time_order <= self.time_order):
fd_args.append(f)
except AttributeError:
pass
return dict(ChainMap(*[getattr(i, '_fd', {}) for i in fd_args]))

@cached_property
def _symbolic_functions(self):
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def dtype(self):
If two Clusters perform calculations with different precision, the
data type with highest precision is returned.
"""
dtypes = {i.dtype for i in self}
dtypes = {i.dtype for i in self} - {None}

return infer_dtype(dtypes)

Expand Down
15 changes: 15 additions & 0 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,21 @@ def test_all_shortcuts(self, so):
for fd in g._fd:
assert getattr(g, fd)

for d in grid.dimensions:
assert 'd%s' % d.name in f._fd
assert 'd%s' % d.name in g._fd
for o in range(2, min(7, so+1)):
assert 'd%s%s' % (d.name, o) in f._fd
assert 'd%s%s' % (d.name, o) in g._fd

def test_shortcuts_mixed(self):
grid = Grid(shape=(10,))
f = Function(name='f', grid=grid, space_order=2)
g = Function(name='g', grid=grid, space_order=4)
assert 'dx4' not in (f*g)._fd
assert 'dx4' not in (f+g)._fd
assert 'dx4' not in (g*f.dx)._fd

def test_transpose_simple(self):
grid = Grid(shape=(4, 4))

Expand Down
3 changes: 2 additions & 1 deletion tests/test_lower_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ def test_symbolic_constant_times_add(self):
dt = grid.time_dim.spacing

u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2)
f = Function(name='f', grid=grid)
f = Function(name='f', grid=grid, space_order=4)

eq = Eq(u.forward, u.laplace + dt**0.2*u.biharmonic(1/f))

leq = collect_derivatives.func([eq])[0]

assert len(eq.rhs.args) == 3
Expand Down

0 comments on commit 8e33092

Please sign in to comment.