New Trace_ELBO that generalizes Trace_ELBO, TraceEnum_ELBO, and TraceGraph_ELBO#2893
New Trace_ELBO that generalizes Trace_ELBO, TraceEnum_ELBO, and TraceGraph_ELBO#2893ordabayevy wants to merge 62 commits intopyro-ppl:devfrom
Trace_ELBO that generalizes Trace_ELBO, TraceEnum_ELBO, and TraceGraph_ELBO#2893Conversation
|
|
||
| elbo = to_funsor(0.0) | ||
| for cost in costs: | ||
| elbo += cost.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) |
There was a problem hiding this comment.
Isn't this missing Dice factors included in log_measures? IIRC that was the reason for using Integrate.
There was a problem hiding this comment.
I copied the test from #2894 which has a simple model/guide pair. When running that model (Elbo=Trace_ELBO, backend=contrib.funsor, reparam-False) both guide_terms["log_measures"] and model_terms["log_measures"] are empty. I can't find Dice factors anywhere in model_terms or guide_terms.
There was a problem hiding this comment.
I guess they're not included because Funsor.sample isn't used in the evaluation of Trace_ELBO. I don't think contrib.funsor.infer.Trace_ELBO is tested extensively outside the pyro-api tests in tests/contrib/funsor/test_pyroapi_funsor.py, which is why this wasn't noticed before.
A more general Funsor-based implementation of Trace_ELBO is certainly possible and would look very similar to the guide-side enumeration handling logic in TraceEnum_ELBO. We might even be able to write a custom "enumeration" strategy that just called Funsor.sample and reuse TraceEnum_ELBO as the Trace_ELBO implementation.
I believe a completely general version might require variable elimination logic beyond what's currently in funsor.sum_product handling cases where the guide had plate structure incompatible with the restrictions there, although I can't immediately think of existing tests or examples where that would be the case.
| - df_a * logqa | ||
| - df_a * (qb * logqb).sum() | ||
| - df_a * (qb * df_c * logqc).sum() | ||
| ) |
There was a problem hiding this comment.
@eb8680 can you check the math here? This is an example with b enumerated in the guide. Trace_ELBO works correctly here.
| # +-----------+ | ||
| # a -|-> b --> c | | ||
| # | \--> d | | ||
| # +-----------+ |
There was a problem hiding this comment.
e is observed and b is enumerated.
| # guide (c is enumerated) | ||
| # +-----------+ | ||
| # a -|-> b --> c | | ||
| # +-----------+ |
There was a problem hiding this comment.
d is observed and c is enumerated.
| # guide (b is enumerated) | ||
| # +-----------+ | ||
| # a -|-> b --> c | | ||
| # +-----------+ |
There was a problem hiding this comment.
d is observed and b is enumerated.
Design Doc
New version of
Trace_ELBOthat extendsTraceEnum_ELBO:dice_factors (as importance weights) (Dispatch toIntegrate(Delta, ...)innormalize_integrate_contractionfunsor#551)I get wrong values forelbo(much larger absolute value compared topyro.infer.trace_elbo.Trace_ELBO), presumably becausesum(log_factors, to_funsor(0.0))in line 41 broadcasts terms inlog_factorsand then that leads to large absolute values afterelbo.reduce(funsor.ops.add, plate_vars)summation.Here I try to fix it by reducing each cost term individually similar toTraceEnum_ELBO. I'm also not sure if integration is needed here.