@@ -714,6 +714,7 @@ def differentiate(self, var: str) -> ScalarExpression:
714714 signature = self .vars ,
715715 allow_indexed = self .allow_indexed ,
716716 user_funcs = self .user_funcs ,
717+ consts = self .consts ,
717718 )
718719
719720 @cached_property ()
@@ -732,7 +733,10 @@ def derivatives(self) -> TensorExpression:
732733
733734 grad = sympy .Array ([self ._sympy_expr .diff (sympy .Symbol (v )) for v in self .vars ])
734735 return TensorExpression (
735- sympy .simplify (grad ), signature = self .vars , user_funcs = self .user_funcs
736+ sympy .simplify (grad ),
737+ signature = self .vars ,
738+ user_funcs = self .user_funcs ,
739+ consts = self .consts ,
736740 )
737741
738742
@@ -788,6 +792,10 @@ def __init__(
788792 user_funcs = expression .user_funcs
789793 else :
790794 user_funcs .update (expression .user_funcs )
795+ if consts is None :
796+ consts = expression .consts
797+ else :
798+ consts .update (expression .consts )
791799
792800 elif isinstance (expression , (np .ndarray , list , tuple )):
793801 # expression is a constant array
@@ -834,11 +842,17 @@ def __getitem__(self, index):
834842 expr = self ._sympy_expr [index ]
835843 if isinstance (expr , sympy .Array ):
836844 return TensorExpression (
837- expr , signature = self .vars , user_funcs = self .user_funcs
845+ expr ,
846+ signature = self .vars ,
847+ user_funcs = self .user_funcs ,
848+ consts = self .consts ,
838849 )
839850 else :
840851 return ScalarExpression (
841- expr , signature = self .vars , user_funcs = self .user_funcs
852+ expr ,
853+ signature = self .vars ,
854+ user_funcs = self .user_funcs ,
855+ consts = self .consts ,
842856 )
843857
844858 @property
@@ -870,7 +884,9 @@ def differentiate(self, var: str) -> TensorExpression:
870884 derivative = np .zeros (self .shape )
871885 else :
872886 derivative = self ._sympy_expr .diff (sympy .Symbol (var ))
873- return TensorExpression (derivative , self .vars , user_funcs = self .user_funcs )
887+ return TensorExpression (
888+ derivative , self .vars , user_funcs = self .user_funcs , consts = self .consts
889+ )
874890
875891 @cached_property ()
876892 def derivatives (self ) -> TensorExpression :
@@ -885,7 +901,9 @@ def derivatives(self) -> TensorExpression:
885901 dx = sympy .Array ([sympy .Symbol (s ) for s in self .vars ])
886902 derivatives = sympy .derive_by_array (self ._sympy_expr , dx )
887903
888- return TensorExpression (derivatives , self .vars , user_funcs = self .user_funcs )
904+ return TensorExpression (
905+ derivatives , self .vars , user_funcs = self .user_funcs , consts = self .consts
906+ )
889907
890908 def get_compiled_array (
891909 self , single_arg : bool = True
0 commit comments