Skip to content

Commit f1e6a4b

Browse files
Merge pull request #727 from zwicker-group/expression_consts
Fix ScalarExpression and TensorExpression to support constants in differentiation and __getitem__
2 parents 0bef99a + 749d66e commit f1e6a4b

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

pde/tools/expressions.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,7 @@ def differentiate(self, var: str) -> ScalarExpression:
714714
signature=self.vars,
715715
allow_indexed=self.allow_indexed,
716716
user_funcs=self.user_funcs,
717+
consts=self.consts,
717718
)
718719

719720
@cached_property()
@@ -732,7 +733,10 @@ def derivatives(self) -> TensorExpression:
732733

733734
grad = sympy.Array([self._sympy_expr.diff(sympy.Symbol(v)) for v in self.vars])
734735
return TensorExpression(
735-
sympy.simplify(grad), signature=self.vars, user_funcs=self.user_funcs
736+
sympy.simplify(grad),
737+
signature=self.vars,
738+
user_funcs=self.user_funcs,
739+
consts=self.consts,
736740
)
737741

738742

@@ -788,6 +792,10 @@ def __init__(
788792
user_funcs = expression.user_funcs
789793
else:
790794
user_funcs.update(expression.user_funcs)
795+
if consts is None:
796+
consts = expression.consts
797+
else:
798+
consts.update(expression.consts)
791799

792800
elif isinstance(expression, (np.ndarray, list, tuple)):
793801
# expression is a constant array
@@ -834,11 +842,17 @@ def __getitem__(self, index):
834842
expr = self._sympy_expr[index]
835843
if isinstance(expr, sympy.Array):
836844
return TensorExpression(
837-
expr, signature=self.vars, user_funcs=self.user_funcs
845+
expr,
846+
signature=self.vars,
847+
user_funcs=self.user_funcs,
848+
consts=self.consts,
838849
)
839850
else:
840851
return ScalarExpression(
841-
expr, signature=self.vars, user_funcs=self.user_funcs
852+
expr,
853+
signature=self.vars,
854+
user_funcs=self.user_funcs,
855+
consts=self.consts,
842856
)
843857

844858
@property
@@ -870,7 +884,9 @@ def differentiate(self, var: str) -> TensorExpression:
870884
derivative = np.zeros(self.shape)
871885
else:
872886
derivative = self._sympy_expr.diff(sympy.Symbol(var))
873-
return TensorExpression(derivative, self.vars, user_funcs=self.user_funcs)
887+
return TensorExpression(
888+
derivative, self.vars, user_funcs=self.user_funcs, consts=self.consts
889+
)
874890

875891
@cached_property()
876892
def derivatives(self) -> TensorExpression:
@@ -885,7 +901,9 @@ def derivatives(self) -> TensorExpression:
885901
dx = sympy.Array([sympy.Symbol(s) for s in self.vars])
886902
derivatives = sympy.derive_by_array(self._sympy_expr, dx)
887903

888-
return TensorExpression(derivatives, self.vars, user_funcs=self.user_funcs)
904+
return TensorExpression(
905+
derivatives, self.vars, user_funcs=self.user_funcs, consts=self.consts
906+
)
889907

890908
def get_compiled_array(
891909
self, single_arg: bool = True

tests/tools/test_expressions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,23 @@ def test_expression_consts():
348348
np.testing.assert_allclose(expr(np.array([2, 3])), np.array([3, 5]))
349349
np.testing.assert_allclose(expr.get_compiled()(np.array([2, 3])), np.array([3, 5]))
350350

351+
expr = ScalarExpression("a * b", consts={"a": np.array([1, 2])})
352+
dexpr_da = expr.differentiate("b")
353+
np.testing.assert_allclose(dexpr_da(np.array([2, 3])), np.array([1, 2]))
354+
dexpr = expr.derivatives
355+
assert dexpr.shape == (1,)
356+
np.testing.assert_allclose(dexpr(np.array([2, 3])), np.array([[1, 2]]))
357+
358+
359+
def test_tensor_expression_consts():
360+
"""Test the usage of consts in TensorExpression."""
361+
e = TensorExpression("[a, a*b]", consts={"b": 5})
362+
assert e[0](2) == 2
363+
assert e[1](2) == 10
364+
d1 = e.differentiate("a")
365+
assert d1[0](2) == 1
366+
assert d1[1](2) == 5
367+
351368

352369
def test_evaluate_func_scalar():
353370
"""Test the evaluate function with scalar fields."""

0 commit comments

Comments
 (0)