Skip to content

Commit 748e801

Browse files
Merge pull request #700 from zwicker-group/test_mixed_derivatives
Add tests for mixed derivatives in Cartesian and cylindrical grids
2 parents 25d861a + 597f53b commit 748e801

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

pde/fields/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -636,15 +636,15 @@ def apply(
636636
Args:
637637
func (callable or str):
638638
The (vectorized) function being applied to the data or an expression
639-
that can be parsed using sympy (:func:`~pde.tools.expression.evaluate`
639+
that can be parsed using sympy (:func:`~pde.tools.expressions.evaluate`
640640
is used in this case). The local field values can be accessed using the
641641
field labels for a field collection and via the variable `c` otherwise.
642642
out (FieldBase, optional):
643643
Optional field into which the data is written
644644
label (str, optional):
645645
Name of the returned field
646646
evaluate_args (dict):
647-
Additional arguments passed to :func:`~pde.tools.expression.evaluate`.
647+
Additional arguments passed to :func:`~pde.tools.expressions.evaluate`.
648648
Only used when `func` is a string.
649649
650650
Returns:
@@ -660,7 +660,10 @@ def apply(
660660
if evaluate_args is None:
661661
evaluate_args = {}
662662
if isinstance(self, DataFieldBase):
663-
result = evaluate(func, {"c": self}, **evaluate_args)
663+
fields = {"c": self}
664+
if self.label is not None:
665+
fields[self.label] = self
666+
result = evaluate(func, fields, **evaluate_args)
664667
elif isinstance(self, FieldCollection):
665668
result = evaluate(func, self, **evaluate_args)
666669
else:

tests/grids/test_cartesian_grids.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,14 @@ def test_9point_stencil(periodic, corner_weight):
353353
reference = field.laplace(bc="auto_periodic_neumann")
354354
test = field.laplace(bc="auto_periodic_neumann", corner_weight=corner_weight)
355355
np.testing.assert_allclose(reference.data, test.data, atol=corner_weight / 3)
356+
357+
358+
@pytest.mark.parametrize("periodic", [True, False])
359+
def test_mixed_derivatives(periodic):
360+
"""Test mixed derivatives of scalar fields."""
361+
grid = CartesianGrid([[0, 1], [-1, 0.5]], [7, 9], periodic=periodic)
362+
field = ScalarField.random_normal(grid, label="fld")
363+
364+
res1 = field.apply("d_dx(d_dy(fld))")
365+
res2 = field.apply("d_dy(d_dx(fld))")
366+
np.testing.assert_allclose(res1.data, res2.data)

tests/grids/test_cylindrical_grids.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,13 @@ def test_setting_boundary_conditions():
7777
grid.get_boundary_conditions({"r": "derivative", "z": "periodic"})
7878
with pytest.raises(RuntimeError):
7979
grid.get_boundary_conditions({"r": "derivative", "z": "derivative"})
80+
81+
82+
def test_mixed_derivatives():
83+
"""Test mixed derivatives of scalar fields."""
84+
grid = CylindricalSymGrid(1, [-1, 0.5], [7, 9])
85+
field = ScalarField.random_normal(grid, label="c")
86+
87+
res1 = field.apply("d_dz(d_dr(c))")
88+
res2 = field.apply("d_dr(d_dz(c))")
89+
np.testing.assert_allclose(res1.data, res2.data)

0 commit comments

Comments
 (0)