Skip to content

Commit 15a87e1

Browse files
David Erikssonmeta-codesync[bot]
authored andcommitted
Fix shape error in qNegIntegratedPosteriorVariance (#3068)
Summary: Pull Request resolved: #3068 `qNegIntegratedPosteriorVariance` doesn't work through MBM due to a shape bug. This diff addresses the underlying issue and would allow us to use active learning through Ax (at some point). Reviewed By: saitcakmak Differential Revision: D85507279 fbshipit-source-id: 1917902e3abd989596147bd4bb74f934697cb5df
1 parent 5818db2 commit 15a87e1

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

botorch/acquisition/active_learning.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,12 @@ def forward(self, X: Tensor) -> Tensor:
120120
)
121121

122122
neg_variance = posterior.variance.mul(-1.0)
123-
124-
if self.posterior_transform is None:
125-
# if single-output, shape is 1 x batch_shape x num_grid_points x 1
126-
return neg_variance.mean(dim=-2).squeeze(-1).squeeze(0)
127-
else:
123+
if self.model.num_outputs > 1:
128124
# if multi-output + obj, shape is num_grid_points x batch_shape x 1 x 1
129125
return neg_variance.mean(dim=0).squeeze(-1).squeeze(-1)
126+
else:
127+
# if single-output, shape is 1 x batch_shape x num_grid_points x 1
128+
return neg_variance.mean(dim=-2).squeeze(-1).squeeze(0)
130129

131130

132131
class PairwiseMCPosteriorVariance(MCAcquisitionFunction):

test/acquisition/test_active_learning.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,29 @@ def test_q_neg_int_post_variance(self):
132132
val_exp = -0.5 * variance.mean(dim=0).view(3, -1).mean(dim=-1)
133133
self.assertAllClose(val, val_exp, atol=1e-4)
134134

135+
# single-output with posterior transform
136+
mean = torch.zeros(2, 10, 3, 1, device=self.device, dtype=dtype)
137+
variance = torch.ones(2, 10, 3, 1, device=self.device, dtype=dtype)
138+
cov = torch.diag_embed(variance.view(2, 10, -1))
139+
f_posterior = GPyTorchPosterior(MultitaskMultivariateNormal(mean, cov))
140+
mc_points = torch.rand(10, 1, device=self.device, dtype=dtype)
141+
mfm = MockModel(f_posterior)
142+
with mock.patch.object(MockModel, "fantasize", return_value=mfm):
143+
with mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs:
144+
mock_num_outputs.return_value = 1
145+
mm = MockModel(None)
146+
qNIPV = qNegIntegratedPosteriorVariance(
147+
model=mm,
148+
mc_points=mc_points,
149+
posterior_transform=ScalarizedPosteriorTransform(
150+
weights=torch.tensor([1.0], device=self.device, dtype=dtype)
151+
),
152+
)
153+
X = torch.empty(2, 3, 1, 1, device=self.device, dtype=dtype)
154+
val = qNIPV(X)
155+
val_exp = -variance.mean(dim=-2).squeeze(-1).squeeze(0)
156+
self.assertAllClose(val, val_exp, atol=1e-4)
157+
135158

136159
class TestPairwiseMCPosteriorVariance(BotorchTestCase):
137160
def test_pairwise_mc_post_var(self):

0 commit comments

Comments
 (0)