Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential extension to compute_log_probs to facilitate VI model diagnostics #1939

Open
hessammehr opened this issue Dec 19, 2024 · 1 comment
Labels
question Further information is requested

Comments

@hessammehr
Copy link
Contributor

It often happens that you want to diagnose a VI fit, specifically examining how well the guide fits the prior, the data, etc. So far, I've been using a function like the following but would be interested to know if there are better established alternatives and, if not, whether it would be an appropriate as a backwards compatible extension to the newly introduced (and very useful) compute_log_probs function (or perhaps as a separate function).

def compute_log_probs(
    model,
    model_args: tuple,
    model_kwargs: dict,
    model_params: dict,
    guide=None,
    guide_params:dict=None,
    sum_log_prob: bool = True,
):
    from numpyro.infer.util import compute_log_probs as clp
    from numpyro.handlers import trace, replay, substitute
    if guide:
        guide_trace = trace(substitute(guide, guide_params or {})).get_trace(*model_args, **model_kwargs)
        model = replay(model, guide_trace)
    return clp(model, model_args, model_kwargs, model_params, sum_log_prob=sum_log_prob)
@fehiepsi fehiepsi added the question Further information is requested label Dec 20, 2024
@fehiepsi
Copy link
Member

Sorry for the late response. I think your implementation is correct. Re introducing new behavior compute_log_probs, I guess it is unnecessary. Maybe @tillahoffmann has other opinion on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants