-
-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fuse consecutive Elemwise
subgraphs with multiple clients
#1242
base: main
Are you sure you want to change the base?
Conversation
074f125
to
fff0d3c
Compare
c84e827
to
fb6abe0
Compare
fb6abe0
to
cbd2c61
Compare
d4301b1
to
0246fce
Compare
It's hard to tell what's going on using those numbers alone. For example, the extra time could be spent in compilation, and the run-time could be significantly reduced. Regardless, the difference is alarming. Situations like this are another reason we should get #718 in place sooner than later. |
The logic for inplacing will have to be rethought, as some inplaced outputs could overwrite inputs that are still needed for other outputs. Basically we will need something that reasons about the inner graph like we do for the general function. Edit: For now I just restricted inplace to single-output Composites |
Another more interesting issue I am finding is some Edit: It was a bug in the subgraph algorithm. Fixed! |
76b30b4
to
f000323
Compare
The same job is done by canonicalize before this rewrite is ever called.
f000323
to
618c11c
Compare
Elemwise
subgraphs with multiple clients
762061b
to
e70ea5f
Compare
233f68b
to
1bfbf24
Compare
This seems to be now working (more often than not) on the C-backend. It provides less speedups than I was expecting: import aesara
import aesara.tensor as at
import numpy as np
x = at.dvector("x")
mu = at.dvector("mu")
logp = (- ((x - mu) **2) / 2)
grad = at.grad(logp.sum(), x)
func = aesara.function([mu, x], [logp, grad])
func.trust_input = True
aesara.dprint(func)
rng = np.random.default_rng(123)
size = 100_000
xv = rng.normal(size=size)
muv = rng.normal(size=size)
%timeit func(xv, muv) The speedup depends on the size.
I couldn't test the effects on the Numba backend, because mulit-output Elemwises are disabled (we could test https://numba.pydata.org/numba-doc/latest/user/vectorize.html#the-guvectorize-decorator). The JAX backend also errors out but I didn't investigate why yet. @brandonwillard do you know of an easy way to retrieve the |
ddc83a4
to
e00c125
Compare
Which function exactly? All the C code generated during an |
It's possible that this new feature has to sometimes trade off between the benefits of "merging"/CSE and fusion. Your example in #1237 illustrates this possibility with the |
@brandonwillard I extended the motivation behind this PR in the original issue: #1237 (comment) |
e00c125
to
e518d4b
Compare
Otherwise they fail due to lack of support for multi-output Elemwises in the Numba backend
e078f4e
to
1302b49
Compare
Closes #1237
Todo
add_mul_fusion
elemwise_max_input_fct