Skip to content

Conversation

@ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Mar 26, 2024

Few notes about this PR

  • Importing functorch.dim messes up torch.Tensor a bit (e.g., the torch.Tensor.split method starts to fail). Therefore, I siloed the functorch.dim import from the main Pyro
    • That's why I also removed pyro/contrib/named from doctest in order to avoid the import of functorch.dim
    • And use hasattr(self.dim, "is_bound") instead of isinstance(self.dim, Dim)
  • Trace_ELBO implementation is actually similar to TraceGraph_ELBO with dependency tracking with the goal of having a single Trace_ELBO implementation that will generalize to TraceEnum_ELBO and others.
  • pyro.enable_validation needs to be set to False to avoid validation errors caused by if value.all() method
  • Distribution arguments all need to be bound by named dim, otherwise broadcasting attempt of parameters by a distribution will lead to segmentation fault

@ordabayevy ordabayevy added the WIP label Mar 27, 2024
]
if bind_named_dims:
result = result[bind_named_dims]
return result
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests for distribution shapes for log_prob, mean, sample, rsample, entropy (fail when named and positional dims are mixed in the batch/event/sample shape; conflicting named dims)

Generalize named dim binding implementation.

Test transforms and support.

Shape inference.

return vindex(self._tensor, args)


def index_select(input, dim, index):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add type annotation. Move to contrib/named in the follow up PR.

@eb8680 eb8680 self-requested a review April 9, 2024 18:49
@fritzo fritzo removed their request for review April 16, 2024 16:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants