Skip to content

Remove need for explicit left expand_dims in inputs of Elemwise#1967

Draft
ricardoV94 wants to merge 1 commit intopymc-devs:v3from
ricardoV94:no_more_expand_dims
Draft

Remove need for explicit left expand_dims in inputs of Elemwise#1967
ricardoV94 wants to merge 1 commit intopymc-devs:v3from
ricardoV94:no_more_expand_dims

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 10, 2026

Mostly written by claude

The idea is to get rid of the left expand_dims we use to align input dimensions of Elemwise (and in follow up Blockwise/RandomVariables).

This simplifies the graphs by removing a bunch of "useless" nodes. I started thinking about this after #1961 and thinking that we want to get rid of dummy dimensions as soon as possible, and only introduce them at the end of the graph (if we really need them). This allows us to reduce book-keping checks in the loops of Elemwise/Blockwise/RV/CAReduce.

But our current machinery tries to push them as early as possible towards the inputs (any dimshuffle, not just expand_dims). I think this happens because we needed them anyway at most inputs. So if you did elemwise bar(foo(scalar, vec), mat), you ended up with bar(foo(scalar[None], vec)[None, :], mat) anyway, so why not start with bar(foo(scalar[None, None], vec[None, :]), mat)? At least no DimShuffle in between the elemwise, so we can reason about them cleanly.

I agree we should push squeeze towards inputs (less axes), and we can do also transpose (just to pick a canonical form), but I don't think we should do this for left expand_dims. We don't want to bookeep more than we need to (note how we ended up with an extra dummy dim in foo(scalar[None, None], vec[None, :]) that wasn't ever needed.

After this PR you just have bar(foo(scalar, vec), mat) so both cleaner graph, and no extra bookkeeping.

It makes rewrites that reason about dimshuffle/broadcasting a bit more verbose, and probably harder to get right, but hopefully we need less of them as well.

@ricardoV94 ricardoV94 force-pushed the no_more_expand_dims branch 8 times, most recently from 1753598 to e865c1f Compare March 17, 2026 12:16
@ricardoV94 ricardoV94 force-pushed the no_more_expand_dims branch from e865c1f to 2fb8d88 Compare March 17, 2026 12:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant