Skip to content

Commit

Permalink
jitter crash fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kegl authored Jan 15, 2024
1 parent 33fc831 commit d3c90c0
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions HEBO/hebo/models/gp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from gpytorch.means import ConstantMean
from gpytorch.distributions import MultivariateNormal
from gpytorch.constraints import GreaterThan
from gpytorch.settings import cholesky_jitter

from ..util import filter_nan
from ..base_model import BaseModel
Expand Down Expand Up @@ -99,13 +100,24 @@ def fit(self, Xc : Tensor, Xe : Tensor, y : Tensor):
opt = torch.optim.Adam(self.gp.parameters(), lr = self.lr)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.lik, self.gp)
for epoch in range(self.num_epochs):
def closure():
dist = self.gp(self.Xc, self.Xe)
loss = -1 * mll(dist, self.y.squeeze())
opt.zero_grad()
loss.backward()
return loss
opt.step(closure)
jitter = 10 ** -8
cont = True
while cont:
cont = False
cholesky_jitter._set_value(
double_value=jitter, float_value=100*jitter, half_value=10000*jitter)
def closure():
dist = self.gp(self.Xc, self.Xe)
loss = -1 * mll(dist, self.y.squeeze())
opt.zero_grad()
loss.backward()
return loss
try:
opt.step(closure)
except:
jitter *= 10
cont = True
print(f'jitter = {jitter}')
if self.verbose and ((epoch + 1) % self.print_every == 0 or epoch == 0):
print('After %d epochs, loss = %g' % (epoch + 1, closure().item()), flush = True)
self.gp.eval()
Expand All @@ -114,7 +126,18 @@ def closure():
def predict(self, Xc, Xe):
Xc, Xe = self.xtrans(Xc, Xe)
with gpytorch.settings.fast_pred_var(), gpytorch.settings.debug(False):
pred = self.gp(Xc, Xe)
jitter = 10 ** -8
cont = True
while cont:
cont = False
cholesky_jitter._set_value(
double_value=jitter, float_value=100*jitter, half_value=10000*jitter)
try:
pred = self.gp(Xc, Xe)
except:
jitter *= 10
cont = True
print(f'jitter = {jitter}')
if self.pred_likeli:
pred = self.lik(pred)
mu_ = pred.mean.reshape(-1, self.num_out)
Expand Down

0 comments on commit d3c90c0

Please sign in to comment.