@@ -761,6 +761,7 @@ def _predictive(
761
761
return_sites = None ,
762
762
infer_discrete = False ,
763
763
parallel = True ,
764
+ exclude_deterministic : bool = True ,
764
765
model_args = (),
765
766
model_kwargs = {},
766
767
):
@@ -774,7 +775,7 @@ def _predictive(
774
775
posterior_samples ,
775
776
)
776
777
prototype_trace = trace (
777
- seed (substitute (masked_model , prototype_sample ), subkey )
778
+ seed (condition (masked_model , prototype_sample ), subkey )
778
779
).get_trace (* model_args , ** model_kwargs )
779
780
first_available_dim = - _guess_max_plate_nesting (prototype_trace ) - 1
780
781
@@ -795,9 +796,20 @@ def single_prediction(val):
795
796
** model_kwargs ,
796
797
)
797
798
else :
798
- model_trace = trace (
799
- seed (substitute (masked_model , samples ), rng_key )
800
- ).get_trace (* model_args , ** model_kwargs )
799
+
800
+ def _samples_wo_deterministic (msg ):
801
+ return (
802
+ samples .get (msg ["name" ]) if msg ["type" ] != "deterministic" else None
803
+ )
804
+
805
+ substituted_model = (
806
+ substitute (masked_model , substitute_fn = _samples_wo_deterministic )
807
+ if exclude_deterministic
808
+ else substitute (masked_model , samples )
809
+ )
810
+ model_trace = trace (seed (substituted_model , rng_key )).get_trace (
811
+ * model_args , ** model_kwargs
812
+ )
801
813
pred_samples = {name : site ["value" ] for name , site in model_trace .items ()}
802
814
803
815
if return_sites is not None :
@@ -870,6 +882,7 @@ class Predictive(object):
870
882
871
883
+ set `batch_ndims=1` to get predictions from a one dimensional batch of the guide and parameters
872
884
with shapes `(num_samples x batch_size x ...)`
885
+ :param exclude_deterministic: indicates whether to ignore deterministic sites from the posterior samples.
873
886
874
887
:return: dict of samples from the predictive distribution.
875
888
@@ -907,6 +920,7 @@ def __init__(
907
920
infer_discrete : bool = False ,
908
921
parallel : bool = False ,
909
922
batch_ndims : Optional [int ] = None ,
923
+ exclude_deterministic : bool = True ,
910
924
):
911
925
if posterior_samples is None and num_samples is None :
912
926
raise ValueError (
@@ -967,6 +981,7 @@ def __init__(
967
981
self .parallel = parallel
968
982
self .batch_ndims = batch_ndims
969
983
self ._batch_shape = batch_shape
984
+ self .exclude_deterministic = exclude_deterministic
970
985
971
986
def _call_with_params (self , rng_key , params , args , kwargs ):
972
987
posterior_samples = self .posterior_samples
@@ -983,6 +998,7 @@ def _call_with_params(self, rng_key, params, args, kwargs):
983
998
parallel = self .parallel ,
984
999
model_args = args ,
985
1000
model_kwargs = kwargs ,
1001
+ exclude_deterministic = self .exclude_deterministic ,
986
1002
)
987
1003
model = substitute (self .model , self .params )
988
1004
return _predictive (
@@ -995,6 +1011,7 @@ def _call_with_params(self, rng_key, params, args, kwargs):
995
1011
parallel = self .parallel ,
996
1012
model_args = args ,
997
1013
model_kwargs = kwargs ,
1014
+ exclude_deterministic = self .exclude_deterministic ,
998
1015
)
999
1016
1000
1017
def __call__ (self , rng_key , * args , ** kwargs ):
0 commit comments