Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predictive fix when deterministic sites are present #1789

Merged
merged 9 commits into from
May 2, 2024

Conversation

kylejcaron
Copy link
Contributor

This PR attempts to fix #1772 - when deterministic sites are included in Predictive with params, Predictive won't generate new samples for those deterministic sites.

This PR solves that by ignoring deterministic sites in the Predictive substitute call. Added a new test to cover this scenario as well.

@kylejcaron kylejcaron changed the title Deterministic predictive fix Predictive fix when deterministic sites are present Apr 30, 2024
model_trace = trace(
seed(substitute(masked_model, samples), rng_key)
seed(
substitute(masked_model, substitute_fn=_samples_wo_deterministic),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help me change line 777 to condition so that? It would be nice to add an argument like exclude_deterministic_from_posterior to Predictive to maintain two behaviors. We will pass such argument to this _predictive function to control the behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the conditional logic, I tried the name exclude_deterministic_params because it was shorter but let me know if that's insufficient

For this could you elaborate?

Could you help me change line 777 to condition so that?

Would I be changing L777 from substitute to condition?

And should I add the deterministic fix there as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, using condition there is fine because we don't substitute deterministic sites under thecondition handler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok just switched it to condition

numpyro/infer/util.py Outdated Show resolved Hide resolved
@kylejcaron kylejcaron requested a review from fehiepsi May 2, 2024 15:25
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the breakage, @kylejcaron! We'll make a patch release soon.

@fehiepsi fehiepsi merged commit 7c3ec50 into pyro-ppl:master May 2, 2024
4 checks passed
@kylejcaron kylejcaron deleted the deterministic-predictive-fix branch May 3, 2024 13:55
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
* added custom effect handler for predictive

* added test, fixed predictive_substitute

* fixed typo, removed unneeded custom substitute calls

* removed custom effect handler, improved readability

* reverted formatting of imports

* added conditional arg for handling deterministic sites to predictive

* changed arg name to exclude_deterministic

* updated exclude_deterministic description

* changed substitute to condition in infer_discrete _predctive workflow

---------

Co-authored-by: kylejcaron <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

numpyro.deterministic static on infer.Predictive
2 participants