-
Notifications
You must be signed in to change notification settings - Fork 21
Delayed param #534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Delayed param #534
Conversation
| from pyro.optim.clipped_adam import ClippedAdam as _ClippedAdam | ||
|
|
||
| import funsor | ||
| from funsor.adam import Adam # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for compatibility with pyroapi
| value, _ = PARAM_STORE[name] | ||
| if event_dim is None: | ||
| event_dim = value.dim() | ||
| output = funsor.Reals[value.shape[value.dim() - event_dim :]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infer output when pyro.param was already defined elsewhere
|
|
||
| def step(self, *args, **kwargs): | ||
| self.optim.num_steps = 1 | ||
| return self.run(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for compatibility with SVI interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, let's think about alternative workarounds... One issue here is that the Adam optimizer statistics would not be persisted across svi steps.
One option is simply to change pyroapi's SVI interface to look for either .run() or if missing fall back to .step(). Also I think it's more important to create a simple didactic example than to fastidiously conform to the pyroapi interface (since that interface hasn't seen much use).
| for p in params: | ||
| p.grad = torch.zeros_like(p.grad) | ||
| return loss.item() | ||
| with funsor.terms.lazy: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lazy interpretation is needed here to make sure that funsor.Integrate is not eagerly expanded in Expectation
|
|
I think you're right, but let's discuss. That's a little different from Pyro where jit is baked into ELBO subclasses. |
Addresses #533
Group coded with @fritzo @eb8680 @fehiepsi