Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Numba caching and source generation #1419

Open
brandonwillard opened this issue Feb 10, 2023 · 4 comments
Open

Improve Numba caching and source generation #1419

brandonwillard opened this issue Feb 10, 2023 · 4 comments
Labels
enhancement New feature or request important Numba Involves Numba transpilation

Comments

@brandonwillard
Copy link
Member

brandonwillard commented Feb 10, 2023

Looks like this wasn't given its own issue, but we need to improve the caching and way we generate source for Numba JITing.

To clarify, we need to do the following:

  • Cache the Python function objects returned by _numba_funcify based on their node arguments (i.e. the Apply nodes).
  • Store cached Python functions as persistent Python modules in aesara.config.compiledir

We can repurpose the compile directory code for the C backend to do most of this, but, if there are other packages that would simplify the entirely process, we should look into those, too.

@brandonwillard brandonwillard added enhancement New feature or request important Numba Involves Numba transpilation labels Feb 10, 2023
@rlouf rlouf moved this to Backends in Aesara Roadmap Feb 13, 2023
@Smit-create
Copy link
Member

I'm thinking of the following approach:

  1. A class in numba that handles cache -- say NumbaCache.
  2. Some of the probable methods that this class would have is:
    • refresh -- Clear the cache
    • has_op_code -- Check whether the cache has Op implementation in Numba
    • add_op_code -- Add the Numba function in the cache
  3. We will have a global object of this class that will be used in each of the functions.

Some of the questions in the implementation that we have are:

  1. Whether to use .py file in the cache dir and append the function source code to it or to use pickling libraries like dill? -- I'm not entirely sure if numba's njit will work with serialization and deserialization in dill.
  2. How will we hash the Op node for a quick look-up?
  3. Will we check the cache dir while visiting all the nodes in numba backend? -- I think that might affect compilation time in some cases.

@brandonwillard
Copy link
Member Author

brandonwillard commented Feb 24, 2023

Here's a high-level pseudocode outline of what we need: https://gist.github.com/brandonwillard/b9262d0eccb0e7016f836447ba8870fc.

Some of the questions in the implementation that we have are:

  1. Whether to use .py file in the cache dir and append the function source code to it or to use pickling libraries like dill? -- I'm not entirely sure if numba's njit will work with serialization and deserialization in dill.

Yes, we need to investigate whether or not we can directly use the pickled modules produced by dill or something similar. This is mocked-up in my outline.

  1. How will we hash the Op node for a quick look-up?

The outline uses a not-so-good pickle-to-SHA256 approach, but that's the basic idea for one approach.

  1. Will we check the cache dir while visiting all the nodes in numba backend? -- I think that might affect compilation time in some cases.

Per the outline, we can use an in-memory, file-backed cache like shelve. Exactly how we use it is the question, though.

@Smit-create
Copy link
Member

Once we have completed #1470, I think we can then try to refactor some nodes like Shape_i and use partial functions.

Numba with partial funcs

import numba
from functools import partial
import numpy as np
import time

def some_func(arr, other_one):
    t = 0.0
    for a in arr:
        t+=a
    return t + other_one


arr = np.random.rand(1_000_000)
a = time.time()
f = numba.njit(some_func, cache=True)
f2 = partial(f, other_one=100)
tmp = f2(arr)
b = time.time()
print("Time 1:", b - a, tmp)

a = time.time()
f = numba.njit(some_func, cache=True)
f2 = partial(f, other_one=0)
tmp = f2(arr)
b = time.time()
print("Time 2:", b - a, tmp)

@brandonwillard
Copy link
Member Author

brandonwillard commented Mar 14, 2023

Once we have completed #1470, I think we can then try to refactor some nodes like Shape_i and use partial functions.

Are we sure that using partial is any better for compilation, though? I'm still not clear on that. It's definitely easier for us in the long-run, so I'm all for it, but we need to know when/if it will help.

Also, we can make the tests work using the same approach we have been using. That Numba disable-JIT approach would be a nicer, but it shouldn't be blocking anything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request important Numba Involves Numba transpilation
Projects
Status: Backends
Development

No branches or pull requests

2 participants