-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Add support for named dims (torchdim)
#3347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
| ] | ||
| if bind_named_dims: | ||
| result = result[bind_named_dims] | ||
| return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unit tests for distribution shapes for log_prob, mean, sample, rsample, entropy (fail when named and positional dims are mixed in the batch/event/sample shape; conflicting named dims)
Generalize named dim binding implementation.
Test transforms and support.
Shape inference.
| return vindex(self._tensor, args) | ||
|
|
||
|
|
||
| def index_select(input, dim, index): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add type annotation. Move to contrib/named in the follow up PR.
Few notes about this PR
functorch.dimmesses uptorch.Tensora bit (e.g., thetorch.Tensor.splitmethod starts to fail). Therefore, I siloed thefunctorch.dimimport from the main Pyropyro/contrib/namedfrom doctest in order to avoid the import offunctorch.dimhasattr(self.dim, "is_bound")instead ofisinstance(self.dim, Dim)Trace_ELBOimplementation is actually similar toTraceGraph_ELBOwith dependency tracking with the goal of having a singleTrace_ELBOimplementation that will generalize toTraceEnum_ELBOand others.pyro.enable_validationneeds to be set toFalseto avoid validation errors caused byif value.all()methodDistributionarguments all need to be bound by named dim, otherwise broadcasting attempt of parameters by a distribution will lead to segmentation fault