From f9adc13fbcc8bf3cc3b344b620ca56d88b499b35 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Mon, 29 Apr 2024 15:57:13 +0100 Subject: [PATCH 1/2] dsl: Sanitise zeroth-order derivatives --- devito/finite_differences/derivative.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index 9e13454170..1fc631bea1 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -138,6 +138,9 @@ def _process_kwargs(cls, expr, *dims, **kwargs): variable_count = [sympy.Tuple(s, dims.count(s)) for s in filter_ordered(dims)] return dims, deriv_orders, fd_orders, variable_count + # Sanitise `dims`. ((x, 2), (y, 0)) is valid input, but (y, 0) should be dropped. + dims = tuple(d for d in dims if not (isinstance(d, Iterable) and d[1] == 0)) + # Check `dims`. It can be a single Dimension, an iterable of Dimensions, or even # an iterable of 2-tuple (Dimension, deriv_order) if len(dims) == 0: From ca2d74836db788da80c2420cbfa64f0e8b5d784a Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 30 Apr 2024 11:05:52 +0100 Subject: [PATCH 2/2] tests: Add test for derivative specification --- devito/finite_differences/derivative.py | 1 + tests/test_derivatives.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index 1fc631bea1..889cfbfdbc 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -138,6 +138,7 @@ def _process_kwargs(cls, expr, *dims, **kwargs): variable_count = [sympy.Tuple(s, dims.count(s)) for s in filter_ordered(dims)] return dims, deriv_orders, fd_orders, variable_count + # Sanitise `dims`. ((x, 2), (y, 0)) is valid input, but (y, 0) should be dropped. dims = tuple(d for d in dims if not (isinstance(d, Iterable) and d[1] == 0)) diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index e6f5849925..0ee95450c7 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -632,6 +632,27 @@ def test_substitution(self): assert f1 is not f2 assert f1.subs(f2, -1) == -1 + def test_zero_spec(self): + """ + Test that derivatives specified as Derivative(f, (x, 2), (y, 0)) are + correctly handled. + """ + grid = Grid((11, 11)) + x, y = grid.dimensions + f = Function(name="f", grid=grid, space_order=4) + # Check that both specifications match + drv0 = Derivative(f, (x, 2)) + drv1 = Derivative(f, (x, 2), (y, 0)) + assert drv0.dims == drv1.dims + assert drv0.fd_order == drv1.fd_order + assert drv0.deriv_order == drv1.deriv_order + + # Check that substitution can applied correctly + expr0 = drv0 + 1 + expr1 = drv1 + 1 + assert expr0.subs(drv0, drv1) == expr1 + assert expr1.subs(drv1, drv0) == expr0 + class TestTwoStageEvaluation(object):