Skip to content

Commit

Permalink
Merge pull request #2366 from devitocodes/deriv_subs_patch
Browse files Browse the repository at this point in the history
dsl: Patch edge-case derivative specifications
  • Loading branch information
FabioLuporini authored Apr 30, 2024
2 parents daf782e + ca2d748 commit cc04dba
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
4 changes: 4 additions & 0 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
variable_count = [sympy.Tuple(s, dims.count(s))
for s in filter_ordered(dims)]
return dims, deriv_orders, fd_orders, variable_count

# Sanitise `dims`. ((x, 2), (y, 0)) is valid input, but (y, 0) should be dropped.
dims = tuple(d for d in dims if not (isinstance(d, Iterable) and d[1] == 0))

# Check `dims`. It can be a single Dimension, an iterable of Dimensions, or even
# an iterable of 2-tuple (Dimension, deriv_order)
if len(dims) == 0:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,27 @@ def test_substitution(self):
assert f1 is not f2
assert f1.subs(f2, -1) == -1

def test_zero_spec(self):
"""
Test that derivatives specified as Derivative(f, (x, 2), (y, 0)) are
correctly handled.
"""
grid = Grid((11, 11))
x, y = grid.dimensions
f = Function(name="f", grid=grid, space_order=4)
# Check that both specifications match
drv0 = Derivative(f, (x, 2))
drv1 = Derivative(f, (x, 2), (y, 0))
assert drv0.dims == drv1.dims
assert drv0.fd_order == drv1.fd_order
assert drv0.deriv_order == drv1.deriv_order

# Check that substitution can applied correctly
expr0 = drv0 + 1
expr1 = drv1 + 1
assert expr0.subs(drv0, drv1) == expr1
assert expr1.subs(drv1, drv0) == expr0


class TestTwoStageEvaluation(object):

Expand Down

0 comments on commit cc04dba

Please sign in to comment.