Skip to content

Potential extension to compute_log_probs to facilitate VI model diagnostics #1939

@hessammehr

Description

@hessammehr

It often happens that you want to diagnose a VI fit, specifically examining how well the guide fits the prior, the data, etc. So far, I've been using a function like the following but would be interested to know if there are better established alternatives and, if not, whether it would be an appropriate as a backwards compatible extension to the newly introduced (and very useful) compute_log_probs function (or perhaps as a separate function).

def compute_log_probs(
    model,
    model_args: tuple,
    model_kwargs: dict,
    model_params: dict,
    guide=None,
    guide_params:dict=None,
    sum_log_prob: bool = True,
):
    from numpyro.infer.util import compute_log_probs as clp
    from numpyro.handlers import trace, replay, substitute
    if guide:
        guide_trace = trace(substitute(guide, guide_params or {})).get_trace(*model_args, **model_kwargs)
        model = replay(model, guide_trace)
    return clp(model, model_args, model_kwargs, model_params, sum_log_prob=sum_log_prob)

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions