Conversation
|
@eb8680 is this along the lines what you were suggesting? Is this new tensor type supposed to be wrapped by |
funsor/torch/metadata.py
Outdated
| if kwargs is None: | ||
| kwargs = {} | ||
| meta = frozenset().union( | ||
| *tuple(a._metadata for a in args if hasattr(a, "_metadata")) |
There was a problem hiding this comment.
Provenance of the output tensor is the union of provenances of input tensors.
funsor/constant.py
Outdated
| return super(ConstantMeta, cls).__call__(const_inputs, arg) | ||
|
|
||
|
|
||
| class Constant(Funsor, metaclass=ConstantMeta): |
There was a problem hiding this comment.
Interesting! It would probably be easiest for us to go over this PR and pyro-ppl/pyro#2893 over Zoom, but one thing that would help me beforehand is if you could add a docstring here explaining how Constant behaves differently from Delta wrt Reduce/Contraction/Integrate
| def __repr__(self): | ||
| return "Provenance:\n{}\nTensor:\n{}".format(self._provenance, self._t) | ||
|
|
||
| def __torch_function__(self, func, types, args=(), kwargs=None): |
There was a problem hiding this comment.
@ordabayevy now that you've had a chance to play around with __torch_function__, I'm curious about whether you think we should add a Funsor.__torch_function__ method and attempt to use it in Pyro more directly in lieu of the combination of ProvenanceTensor and to_data/to_funsor. I opened #546 to discuss.
eb8680
left a comment
There was a problem hiding this comment.
Implementation seems reasonable, and nicely separated from the rest of the code.
| provenance = frozenset() | ||
| # extract ProvenanceTensor._t data from args | ||
| _args = [] | ||
| for arg in args: |
There was a problem hiding this comment.
This logic is a bit convoluted. Maybe it could be simplified with some of the helpers in torch.overrides?
There was a problem hiding this comment.
the helpers in torch.overrides?
That might be useful, I look more into torch.overrides functionality.
This is an implementation of Provenance Tracking (https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.361.7132&rep=rep1&type=pdf) in Pytorch. The main idea is that provenance of the output tensor is the union of provenances of input tensors.
Tests: