Remove need for explicit left expand_dims in inputs of Elemwise#1967
Draft
ricardoV94 wants to merge 1 commit intopymc-devs:v3from
Draft
Remove need for explicit left expand_dims in inputs of Elemwise#1967ricardoV94 wants to merge 1 commit intopymc-devs:v3from
ricardoV94 wants to merge 1 commit intopymc-devs:v3from
Conversation
1753598 to
e865c1f
Compare
e865c1f to
2fb8d88
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Mostly written by claude
The idea is to get rid of the left expand_dims we use to align input dimensions of Elemwise (and in follow up Blockwise/RandomVariables).
This simplifies the graphs by removing a bunch of "useless" nodes. I started thinking about this after #1961 and thinking that we want to get rid of dummy dimensions as soon as possible, and only introduce them at the end of the graph (if we really need them). This allows us to reduce book-keping checks in the loops of Elemwise/Blockwise/RV/CAReduce.
But our current machinery tries to push them as early as possible towards the inputs (any dimshuffle, not just expand_dims). I think this happens because we needed them anyway at most inputs. So if you did elemwise
bar(foo(scalar, vec), mat), you ended up withbar(foo(scalar[None], vec)[None, :], mat)anyway, so why not start withbar(foo(scalar[None, None], vec[None, :]), mat)? At least no DimShuffle in between the elemwise, so we can reason about them cleanly.I agree we should push squeeze towards inputs (less axes), and we can do also transpose (just to pick a canonical form), but I don't think we should do this for left expand_dims. We don't want to bookeep more than we need to (note how we ended up with an extra dummy dim in
foo(scalar[None, None], vec[None, :])that wasn't ever needed.After this PR you just have
bar(foo(scalar, vec), mat)so both cleaner graph, and no extra bookkeeping.It makes rewrites that reason about dimshuffle/broadcasting a bit more verbose, and probably harder to get right, but hopefully we need less of them as well.