1818from aepsych .factory .monotonic import monotonic_mean_covar_factory
1919from aepsych .kernels .rbf_partial_grad import RBFKernelPartialObsGrad
2020from aepsych .means .constant_partial_grad import ConstantMeanPartialObsGrad
21- from aepsych .models .base import AEPsychMixin
21+ from aepsych .models .base import AEPsychModelDeviceMixin
2222from aepsych .models .utils import select_inducing_points
2323from aepsych .utils import _process_bounds , promote_0d
2424from botorch .fit import fit_gpytorch_mll
3232from torch import Tensor
3333
3434
35- class MonotonicRejectionGP (AEPsychMixin , ApproximateGP ):
35+ class MonotonicRejectionGP (AEPsychModelDeviceMixin , ApproximateGP ):
3636 """A monotonic GP using rejection sampling.
3737
3838 This takes the same insight as in e.g. Riihimäki & Vehtari 2010 (that the derivative of a GP
@@ -83,15 +83,15 @@ def __init__(
8383 objective (Optional[MCAcquisitionObjective], optional): Transformation of GP to apply before computing acquisition function. Defaults to identity transform for gaussian likelihood, probit transform for probit-bernoulli.
8484 extra_acqf_args (Optional[Dict[str, object]], optional): Additional arguments to pass into the acquisition function. Defaults to None.
8585 """
86- self . lb , self . ub , self .dim = _process_bounds (lb , ub , dim )
86+ lb , ub , self .dim = _process_bounds (lb , ub , dim )
8787 if likelihood is None :
8888 likelihood = BernoulliLikelihood ()
8989
9090 self .inducing_size = num_induc
9191 self .inducing_point_method = inducing_point_method
9292 inducing_points = select_inducing_points (
9393 inducing_size = self .inducing_size ,
94- bounds = self . bounds ,
94+ bounds = torch . stack (( lb , ub )) ,
9595 method = "sobol" ,
9696 )
9797
@@ -134,7 +134,9 @@ def __init__(
134134
135135 super ().__init__ (variational_strategy )
136136
137- self .bounds_ = torch .stack ([self .lb , self .ub ])
137+ self .register_buffer ("lb" , lb )
138+ self .register_buffer ("ub" , ub )
139+ self .register_buffer ("bounds_" , torch .stack ([self .lb , self .ub ]))
138140 self .mean_module = mean_module
139141 self .covar_module = covar_module
140142 self .likelihood = likelihood
@@ -144,7 +146,7 @@ def __init__(
144146 self .num_samples = num_samples
145147 self .num_rejection_samples = num_rejection_samples
146148 self .fixed_prior_mean = fixed_prior_mean
147- self .inducing_points = inducing_points
149+ self .register_buffer ( " inducing_points" , inducing_points )
148150
149151 def fit (self , train_x : Tensor , train_y : Tensor , ** kwargs ) -> None :
150152 """Fit the model
@@ -161,7 +163,7 @@ def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
161163 X = self .train_inputs [0 ],
162164 bounds = self .bounds ,
163165 method = self .inducing_point_method ,
164- )
166+ ). to ( self . device )
165167 self ._set_model (train_x , train_y )
166168
167169 def _set_model (
@@ -284,13 +286,14 @@ def predict_probability(
284286 return self .predict (x , probability_space = True )
285287
286288 def _augment_with_deriv_index (self , x : Tensor , indx ) -> Tensor :
289+ x = x .to (self .device )
287290 return torch .cat (
288- (x , indx * torch .ones (x .shape [0 ], 1 )),
291+ (x , indx * torch .ones (x .shape [0 ], 1 ). to ( self . device ) ),
289292 dim = 1 ,
290293 )
291294
292295 def _get_deriv_constraint_points (self ) -> Tensor :
293- deriv_cp = torch .tensor ([])
296+ deriv_cp = torch .tensor ([]). to ( self . device )
294297 for i in self .monotonic_idxs :
295298 induc_i = self ._augment_with_deriv_index (self .inducing_points , i + 1 )
296299 deriv_cp = torch .cat ((deriv_cp , induc_i ), dim = 0 )
@@ -299,8 +302,8 @@ def _get_deriv_constraint_points(self) -> Tensor:
299302 @classmethod
300303 def from_config (cls , config : Config ) -> MonotonicRejectionGP :
301304 classname = cls .__name__
302- num_induc = config .gettensor (classname , "num_induc" , fallback = 25 )
303- num_samples = config .gettensor (classname , "num_samples" , fallback = 250 )
305+ num_induc = config .getint (classname , "num_induc" , fallback = 25 )
306+ num_samples = config .getint (classname , "num_samples" , fallback = 250 )
304307 num_rejection_samples = config .getint (
305308 classname , "num_rejection_samples" , fallback = 5000
306309 )
0 commit comments