-
-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement multi-output Elemwise in Numba via guvectorize #1271
base: main
Are you sure you want to change the base?
Implement multi-output Elemwise in Numba via guvectorize #1271
Conversation
|
||
input_names = [f"i{i}" for i in range(len(node.inputs))] | ||
output_names = [f"o{i}" for i in range(len(node.outputs))] | ||
gu_fn_name = "gu_func" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not familiar with the auto-naming strategy we have going on with Numba, are there any developer docs I can use as a reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should only need to make sure that the names we generate are fixed and don't clobber each other. As long as you're generating all the names yourself everything should be fine; it's usually when you're using unknown names provided by something else that problems start to arise.
On a very related note, if we want certain type of caching to work (e.g. the kind that's based on hashes of source code), we'll need to clean up some old code that uses Variable.name
, and anything else that could differ between equivalent graphs. Since most of the unique name-based code was used to avoid Variable.name
issues, we can probably drop all of it now. In summary, it might be useful for debugging and readability, but it's not necessary and it can negatively affect caching, so don't worry about it.
for i in range({input_names[0]}.shape[0]): | ||
{'[i], '.join(output_names)}[i] = scalar_op_fn({'[i], '.join(input_names)}[i]) | ||
""" | ||
print(gu_fn_src) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This basically creates a function that looks like:
def gu_func(i0, i1, ..., iN, o0, o1, ..., oN):
for i in range(i0.shape[0]):
o0[i], o1[i], ..., oN[i] = scalar_op_fn(i0[i], i1[i], ..., iN[i])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also only seems to work for vector inputs. Am I supposed to do a nested loop for every dimension, or is there a shortcut/helper I can use?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is outdated now! I shouldn't need the loop at all
@@ -27,6 +27,7 @@ def fgraph_convert(self, fgraph, **kwargs): | |||
return numba_funcify(fgraph, **kwargs) | |||
|
|||
def jit_compile(self, fn): | |||
return fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My test errors out when the jitting of the whole function is attempted:
E numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
E Untyped global name 'gu_func': Cannot determine Numba type of <class 'numba.np.ufunc.gufunc.GUFunc'>
E
E File "../../../../../../tmp/tmphopvpths", line 3:
E def numba_funcified_fgraph(tensor_variable):
E <source elided>
E # Elemwise{Composite{exp(i0), log(i0)}}(<TensorType(float64, (None,))>)
E tensor_variable_1, tensor_variable_2 = gu_func(tensor_variable)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably related to numba/numba#2089
@@ -233,7 +233,7 @@ def assert_fn(x, y): | |||
numba_res = aesara_numba_fn(*inputs) | |||
|
|||
# Get some coverage | |||
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) | |||
# eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails for reasons I haven't explored. It seems like it may need special logic for the multi-output Elemwises?
This approach might be dead in the water due to numba/numba#2089 |
Just remember that we don't have to use |
…scalar_op_fn' is not defined
This was just a quick hack to see if it makes sense.
I have not worked with Numba before and I mostly pattern-matched my way so far. I made it to the point where it works if I disable the function level Jitting, otherwise it errors out (see comment below)
Motivated by #1242