Skip to content

Commit 9d6dd1f

Browse files
committed
Increase disc predictive sampling
1 parent 5b827d0 commit 9d6dd1f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

vbll/layers/classification.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self,
7878
prior_scale=1.,
7979
wishart_scale=1.,
8080
dof=1.,
81-
cov_rank=None):
81+
cov_rank=3):
8282
super(DiscClassification, self).__init__()
8383

8484
self.wishart_scale = wishart_scale
@@ -155,7 +155,7 @@ def forward(self, x):
155155
def logit_predictive(self, x):
156156
return (self.W() @ x[..., None]).squeeze(-1) + self.noise()
157157

158-
def predictive(self, x, n_samples = 10):
158+
def predictive(self, x, n_samples = 20):
159159
softmax_samples = F.softmax(self.logit_predictive(x).rsample(sample_shape=torch.Size([n_samples])), dim=-1)
160160
return torch.clip(torch.mean(softmax_samples, dim=0),min=0.,max=1.)
161161

@@ -222,7 +222,7 @@ def __init__(self,
222222
prior_scale=1.,
223223
wishart_scale=1.,
224224
dof=1.,
225-
cov_rank=None,
225+
cov_rank=3,
226226
):
227227

228228
super(tDiscClassification, self).__init__()
@@ -306,7 +306,7 @@ def logit_predictive(self, x):
306306
pred_cov = (Wx.variance + 1) * cov_sample
307307
return Normal(mean, torch.sqrt(pred_cov))
308308

309-
def predictive(self, x, n_samples = 10):
309+
def predictive(self, x, n_samples = 20):
310310
softmax_samples = F.softmax(self.logit_predictive(x).rsample(sample_shape=torch.Size([n_samples])), dim=-1)
311311
return torch.clip(torch.mean(softmax_samples, dim=0),min=0.,max=1.)
312312

0 commit comments

Comments
 (0)