Skip to content

Commit 8e33092

Browse files
authored
Merge pull request #2254 from devitocodes/fix-fd-shct
api: prevent derivative shortcut with incompatible fd order
2 parents d4ebfa9 + e6ca17e commit 8e33092

File tree

4 files changed

+28
-3
lines changed

4 files changed

+28
-3
lines changed

devito/finite_differences/differentiable.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,16 @@ def is_TimeDependent(self):
105105

106106
@cached_property
107107
def _fd(self):
108-
return dict(ChainMap(*[getattr(i, '_fd', {}) for i in self._args_diff]))
108+
# Filter out all args with fd order too high
109+
fd_args = []
110+
for f in self._args_diff:
111+
try:
112+
if f.space_order <= self.space_order and \
113+
(not f.is_TimeDependent or f.time_order <= self.time_order):
114+
fd_args.append(f)
115+
except AttributeError:
116+
pass
117+
return dict(ChainMap(*[getattr(i, '_fd', {}) for i in fd_args]))
109118

110119
@cached_property
111120
def _symbolic_functions(self):

devito/ir/clusters/cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def dtype(self):
483483
If two Clusters perform calculations with different precision, the
484484
data type with highest precision is returned.
485485
"""
486-
dtypes = {i.dtype for i in self}
486+
dtypes = {i.dtype for i in self} - {None}
487487

488488
return infer_dtype(dtypes)
489489

tests/test_derivatives.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,21 @@ def test_all_shortcuts(self, so):
457457
for fd in g._fd:
458458
assert getattr(g, fd)
459459

460+
for d in grid.dimensions:
461+
assert 'd%s' % d.name in f._fd
462+
assert 'd%s' % d.name in g._fd
463+
for o in range(2, min(7, so+1)):
464+
assert 'd%s%s' % (d.name, o) in f._fd
465+
assert 'd%s%s' % (d.name, o) in g._fd
466+
467+
def test_shortcuts_mixed(self):
468+
grid = Grid(shape=(10,))
469+
f = Function(name='f', grid=grid, space_order=2)
470+
g = Function(name='g', grid=grid, space_order=4)
471+
assert 'dx4' not in (f*g)._fd
472+
assert 'dx4' not in (f+g)._fd
473+
assert 'dx4' not in (g*f.dx)._fd
474+
460475
def test_transpose_simple(self):
461476
grid = Grid(shape=(4, 4))
462477

tests/test_lower_exprs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,10 @@ def test_symbolic_constant_times_add(self):
6565
dt = grid.time_dim.spacing
6666

6767
u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2)
68-
f = Function(name='f', grid=grid)
68+
f = Function(name='f', grid=grid, space_order=4)
6969

7070
eq = Eq(u.forward, u.laplace + dt**0.2*u.biharmonic(1/f))
71+
7172
leq = collect_derivatives.func([eq])[0]
7273

7374
assert len(eq.rhs.args) == 3

0 commit comments

Comments
 (0)