-
-
Notifications
You must be signed in to change notification settings - Fork 1k
New Trace_ELBO that generalizes Trace_ELBO, TraceEnum_ELBO, and TraceGraph_ELBO
#2893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
|
|
||
| elbo = to_funsor(0.0) | ||
| for cost in costs: | ||
| elbo += cost.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this missing Dice factors included in log_measures? IIRC that was the reason for using Integrate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied the test from #2894 which has a simple model/guide pair. When running that model (Elbo=Trace_ELBO, backend=contrib.funsor, reparam-False) both guide_terms["log_measures"] and model_terms["log_measures"] are empty. I can't find Dice factors anywhere in model_terms or guide_terms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess they're not included because Funsor.sample isn't used in the evaluation of Trace_ELBO. I don't think contrib.funsor.infer.Trace_ELBO is tested extensively outside the pyro-api tests in tests/contrib/funsor/test_pyroapi_funsor.py, which is why this wasn't noticed before.
A more general Funsor-based implementation of Trace_ELBO is certainly possible and would look very similar to the guide-side enumeration handling logic in TraceEnum_ELBO. We might even be able to write a custom "enumeration" strategy that just called Funsor.sample and reuse TraceEnum_ELBO as the Trace_ELBO implementation.
I believe a completely general version might require variable elimination logic beyond what's currently in funsor.sum_product handling cases where the guide had plate structure incompatible with the restrictions there, although I can't immediately think of existing tests or examples where that would be the case.
| - df_a * logqa | ||
| - df_a * (qb * logqb).sum() | ||
| - df_a * (qb * df_c * logqc).sum() | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eb8680 can you check the math here? This is an example with b enumerated in the guide. Trace_ELBO works correctly here.
| # +-----------+ | ||
| # a -|-> b --> c | | ||
| # | \--> d | | ||
| # +-----------+ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e is observed and b is enumerated.
| # guide (c is enumerated) | ||
| # +-----------+ | ||
| # a -|-> b --> c | | ||
| # +-----------+ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d is observed and c is enumerated.
| # guide (b is enumerated) | ||
| # +-----------+ | ||
| # a -|-> b --> c | | ||
| # +-----------+ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d is observed and b is enumerated.
Design Doc
New version of
Trace_ELBOthat extendsTraceEnum_ELBO:dice_factors (as importance weights) (Dispatch toIntegrate(Delta, ...)innormalize_integrate_contractionfunsor#551)I get wrong values forelbo(much larger absolute value compared topyro.infer.trace_elbo.Trace_ELBO), presumably becausesum(log_factors, to_funsor(0.0))in line 41 broadcasts terms inlog_factorsand then that leads to large absolute values afterelbo.reduce(funsor.ops.add, plate_vars)summation.Here I try to fix it by reducing each cost term individually similar toTraceEnum_ELBO. I'm also not sure if integration is needed here.