Skip to content

Commit 7b1e4ec

Browse files
JasonKChowfacebook-github-bot
authored andcommitted
Monotonic rejection model and generator (facebookresearch#458)
Summary: monotonic rejection model GPU support, since they're tied to the generator, we also ensure the generators are gpu ready as well. Differential Revision: D65638150
1 parent d096c6a commit 7b1e4ec

File tree

8 files changed

+272
-21
lines changed

8 files changed

+272
-21
lines changed

aepsych/generators/monotonic_rejection_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def gen(
101101
)
102102

103103
# Augment bounds with deriv indicator
104-
bounds = torch.cat((model.bounds_, torch.zeros(2, 1)), dim=1)
104+
bounds = torch.cat((model.bounds_, torch.zeros(2, 1).to(model.device)), dim=1)
105105
# Fix deriv indicator to 0 during optimization
106106
fixed_features = {(bounds.shape[1] - 1): 0.0}
107107
# Fix explore features to random values

aepsych/models/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,12 @@ def set_train_data(self, inputs=None, targets=None, strict=False):
415415
def device(self) -> torch.device:
416416
# We assume all models have some parameters and all models will only use one device
417417
# notice that this has no setting, don't let users set device, use .to().
418-
return next(self.parameters()).device
418+
try:
419+
return next(self.parameters()).device
420+
except (
421+
AttributeError
422+
): # Fallback for cases where we need device before we have params
423+
return torch.device("cpu")
419424

420425
@property
421426
def train_inputs(self) -> Optional[Tuple[torch.Tensor]]:

aepsych/models/derivative_gp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
else:
9797
self.covar_module = covar_module
9898

99+
self.to(self.device)
99100
self._num_outputs = 1
100101
self.train_inputs = (train_x,)
101102
self.train_targets = train_y
@@ -111,6 +112,7 @@ def forward(self, x: torch.Tensor) -> MultivariateNormal:
111112
MultivariateNormal: Object containig mean and covariance
112113
of GP at these points.
113114
"""
115+
x = x.to(self.device)
114116
mean_x = self.mean_module(x)
115117
covar_x = self.covar_module(x)
116118
return MultivariateNormal(mean_x, covar_x)

aepsych/models/monotonic_rejection_gp.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from aepsych.factory.monotonic import monotonic_mean_covar_factory
1919
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
2020
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
21-
from aepsych.models.base import AEPsychMixin
21+
from aepsych.models.base import AEPsychModelDeviceMixin
2222
from aepsych.models.utils import select_inducing_points
2323
from aepsych.utils import _process_bounds, promote_0d
2424
from botorch.fit import fit_gpytorch_mll
@@ -32,7 +32,7 @@
3232
from torch import Tensor
3333

3434

35-
class MonotonicRejectionGP(AEPsychMixin, ApproximateGP):
35+
class MonotonicRejectionGP(AEPsychModelDeviceMixin, ApproximateGP):
3636
"""A monotonic GP using rejection sampling.
3737
3838
This takes the same insight as in e.g. Riihimäki & Vehtari 2010 (that the derivative of a GP
@@ -83,15 +83,15 @@ def __init__(
8383
objective (Optional[MCAcquisitionObjective], optional): Transformation of GP to apply before computing acquisition function. Defaults to identity transform for gaussian likelihood, probit transform for probit-bernoulli.
8484
extra_acqf_args (Optional[Dict[str, object]], optional): Additional arguments to pass into the acquisition function. Defaults to None.
8585
"""
86-
self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim)
86+
lb, ub, self.dim = _process_bounds(lb, ub, dim)
8787
if likelihood is None:
8888
likelihood = BernoulliLikelihood()
8989

9090
self.inducing_size = num_induc
9191
self.inducing_point_method = inducing_point_method
9292
inducing_points = select_inducing_points(
9393
inducing_size=self.inducing_size,
94-
bounds=self.bounds,
94+
bounds=torch.stack((lb, ub)),
9595
method="sobol",
9696
)
9797

@@ -134,7 +134,9 @@ def __init__(
134134

135135
super().__init__(variational_strategy)
136136

137-
self.bounds_ = torch.stack([self.lb, self.ub])
137+
self.register_buffer("lb", lb)
138+
self.register_buffer("ub", ub)
139+
self.register_buffer("bounds_", torch.stack([self.lb, self.ub]))
138140
self.mean_module = mean_module
139141
self.covar_module = covar_module
140142
self.likelihood = likelihood
@@ -144,7 +146,8 @@ def __init__(
144146
self.num_samples = num_samples
145147
self.num_rejection_samples = num_rejection_samples
146148
self.fixed_prior_mean = fixed_prior_mean
147-
self.inducing_points = inducing_points
149+
# self.inducing_points = inducing_points
150+
self.register_buffer("inducing_points", inducing_points)
148151

149152
def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
150153
"""Fit the model
@@ -161,7 +164,7 @@ def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
161164
X=self.train_inputs[0],
162165
bounds=self.bounds,
163166
method=self.inducing_point_method,
164-
)
167+
).to(self.device)
165168
self._set_model(train_x, train_y)
166169

167170
def _set_model(
@@ -284,13 +287,14 @@ def predict_probability(
284287
return self.predict(x, probability_space=True)
285288

286289
def _augment_with_deriv_index(self, x: Tensor, indx) -> Tensor:
290+
x = x.to(self.device)
287291
return torch.cat(
288-
(x, indx * torch.ones(x.shape[0], 1)),
292+
(x, indx * torch.ones(x.shape[0], 1).to(self.device)),
289293
dim=1,
290294
)
291295

292296
def _get_deriv_constraint_points(self) -> Tensor:
293-
deriv_cp = torch.tensor([])
297+
deriv_cp = torch.tensor([]).to(self.device)
294298
for i in self.monotonic_idxs:
295299
induc_i = self._augment_with_deriv_index(self.inducing_points, i + 1)
296300
deriv_cp = torch.cat((deriv_cp, induc_i), dim=0)
@@ -299,8 +303,8 @@ def _get_deriv_constraint_points(self) -> Tensor:
299303
@classmethod
300304
def from_config(cls, config: Config) -> MonotonicRejectionGP:
301305
classname = cls.__name__
302-
num_induc = config.gettensor(classname, "num_induc", fallback=25)
303-
num_samples = config.gettensor(classname, "num_samples", fallback=250)
306+
num_induc = config.getint(classname, "num_induc", fallback=25)
307+
num_samples = config.getint(classname, "num_samples", fallback=250)
304308
num_rejection_samples = config.getint(
305309
classname, "num_rejection_samples", fallback=5000
306310
)

aepsych/strategy.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,6 @@ class Strategy(object):
7171

7272
_n_eval_points: int = 1000
7373

74-
no_gpu_acqfs = (
75-
MonotonicMCAcquisition,
76-
MonotonicBernoulliMCMutualInformation,
77-
MonotonicMCPosteriorVariance,
78-
MonotonicMCLSE,
79-
)
80-
8174
def __init__(
8275
self,
8376
generator: Union[AEPsychGenerator, ParameterTransformedGenerator],
@@ -182,7 +175,7 @@ def __init__(
182175
)
183176
self.generator_device = torch.device("cpu")
184177
else:
185-
if hasattr(generator, "acqf") and generator.acqf in self.no_gpu_acqfs:
178+
if hasattr(generator, "acqf"):
186179
warnings.warn(
187180
f"GPU requested for acquistion function {type(generator.acqf).__name__}, but this acquisiton function does not support GPU! Using CPU instead.",
188181
UserWarning,
@@ -283,9 +276,11 @@ def normalize_inputs(
283276
x = x[None, :]
284277

285278
if self.x is not None:
279+
x = x.to(self.x)
286280
x = torch.cat((self.x, x), dim=0)
287281

288282
if self.y is not None:
283+
y = y.to(self.y)
289284
y = torch.cat((self.y, y), dim=0)
290285

291286
# Ensure the correct dtype

tests_gpu/acquisition/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta, Inc. and its affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import torch
9+
from aepsych.acquisition.monotonic_rejection import MonotonicMCLSE
10+
from aepsych.acquisition.objective import ProbitObjective
11+
from aepsych.models.derivative_gp import MixedDerivativeVariationalGP
12+
from botorch.acquisition.objective import IdentityMCObjective
13+
from botorch.utils.testing import BotorchTestCase
14+
15+
16+
class TestMonotonicAcq(BotorchTestCase):
17+
def test_monotonic_acq_gpu(self):
18+
# Init
19+
train_X_aug = torch.tensor(
20+
[[0.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 2.0, 0.0]]
21+
).cuda()
22+
deriv_constraint_points = torch.tensor(
23+
[[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [2.0, 2.0, 1.0]]
24+
).cuda()
25+
train_Y = torch.tensor([[1.0], [2.0], [3.0]]).cuda()
26+
27+
m = MixedDerivativeVariationalGP(
28+
train_x=train_X_aug, train_y=train_Y, inducing_points=train_X_aug
29+
).cuda()
30+
acq = MonotonicMCLSE(
31+
model=m,
32+
deriv_constraint_points=deriv_constraint_points,
33+
num_samples=5,
34+
num_rejection_samples=8,
35+
target=1.9,
36+
)
37+
self.assertTrue(isinstance(acq.objective, IdentityMCObjective))
38+
acq = MonotonicMCLSE(
39+
model=m,
40+
deriv_constraint_points=deriv_constraint_points,
41+
num_samples=5,
42+
num_rejection_samples=8,
43+
target=1.9,
44+
objective=ProbitObjective(),
45+
).cuda()
46+
# forward
47+
acq(train_X_aug)
48+
Xfull = torch.cat((train_X_aug, acq.deriv_constraint_points), dim=0)
49+
posterior = m.posterior(Xfull)
50+
samples = acq.sampler(posterior)
51+
self.assertEqual(samples.shape, torch.Size([5, 6, 1]))

0 commit comments

Comments
 (0)