@@ -78,7 +78,7 @@ def __init__(self,
7878 prior_scale = 1. ,
7979 wishart_scale = 1. ,
8080 dof = 1. ,
81- cov_rank = None ):
81+ cov_rank = 3 ):
8282 super (DiscClassification , self ).__init__ ()
8383
8484 self .wishart_scale = wishart_scale
@@ -155,7 +155,7 @@ def forward(self, x):
155155 def logit_predictive (self , x ):
156156 return (self .W () @ x [..., None ]).squeeze (- 1 ) + self .noise ()
157157
158- def predictive (self , x , n_samples = 10 ):
158+ def predictive (self , x , n_samples = 20 ):
159159 softmax_samples = F .softmax (self .logit_predictive (x ).rsample (sample_shape = torch .Size ([n_samples ])), dim = - 1 )
160160 return torch .clip (torch .mean (softmax_samples , dim = 0 ),min = 0. ,max = 1. )
161161
@@ -222,7 +222,7 @@ def __init__(self,
222222 prior_scale = 1. ,
223223 wishart_scale = 1. ,
224224 dof = 1. ,
225- cov_rank = None ,
225+ cov_rank = 3 ,
226226 ):
227227
228228 super (tDiscClassification , self ).__init__ ()
@@ -306,7 +306,7 @@ def logit_predictive(self, x):
306306 pred_cov = (Wx .variance + 1 ) * cov_sample
307307 return Normal (mean , torch .sqrt (pred_cov ))
308308
309- def predictive (self , x , n_samples = 10 ):
309+ def predictive (self , x , n_samples = 20 ):
310310 softmax_samples = F .softmax (self .logit_predictive (x ).rsample (sample_shape = torch .Size ([n_samples ])), dim = - 1 )
311311 return torch .clip (torch .mean (softmax_samples , dim = 0 ),min = 0. ,max = 1. )
312312
0 commit comments