Skip to content

Conversation

@ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Apr 23, 2021

Addresses #533

Group coded with @fritzo @eb8680 @fehiepsi

Yerdos Ordabayev added 2 commits April 23, 2021 09:50
@fritzo fritzo added the examples Examples and tutorials label Apr 23, 2021
Yerdos Ordabayev added 2 commits April 23, 2021 21:18
from pyro.optim.clipped_adam import ClippedAdam as _ClippedAdam

import funsor
from funsor.adam import Adam # noqa: F401
Copy link
Member Author

@ordabayevy ordabayevy Apr 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for compatibility with pyroapi

value, _ = PARAM_STORE[name]
if event_dim is None:
event_dim = value.dim()
output = funsor.Reals[value.shape[value.dim() - event_dim :]]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infer output when pyro.param was already defined elsewhere


def step(self, *args, **kwargs):
self.optim.num_steps = 1
return self.run(*args, **kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for compatibility with SVI interface

Copy link
Member

@fritzo fritzo Apr 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, let's think about alternative workarounds... One issue here is that the Adam optimizer statistics would not be persisted across svi steps.

One option is simply to change pyroapi's SVI interface to look for either .run() or if missing fall back to .step(). Also I think it's more important to create a simple didactic example than to fastidiously conform to the pyroapi interface (since that interface hasn't seen much use).

for p in params:
p.grad = torch.zeros_like(p.grad)
return loss.item()
with funsor.terms.lazy:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lazy interpretation is needed here to make sure that funsor.Integrate is not eagerly expanded in Expectation

@ordabayevy
Copy link
Member Author

examples/minipyro.py is currently failing with jit. My guess is that jitting needs to be baked into funsor.adam.Adam, I will think more about this.

@ordabayevy ordabayevy added the WIP label Apr 24, 2021
@ordabayevy
Copy link
Member Author

Am I right that when using funsor.adam.Adam the function that needs to be jit traced is the loss function below (Subs funsor) @fritzo @eb8680 ? If yes, then it first needs to be converted to a function with positional arguments?

step_loss = loss(**{k: v[...] for k, v in params.items()}).data

@fritzo
Copy link
Member

fritzo commented Apr 25, 2021

...failing with jit. My guess is that jitting needs to be baked into funsor.adam.Adam

I think you're right, but let's discuss. That's a little different from Pyro where jit is baked into ELBO subclasses.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples Examples and tutorials WIP

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants