@@ -166,159 +166,3 @@ def _print_exp(self, expr):
166166 self ._module_format (self ._module + ".exp" ),
167167 "," .join (self ._print (i ) for i in expr .args ),
168168 )
169-
170-
171- def numba_lambdify (
172- inputs : list [sp .Symbol ],
173- expr : list [sp .Expr ] | sp .Matrix | list [sp .Matrix ],
174- func_signature : str | None = None ,
175- ravel_outputs = False ,
176- stack_outputs = False ,
177- ) -> Callable :
178- """
179- Convert a sympy expression into a Numba-compiled function. Unlike sp.lambdify, the resulting function can be
180- pickled. In addition, common sub-expressions are gathered using sp.cse and assigned to local variables,
181- giving a (very) modest performance boost. A signature can optionally be provided for numba.njit.
182-
183- Finally, the resulting function always returns a numpy array, rather than a list.
184-
185- Parameters
186- ----------
187- inputs: list of sympy.Symbol
188- A list of "exogenous" variables. The distinction between "exogenous" and "enodgenous" is
189- useful when passing the resulting function to a scipy.optimize optimizer. In this context, exogenous
190- variables should be the choice varibles used to minimize the function.
191- expr : list of sympy.Expr or sp.Matrix
192- The sympy expression(s) to be converted. Expects a list of expressions (in the case that we're compiling a
193- system to be stacked into a single output vector), a single matrix (which is returned as a single nd-array)
194- or a list of matrices (which are returned as a list of nd-arrays)
195- func_signature: str
196- A numba function signature, passed to the numba.njit decorator on the generated function.
197- ravel_outputs: bool, default False
198- If true, all outputs of the jitted function will be raveled before they are returned. This is useful for
199- removing size-1 dimensions from sympy vectors.
200- stack_outputs: bool, default False
201- If true, stack all return values into a single vector. Otherwise they are returned as a tuple as usual.
202-
203- Returns
204- -------
205- numba.types.function
206- A Numba-compiled function equivalent to the input expression.
207-
208- Notes
209- -----
210- The function returned by this function is pickleable.
211- """
212- ZERO_PATTERN = re .compile (r"(?<![\.\w])0([ ,\]])" )
213- ZERO_ONE_INDEX_PATTERN = re .compile (r"((?<=\[)(([0,1])\.0)(?=\]))" )
214- FLOAT_SUBS = {
215- sp .core .numbers .One (): sp .Float (1 ),
216- sp .core .numbers .NegativeOne (): sp .Float (- 1 ),
217- }
218- printer = NumbaFriendlyNumPyPrinter ()
219-
220- if func_signature is None :
221- decorator = "@nb.njit"
222- else :
223- decorator = f"@nb.njit({ func_signature } )"
224-
225- # Special case: expr is [[]]. This can occur if no user-defined steady-state values were provided.
226- # It shouldn't happen otherwise.
227- if expr == [[]]:
228- sub_dict = ()
229- code = ""
230- retvals = ["[None]" ]
231-
232- else :
233- # Need to make the float substitutions so that numba can correctly interpret everything, but we have to handle
234- # several cases:
235- # Case 1: expr is just a single Sympy thing
236- if isinstance (expr , sp .Matrix | sp .Expr ):
237- expr = expr .subs (FLOAT_SUBS )
238-
239- # Case 2: expr is a list. Items in the list are either lists of expressions (systems of equations),
240- # single equations, or matrices.
241- elif isinstance (expr , list ):
242- new_expr = []
243- for item in expr :
244- # Case 2a: It's a simple list of sympy things
245- if isinstance (item , sp .Matrix | sp .Expr ):
246- new_expr .append (item .subs (FLOAT_SUBS ))
247- # Case 2b: It's a system of equations, List[List[sp.Expr]]
248- elif isinstance (item , list ):
249- if all ([isinstance (x , sp .Matrix | sp .Expr ) for x in item ]):
250- new_expr .append ([x .subs (FLOAT_SUBS ) for x in item ])
251- else :
252- raise ValueError ("Unexpected input type for expr" )
253-
254- # Case 2c: It's a constant -- just pass it along unchanged.
255- elif isinstance (item , int | float ):
256- new_expr .append (item )
257- else :
258- raise ValueError (f"Unexpected input type for expr: { expr } " )
259-
260- expr = new_expr
261- else :
262- raise ValueError ("Unexpected input type for expr" )
263- sub_dict , expr = sp .cse (expr )
264-
265- # Converting matrices to a list of lists is convenient because NumPyPrinter() won't wrap them in np.array
266- exprs = []
267- for ex in expr :
268- if hasattr (ex , "tolist" ):
269- exprs .append (ex .tolist ())
270- else :
271- exprs .append (ex )
272-
273- codes = []
274- retvals = []
275- for i , expr in enumerate (exprs ):
276- code = printer .doprint (expr )
277-
278- delimiter = "]," if "]," in code else ","
279- delimiter = ","
280- code = code .split (delimiter )
281- code = [" " * 8 + eq .strip () for eq in code ]
282- code = f"{ delimiter } \n " .join (code )
283- code = code .replace ("numpy." , "np." )
284-
285- # Handle conversion of 0 to 0.0
286- code = re .sub (ZERO_PATTERN , r"0.0\g<1>" , code )
287-
288- # Repair indexing -- we might have converted x[0] to x[0.0] or x[1] to x[1.0]
289- code = re .sub (ZERO_ONE_INDEX_PATTERN , r"\g<3>" , code )
290-
291- code_name = f"retval_{ i } "
292- retvals .append (code_name )
293- code = f" { code_name } = np.array(\n { code } \n )"
294- if ravel_outputs :
295- code += ".ravel()"
296-
297- codes .append (code )
298- code = "\n " .join (codes )
299-
300- input_signature = ", " .join ([getattr (x , "safe_name" , x .name ) for x in inputs ])
301-
302- assignments = "\n " .join (
303- [
304- f" { x } = { printer .doprint (y ).replace ('numpy.' , 'np.' )} "
305- for x , y in sub_dict
306- ]
307- )
308- assignments = re .sub (ZERO_ONE_INDEX_PATTERN , r"\g<3>" , assignments )
309-
310- if len (retvals ) > 1 :
311- returns = f"({ ',' .join (retvals )} )"
312- if stack_outputs :
313- returns = f"np.stack({ returns } )"
314- else :
315- returns = retvals [0 ]
316- # returns = f'[{",".join(retvals)}]' if len(retvals) > 1 else retvals[0]
317- full_code = f"{ decorator } \n def f({ input_signature } ):\n \n { assignments } \n \n { code } \n \n return { returns } "
318-
319- docstring = f"'''Automatically generated code:\n { full_code } '''"
320- code = f"{ decorator } \n def f({ input_signature } ):\n { docstring } \n \n { assignments } \n \n { code } \n \n return { returns } "
321-
322- exec_namespace = {}
323- exec (code , globals (), exec_namespace )
324- return exec_namespace ["f" ]
0 commit comments