-
-
Notifications
You must be signed in to change notification settings - Fork 1k
New Trace_ELBO that generalizes Trace_ELBO, TraceEnum_ELBO, and TraceGraph_ELBO
#2893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
ordabayevy
wants to merge
62
commits into
pyro-ppl:dev
Choose a base branch
from
ordabayevy:fix-funsor-traceelbo
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 44 commits
Commits
Show all changes
62 commits
Select commit
Hold shift + click to select a range
b00948c
trace_elbo
da2f887
lint
f6c95e4
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
3ec076d
test_gradient
a22ff4e
copy traceenum_elbo and add test model with poisson dist
d551fa2
lint
b68bb3f
use constant funsor
bfb13bf
working version
ca1a1fe
pass second test
6d6a9ed
clean up trace_elbo
0f23b42
add another test
91384ed
lazy eval
c18a8bd
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
34d9a3c
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
b0182c0
vectorize particles; update tests
dc31767
minor fixes; pin to funsor@normalize-logaddexp
5c0fe75
update docs/requirements
2b15fe1
combine Trace_ELBO and TraceEnum_ELBO
351090b
eager evaluation
7d029c7
rm file
1bb7380
lazy
42ad4fa
remove memoize
5b6afdb
merge TraceEnum_ELBO
33628aa
skip test
18a973b
fixes
2c3ead3
convert Tensor to Categorical
5fb1522
restore docs/requirements.txt
f907f93
pin funsor in docs/requirements
902e445
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
0042f85
use funsor.optimizer.apply_optimizer; higher precision in the test
ee5a5ad
pin funsor to the latest commit
e4c6760
optimize logzq
aba300a
optimize logzq
d823153
restore TraceEnum_ELBO
c06e9e4
revert hmm changes
eee297d
_tensor_to_categorical helper function
d748efa
lazy to_funsor
a1970d6
reduce over particle_var
4c1ee9e
address comment in tests
5df30c8
import pyroapi
46ff6f4
compute expected grads using dice factors
d7ee7ee
add test with guide enumeration
49553c3
add two more tests
835f815
pin funsor
760eeb0
lint
ab3831c
remove breakpoint
0b46f3a
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
b6ff8e0
Approximate(ops.sample, ...) based approach
b5bece7
Importance funsor based approach
d6e246e
fixes
6582d7d
Merge branch 'dev' into fix-funsor-traceelbo
714fd62
fix funsor model enumeration
2d2210e
Merge branch 'fix-model-enumeration-funsor' into fix-funsor-traceelbo
29bad7a
use Sampled funsor
9144be1
fixes
e4c8a47
git fixes
c147ad9
Merge branch 'dev' into fix-funsor-traceelbo
703a2fa
use Provenance funsor
3137b1b
clean up
88713f6
fixes
99a0647
Merge branch 'dev' into fix-funsor-traceelbo
14131ad
use provenance
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,10 +4,11 @@ | |
| import contextlib | ||
|
|
||
| import funsor | ||
| from funsor.adjoint import adjoint | ||
| from funsor.constant import Constant | ||
|
|
||
| from pyro.contrib.funsor import to_data, to_funsor | ||
| from pyro.contrib.funsor.handlers import enum, plate, replay, trace | ||
| from pyro.contrib.funsor.infer import config_enumerate | ||
| from pyro.contrib.funsor.handlers import enum, plate, provenance, replay, trace | ||
| from pyro.distributions.util import copy_docs_from | ||
| from pyro.infer import Trace_ELBO as _OrigTrace_ELBO | ||
|
|
||
|
|
@@ -18,32 +19,102 @@ | |
| @copy_docs_from(_OrigTrace_ELBO) | ||
| class Trace_ELBO(ELBO): | ||
| def differentiable_loss(self, model, guide, *args, **kwargs): | ||
| with enum(), plate( | ||
| size=self.num_particles | ||
| with enum( | ||
| first_available_dim=(-self.max_plate_nesting - 1) | ||
| if self.max_plate_nesting is not None | ||
| and self.max_plate_nesting != float("inf") | ||
| else None | ||
| ), provenance(), plate( | ||
| name="num_particles_vectorized", | ||
| size=self.num_particles, | ||
| dim=-self.max_plate_nesting, | ||
| ) if self.num_particles > 1 else contextlib.ExitStack(): | ||
| guide_tr = trace(config_enumerate(default="flat")(guide)).get_trace( | ||
| *args, **kwargs | ||
| ) | ||
| guide_tr = trace(guide).get_trace(*args, **kwargs) | ||
| model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) | ||
|
|
||
| model_terms = terms_from_trace(model_tr) | ||
| guide_terms = terms_from_trace(guide_tr) | ||
|
|
||
| log_measures = guide_terms["log_measures"] + model_terms["log_measures"] | ||
| log_factors = model_terms["log_factors"] + [ | ||
| -f for f in guide_terms["log_factors"] | ||
| ] | ||
| plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"] | ||
| measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"] | ||
|
|
||
| elbo = funsor.Integrate( | ||
| sum(log_measures, to_funsor(0.0)), | ||
| sum(log_factors, to_funsor(0.0)), | ||
| measure_vars, | ||
| particle_var = ( | ||
| frozenset({"num_particles_vectorized"}) | ||
| if self.num_particles > 1 | ||
| else frozenset() | ||
| ) | ||
| elbo = elbo.reduce(funsor.ops.add, plate_vars) | ||
| plate_vars = ( | ||
| guide_terms["plate_vars"] | model_terms["plate_vars"] | ||
| ) - particle_var | ||
|
|
||
| model_measure_vars = model_terms["measure_vars"] - guide_terms["measure_vars"] | ||
| with funsor.terms.lazy: | ||
| # identify and contract out auxiliary variables in the model with partial_sum_product | ||
| contracted_factors, uncontracted_factors = [], [] | ||
| for f in model_terms["log_factors"]: | ||
| if model_measure_vars.intersection(f.inputs): | ||
| contracted_factors.append(f) | ||
| else: | ||
| uncontracted_factors.append(f) | ||
| # incorporate the effects of subsampling and handlers.scale through a common scale factor | ||
| contracted_costs = [ | ||
| model_terms["scale"] * f | ||
| for f in funsor.sum_product.partial_sum_product( | ||
| funsor.ops.logaddexp, | ||
| funsor.ops.add, | ||
| model_terms["log_measures"] + contracted_factors, | ||
| plates=plate_vars, | ||
| eliminate=model_measure_vars, | ||
| ) | ||
| ] | ||
|
|
||
| # accumulate costs from model (logp) and guide (-logq) | ||
| costs = contracted_costs + uncontracted_factors # model costs: logp | ||
| costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq | ||
|
|
||
| # compute log_measures corresponding to each cost term | ||
| # the goal is to achieve fine-grained Rao-Blackwellization | ||
| targets = dict() | ||
| for cost in costs: | ||
| if cost.input_vars not in targets: | ||
| targets[cost.input_vars] = Constant( | ||
| cost.inputs, | ||
| funsor.Tensor( | ||
| funsor.ops.new_zeros( | ||
| funsor.tensor.get_default_prototype(), | ||
| (), | ||
| ) | ||
| ), | ||
| ) | ||
|
|
||
| logzq = funsor.sum_product.sum_product( | ||
ordabayevy marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| funsor.ops.logaddexp, | ||
| funsor.ops.add, | ||
| guide_terms["log_measures"] + list(targets.values()), | ||
| plates=plate_vars, | ||
| eliminate=(plate_vars | guide_terms["measure_vars"]), | ||
| ) | ||
| logzq = funsor.optimizer.apply_optimizer(logzq) | ||
|
|
||
| marginals = adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq) | ||
|
|
||
| with funsor.terms.lazy: | ||
| # finally, integrate out guide variables in the elbo and all plates | ||
| elbo = to_funsor(0, output=funsor.Real) | ||
| for cost in costs: | ||
| target = targets[cost.input_vars] | ||
| # FIXME account for normalization factor for unnormalized logzq | ||
| log_measure = marginals[target] | ||
| measure_vars = (frozenset(cost.inputs) - plate_vars) - particle_var | ||
| elbo_term = funsor.Integrate( | ||
| log_measure, | ||
| cost, | ||
| measure_vars, | ||
| ) | ||
| elbo += elbo_term.reduce( | ||
| funsor.ops.add, plate_vars & frozenset(cost.inputs) | ||
| ) | ||
| # average over Monte-Carlo particles | ||
| elbo = elbo.reduce(funsor.ops.mean, particle_var) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reducing over |
||
|
|
||
| return -to_data(elbo) | ||
| return -to_data(funsor.optimizer.apply_optimizer(elbo)) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||
|
|
||
|
|
||
| class JitTrace_ELBO(Jit_ELBO, Trace_ELBO): | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.