Skip to content

Commit

Permalink
Reduce memory usage in ConstrainedMaxPosteriorSampling
Browse files Browse the repository at this point in the history
Summary:
Resolves #2620

When a `ModelListGP` is used for the constraint model, we can loop over the sub-models to generate the samples to reduce peak memory usage.

Differential Revision: D65881867
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 13, 2024
1 parent 5181cb8 commit 953d7f8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
15 changes: 12 additions & 3 deletions botorch/generation/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,19 @@ def forward(
)
Y_samples = posterior.rsample(sample_shape=torch.Size([num_samples]))

c_posterior = self.constraint_model.posterior(
X=X, observation_noise=observation_noise
# Loop over the constraint models (if possible) to reduce peak memory usage.
constraint_models = (
self.constraint_model.models
if isinstance(self.constraint_model, ModelListGP)
else [self.constraint_model]
)
C_samples = c_posterior.rsample(sample_shape=torch.Size([num_samples]))
C_samples_list = []
for c_model in constraint_models:
c_posterior = c_model.posterior(X=X, observation_noise=observation_noise)
C_samples_list.append(
c_posterior.rsample(sample_shape=torch.Size([num_samples]))
)
C_samples = torch.cat(C_samples_list, dim=-1)

# Convert the objective and constraint samples into a scalar-valued "score"
scores = self._convert_samples_to_scores(
Expand Down
28 changes: 12 additions & 16 deletions test/generation/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,11 @@ def test_init(self):
ConstrainedMaxPosteriorSampling(mm, cmms, objective=obj, replacement=False)

def test_constrained_max_posterior_sampling(self):
batch_shapes = (torch.Size(), torch.Size([3]), torch.Size([3, 2]))
dtypes = (torch.float, torch.double)
for (
batch_shape,
dtype,
N,
num_samples,
d,
observation_noise,
) in itertools.product(
batch_shapes, dtypes, (5, 6), (1, 2), (1, 2), (True, False)
):
for batch_shape, dtype, N, num_samples, d, observation_noise in [
(torch.Size(), torch.float, 5, 1, 1, False),
(torch.Size([3]), torch.float, 6, 3, 2, False),
(torch.Size([3, 2]), torch.double, 6, 3, 2, True),
]:
tkwargs = {"device": self.device, "dtype": dtype}
expected_shape = torch.Size(list(batch_shape) + [num_samples] + [d])
# X is `batch_shape x N x d` = batch_shape x N x 1.
Expand All @@ -193,16 +186,19 @@ def test_constrained_max_posterior_sampling(self):
with mock.patch.object(MockModel, "posterior", return_value=mp):
mm = MockModel(None)
c_model1 = SingleTaskGP(
X, torch.randn(X.shape[0:-1], **tkwargs).unsqueeze(-1)
X, torch.randn(X.shape[:-1], **tkwargs).unsqueeze(-1)
)
c_model2 = SingleTaskGP(
X, torch.randn(X.shape[0:-1], **tkwargs).unsqueeze(-1)
X, torch.randn(X.shape[:-1], **tkwargs).unsqueeze(-1)
)
c_model3 = SingleTaskGP(
X, torch.randn(X.shape[0:-1], **tkwargs).unsqueeze(-1)
X, torch.randn(X.shape[:-1], **tkwargs).unsqueeze(-1)
)
cmms1 = MockModel(MockPosterior(mean=None))
cmms2 = ModelListGP(c_model1, c_model2)
cmms2 = SingleTaskGP( # Multi-output model as constraint.
X, torch.randn((X.shape[0:-1] + (4,)), **tkwargs)
)
# ModelListGP as constraint.
cmms3 = ModelListGP(c_model1, c_model2, c_model3)
for cmms in [cmms1, cmms2, cmms3]:
CPS = ConstrainedMaxPosteriorSampling(mm, cmms)
Expand Down

0 comments on commit 953d7f8

Please sign in to comment.