Skip to content

Commit abb1b3e

Browse files
Remove numba backend (#55)
* Remove numba backend * Manually type numba functions * Ensure linearized system is contiguous and float64 * Remove njit from cycle reduction functions (windows CI fails mysteriously)
1 parent 3e9c9cf commit abb1b3e

File tree

11 files changed

+55
-273
lines changed

11 files changed

+55
-273
lines changed

gEconpy/model/compile.py

Lines changed: 3 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
from sympytensor import as_tensor
1111

1212
from gEconpy.classes.containers import SteadyStateResults
13-
from gEconpy.numbaf.utilities import numba_lambdify
1413

15-
BACKENDS = Literal["numpy", "numba", "pytensor"]
14+
BACKENDS = Literal["numpy", "pytensor"]
1615

1716

1817
def output_to_tensor(x, cache):
@@ -151,7 +150,7 @@ def compile_function(
151150
**kwargs,
152151
) -> tuple[Callable, dict]:
153152
"""
154-
Dispatch compilation of a sympy function to one of three possible backends: numpy, numba, or pytensor.
153+
Dispatch compilation of a sympy function to one of three possible backends: numpy, or pytensor.
155154
156155
Parameters
157156
----------
@@ -161,7 +160,7 @@ def compile_function(
161160
outputs: list[Union[sp.Symbol, sp.Expr]]
162161
The outputs of the function.
163162
164-
backend: str, one of "numpy", "numba", "pytensor"
163+
backend: str, one of "numpy" or "pytensor"
165164
The backend to use for the compiled function.
166165
167166
cache: dict, optional
@@ -195,10 +194,6 @@ def compile_function(
195194
f, cache = compile_to_numpy(
196195
inputs, outputs, cache, stack_return, pop_return, **kwargs
197196
)
198-
elif backend == "numba":
199-
f, cache = compile_to_numba(
200-
inputs, outputs, cache, stack_return, pop_return, **kwargs
201-
)
202197
elif backend == "pytensor":
203198
f, cache = compile_to_pytensor_function(
204199
inputs, outputs, cache, stack_return, pop_return, return_symbolic, **kwargs
@@ -256,48 +251,6 @@ def compile_to_numpy(
256251
return f, cache
257252

258253

259-
def compile_to_numba(
260-
inputs: list[sp.Symbol],
261-
outputs: list[sp.Symbol | sp.Expr],
262-
cache: dict,
263-
stack_return: bool,
264-
pop_return: bool,
265-
**kwargs,
266-
):
267-
"""
268-
Convert a sympy function to a numba njit function using :func:`numba_lambdify`.
269-
270-
Parameters
271-
----------
272-
inputs: list[sp.Symbol]
273-
The inputs to the function.
274-
outputs: list[Union[sp.Symbol, sp.Expr]]
275-
The outputs of the function.
276-
cache: dict
277-
Mapping between sympy variables and pytensor variables. Ignored by this function; included for compatibility
278-
with other compile functions.
279-
stack_return: bool
280-
If True, the function will return a single numpy array with all outputs. Otherwise it will return a list
281-
of numpy arrays.
282-
pop_return: bool
283-
If True, the function will return only the 0th element of the output. Used to remove the list wrapper around
284-
single-output functions.
285-
kwargs: dict
286-
Ignored, included for compatibility with other compile functions
287-
288-
Returns
289-
-------
290-
f: Callable
291-
The compiled function.
292-
cache: dict
293-
Pytensor caching information.
294-
"""
295-
f = numba_lambdify(inputs, outputs, stack_outputs=stack_return)
296-
if pop_return:
297-
f = pop_return_wrapper(f)
298-
return f, cache
299-
300-
301254
def compile_to_pytensor_function(
302255
inputs: list[sp.Symbol],
303256
outputs: list[sp.Symbol | sp.Expr],

gEconpy/model/model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from copy import deepcopy
77
from typing import Literal, cast
88

9-
import numba as nb
109
import numpy as np
1110
import pandas as pd
1211
import sympy as sp
@@ -42,6 +41,7 @@
4241
)
4342
from gEconpy.utilities import get_name, postprocess_optimizer_res, safe_to_ss
4443

44+
4545
VariableType = sp.Symbol | TimeAwareSymbol
4646
_log = logging.getLogger(__name__)
4747

@@ -1133,7 +1133,10 @@ def linearize_model(
11331133
**param_dict, **steady_state, not_loglin_variable=not_loglin_flags
11341134
)
11351135

1136-
return A, B, C, D
1136+
# Using A.dtype to avoid hard-coding float64 (we might be using float32)
1137+
# The reason for casting is mostly D, which sometimes comes out as an int32/64 array
1138+
1139+
return list(map(lambda x: np.ascontiguousarray(x, dtype=A.dtype), [A, B, C, D]))
11371140

11381141
def solve_model(
11391142
self,
@@ -1337,6 +1340,8 @@ def solve_model(
13371340
**parameter_updates,
13381341
)
13391342

1343+
assert all(x.flags["C_CONTIGUOUS"] for x in [A, B, C, D])
1344+
13401345
if solver == "gensys":
13411346
gensys_results = solve_policy_function_with_gensys(A, B, C, D, tol)
13421347
G_1, constant, impact, f_mat, f_wt, y_wt, gev, eu, loose = gensys_results

gEconpy/model/parameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def compile_param_dict_func(
2929
deterministic_dict: SymbolDictionary
3030
A dictionary of deterministic parameters, with the keys being the parameters and the values being the
3131
expressions to compute them.
32-
backend: str, one of "numpy", "numba", "pytensor"
32+
backend: str, one of "numpy", "pytensor"
3333
The backend to use for the compiled function.
3434
cache: dict, optional
3535
A dictionary mapping from pytensor symbols to sympy expressions. Used to prevent duplicate mappings from

gEconpy/model/steady_state.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,9 @@ def compile_ss_resid_and_sq_err(
158158
backend=backend,
159159
cache=cache,
160160
return_symbolic=return_symbolic,
161-
# for pytensor/numba, the return is a single object; don't stack into a (1,n,n) array
161+
# for pytensor, the return is a single object; don't stack into a (1,n,n) array
162162
stack_return=backend == "numpy",
163-
# Numba directly returns the jacobian as an array, don't pop
164-
# pytensor and lambdify return a list of one item, so we have to extract it.
165-
pop_return=backend != "numba",
163+
pop_return=True,
166164
**kwargs,
167165
)
168166

@@ -209,8 +207,7 @@ def compile_ss_resid_and_sq_err(
209207
return_symbolic=return_symbolic,
210208
# error_hess is a list of one element; don't stack into a (1,n,n) array
211209
stack_return=backend != "pytensor",
212-
# Numba directly returns the hessian as an array, don't pop
213-
pop_return=backend != "numba",
210+
pop_return=True,
214211
**kwargs,
215212
)
216213

gEconpy/numbaf/utilities.py

Lines changed: 0 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -166,159 +166,3 @@ def _print_exp(self, expr):
166166
self._module_format(self._module + ".exp"),
167167
",".join(self._print(i) for i in expr.args),
168168
)
169-
170-
171-
def numba_lambdify(
172-
inputs: list[sp.Symbol],
173-
expr: list[sp.Expr] | sp.Matrix | list[sp.Matrix],
174-
func_signature: str | None = None,
175-
ravel_outputs=False,
176-
stack_outputs=False,
177-
) -> Callable:
178-
"""
179-
Convert a sympy expression into a Numba-compiled function. Unlike sp.lambdify, the resulting function can be
180-
pickled. In addition, common sub-expressions are gathered using sp.cse and assigned to local variables,
181-
giving a (very) modest performance boost. A signature can optionally be provided for numba.njit.
182-
183-
Finally, the resulting function always returns a numpy array, rather than a list.
184-
185-
Parameters
186-
----------
187-
inputs: list of sympy.Symbol
188-
A list of "exogenous" variables. The distinction between "exogenous" and "enodgenous" is
189-
useful when passing the resulting function to a scipy.optimize optimizer. In this context, exogenous
190-
variables should be the choice varibles used to minimize the function.
191-
expr : list of sympy.Expr or sp.Matrix
192-
The sympy expression(s) to be converted. Expects a list of expressions (in the case that we're compiling a
193-
system to be stacked into a single output vector), a single matrix (which is returned as a single nd-array)
194-
or a list of matrices (which are returned as a list of nd-arrays)
195-
func_signature: str
196-
A numba function signature, passed to the numba.njit decorator on the generated function.
197-
ravel_outputs: bool, default False
198-
If true, all outputs of the jitted function will be raveled before they are returned. This is useful for
199-
removing size-1 dimensions from sympy vectors.
200-
stack_outputs: bool, default False
201-
If true, stack all return values into a single vector. Otherwise they are returned as a tuple as usual.
202-
203-
Returns
204-
-------
205-
numba.types.function
206-
A Numba-compiled function equivalent to the input expression.
207-
208-
Notes
209-
-----
210-
The function returned by this function is pickleable.
211-
"""
212-
ZERO_PATTERN = re.compile(r"(?<![\.\w])0([ ,\]])")
213-
ZERO_ONE_INDEX_PATTERN = re.compile(r"((?<=\[)(([0,1])\.0)(?=\]))")
214-
FLOAT_SUBS = {
215-
sp.core.numbers.One(): sp.Float(1),
216-
sp.core.numbers.NegativeOne(): sp.Float(-1),
217-
}
218-
printer = NumbaFriendlyNumPyPrinter()
219-
220-
if func_signature is None:
221-
decorator = "@nb.njit"
222-
else:
223-
decorator = f"@nb.njit({func_signature})"
224-
225-
# Special case: expr is [[]]. This can occur if no user-defined steady-state values were provided.
226-
# It shouldn't happen otherwise.
227-
if expr == [[]]:
228-
sub_dict = ()
229-
code = ""
230-
retvals = ["[None]"]
231-
232-
else:
233-
# Need to make the float substitutions so that numba can correctly interpret everything, but we have to handle
234-
# several cases:
235-
# Case 1: expr is just a single Sympy thing
236-
if isinstance(expr, sp.Matrix | sp.Expr):
237-
expr = expr.subs(FLOAT_SUBS)
238-
239-
# Case 2: expr is a list. Items in the list are either lists of expressions (systems of equations),
240-
# single equations, or matrices.
241-
elif isinstance(expr, list):
242-
new_expr = []
243-
for item in expr:
244-
# Case 2a: It's a simple list of sympy things
245-
if isinstance(item, sp.Matrix | sp.Expr):
246-
new_expr.append(item.subs(FLOAT_SUBS))
247-
# Case 2b: It's a system of equations, List[List[sp.Expr]]
248-
elif isinstance(item, list):
249-
if all([isinstance(x, sp.Matrix | sp.Expr) for x in item]):
250-
new_expr.append([x.subs(FLOAT_SUBS) for x in item])
251-
else:
252-
raise ValueError("Unexpected input type for expr")
253-
254-
# Case 2c: It's a constant -- just pass it along unchanged.
255-
elif isinstance(item, int | float):
256-
new_expr.append(item)
257-
else:
258-
raise ValueError(f"Unexpected input type for expr: {expr}")
259-
260-
expr = new_expr
261-
else:
262-
raise ValueError("Unexpected input type for expr")
263-
sub_dict, expr = sp.cse(expr)
264-
265-
# Converting matrices to a list of lists is convenient because NumPyPrinter() won't wrap them in np.array
266-
exprs = []
267-
for ex in expr:
268-
if hasattr(ex, "tolist"):
269-
exprs.append(ex.tolist())
270-
else:
271-
exprs.append(ex)
272-
273-
codes = []
274-
retvals = []
275-
for i, expr in enumerate(exprs):
276-
code = printer.doprint(expr)
277-
278-
delimiter = "]," if "]," in code else ","
279-
delimiter = ","
280-
code = code.split(delimiter)
281-
code = [" " * 8 + eq.strip() for eq in code]
282-
code = f"{delimiter}\n".join(code)
283-
code = code.replace("numpy.", "np.")
284-
285-
# Handle conversion of 0 to 0.0
286-
code = re.sub(ZERO_PATTERN, r"0.0\g<1>", code)
287-
288-
# Repair indexing -- we might have converted x[0] to x[0.0] or x[1] to x[1.0]
289-
code = re.sub(ZERO_ONE_INDEX_PATTERN, r"\g<3>", code)
290-
291-
code_name = f"retval_{i}"
292-
retvals.append(code_name)
293-
code = f" {code_name} = np.array(\n{code}\n )"
294-
if ravel_outputs:
295-
code += ".ravel()"
296-
297-
codes.append(code)
298-
code = "\n".join(codes)
299-
300-
input_signature = ", ".join([getattr(x, "safe_name", x.name) for x in inputs])
301-
302-
assignments = "\n".join(
303-
[
304-
f" {x} = {printer.doprint(y).replace('numpy.', 'np.')}"
305-
for x, y in sub_dict
306-
]
307-
)
308-
assignments = re.sub(ZERO_ONE_INDEX_PATTERN, r"\g<3>", assignments)
309-
310-
if len(retvals) > 1:
311-
returns = f"({','.join(retvals)})"
312-
if stack_outputs:
313-
returns = f"np.stack({returns})"
314-
else:
315-
returns = retvals[0]
316-
# returns = f'[{",".join(retvals)}]' if len(retvals) > 1 else retvals[0]
317-
full_code = f"{decorator}\ndef f({input_signature}):\n\n{assignments}\n\n{code}\n\n return {returns}"
318-
319-
docstring = f"'''Automatically generated code:\n{full_code}'''"
320-
code = f"{decorator}\ndef f({input_signature}):\n {docstring}\n\n{assignments}\n\n{code}\n\n return {returns}"
321-
322-
exec_namespace = {}
323-
exec(code, globals(), exec_namespace)
324-
return exec_namespace["f"]

gEconpy/solvers/cycle_reduction.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
)
1616

1717

18-
# @nb.njit(cache=True)
18+
# TODO: These njit decorators cause the CI to fail on Windows only -- no idea why. Disabling it for now.
19+
# @nb.njit(
20+
# (nb.float64[:, ::1], nb.float64[:, ::1], nb.float64[:, ::1], nb.int64, nb.float64),
21+
# cache=True,
22+
# )
1923
def nb_cycle_reduction(
2024
A0: np.ndarray,
2125
A1: np.ndarray,
@@ -113,8 +117,11 @@ def nb_cycle_reduction(
113117
return X, res, result, log_norm
114118

115119

116-
# @nb.njit(cache=True)
117-
def nb_solve_shock_matrix(B, C, D, G_1):
120+
# @nb.njit(
121+
# (nb.float64[:, ::1], nb.float64[:, ::1], nb.float64[:, ::1], nb.float64[:, ::1]),
122+
# cache=True,
123+
# )
124+
def nb_solve_shock_matrix(B: np.ndarray, C: np.ndarray, D: np.ndarray, G_1: np.ndarray):
118125
"""
119126
Given the partial solution to the linear approximate policy function G_1, solve for the remaining component of the
120127
policy function, R.

tests/model/test_compile.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def f():
3232
assert result["b"] == 2.0
3333

3434

35-
@pytest.mark.parametrize("backend", ["numpy", "numba", "pytensor"])
36-
def test_scalar_function(backend: Literal["numpy", "numba", "pytensor"]):
35+
@pytest.mark.parametrize("backend", ["numpy", "pytensor"])
36+
def test_scalar_function(backend: Literal["numpy", "pytensor"]):
3737
x = sp.symbols("x")
3838
f = x**2
3939
f_func, _ = compile_function(
@@ -42,11 +42,9 @@ def test_scalar_function(backend: Literal["numpy", "numba", "pytensor"]):
4242
assert f_func(x=2) == 4
4343

4444

45-
@pytest.mark.parametrize("backend", ["numpy", "numba", "pytensor"])
45+
@pytest.mark.parametrize("backend", ["numpy", "pytensor"])
4646
@pytest.mark.parametrize("stack_return", [True, False])
47-
def test_multiple_outputs(
48-
backend: Literal["numpy", "numba", "pytensor"], stack_return: bool
49-
):
47+
def test_multiple_outputs(backend: Literal["numpy", "pytensor"], stack_return: bool):
5048
x, y, z = sp.symbols("x y z ")
5149
x2 = x**2
5250
y2 = y**2
@@ -68,8 +66,8 @@ def test_multiple_outputs(
6866
)
6967

7068

71-
@pytest.mark.parametrize("backend", ["numpy", "numba", "pytensor"])
72-
def test_matrix_function(backend: Literal["numpy", "numba", "pytensor"]):
69+
@pytest.mark.parametrize("backend", ["numpy", "pytensor"])
70+
def test_matrix_function(backend: Literal["numpy", "pytensor"]):
7371
x, y, z = sp.symbols("x y z")
7472
f = sp.Matrix([x, y, z]).reshape(1, 3)
7573

@@ -87,7 +85,7 @@ def test_matrix_function(backend: Literal["numpy", "numba", "pytensor"]):
8785
np.testing.assert_allclose(res, np.array([[2.0, 3.0, 4.0]]))
8886

8987

90-
@pytest.mark.parametrize("backend", ["numpy", "numba", "pytensor"])
88+
@pytest.mark.parametrize("backend", ["numpy", "pytensor"])
9189
def test_compile_gradient(backend: BACKENDS):
9290
x, y, z = sp.symbols("x y z")
9391
f = x**2 + y**2 + z**2

0 commit comments

Comments
 (0)