-
Notifications
You must be signed in to change notification settings - Fork 277
Closed
Labels
questionFurther information is requestedFurther information is requested
Description
It often happens that you want to diagnose a VI fit, specifically examining how well the guide fits the prior, the data, etc. So far, I've been using a function like the following but would be interested to know if there are better established alternatives and, if not, whether it would be an appropriate as a backwards compatible extension to the newly introduced (and very useful) compute_log_probs function (or perhaps as a separate function).
def compute_log_probs(
model,
model_args: tuple,
model_kwargs: dict,
model_params: dict,
guide=None,
guide_params:dict=None,
sum_log_prob: bool = True,
):
from numpyro.infer.util import compute_log_probs as clp
from numpyro.handlers import trace, replay, substitute
if guide:
guide_trace = trace(substitute(guide, guide_params or {})).get_trace(*model_args, **model_kwargs)
model = replay(model, guide_trace)
return clp(model, model_args, model_kwargs, model_params, sum_log_prob=sum_log_prob)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested