Skip to content

Commit

Permalink
Fix KeOps regressions from #2296.
Browse files Browse the repository at this point in the history
KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel.
(This computation happens during preconditioning, which requires the
diagonal of the already-formed kernel LinearOperator object.)
This error was because KeopsLinearOperator.diagonal calls to_dense on
the output of a batch kernel operation. However, to_dense is not defined
for KeOps LazyTensors.

This PR is in some sense a hack fix to this bug (a less hack fix will
require changes to KernelLinearOperator), but it is also a generally
nice and helpful refactor that will improve KeOps kernels in general.

The fixes:
- KeOpsKernels now only define a forward function, that will be used
both when we want to use KeOps and when we want to bypass it.
- KeOpsKernels now use a `_lazify_inputs` helper method, which
(potentially) wraps the inputs as KeOpsLazyTensors, or potentially
leaves the inputs as torch Tensors.
- The KeOps wrapping happens unless we want to bypass KeOps, which
occurs when either (1) the matrix is small (below Cholesky size) or (2)
when the use has turned off the `gpytorch.settings.use_keops` option
(*NEW IN THIS PR*).

Why this is beneficial:
- KeOps kernels now follow the same API as non-KeOps kernels (define a
forward method)
- The user now only has to define one forward method, that works in both
the keops and non-keops cases
- The `diagonal` call in KeopsLinearOperator constructs a batch 1x1
matrix, which is small enough to bypass keops and thus avoid the current
bug. (Hence why this solution is currently a hack, but could become less
hacky with a small modification to KernelLinearOperator and/or the
to_dense method in LinearOperator).

Other changes:
- Fix stability issues with the keops MaternKernel. (There were some NaN
issues)
- Introduce a `gpytorch.settings.use_keops` feature flag.
- Clean up KeOPs notebook

[Fixes #2363]
  • Loading branch information
gpleiss committed Sep 21, 2023
1 parent 43383c2 commit f7eaa80
Show file tree
Hide file tree
Showing 8 changed files with 319 additions and 349 deletions.
155 changes: 77 additions & 78 deletions examples/02_Scalable_Exact_GPs/KeOps_GP_Regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,14 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"import gpytorch\n",
"import tqdm.notebook as tqdm\n",
"from matplotlib import pyplot as plt\n",
"\n",
"%matplotlib inline\n",
Expand All @@ -45,22 +37,16 @@
"metadata": {},
"source": [
"### Downloading Data\n",
"We will be using the 3droad UCI dataset which contains a total of 278,319 data points. The next cell will download this dataset from a Google drive and load it."
"We will be using the 3droad UCI dataset which contains a total of 434,874 data points. We will split the data in half for training and half for testing.\n",
"\n",
"The next cell will download this dataset from a Google drive and load it."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading '3droad' UCI dataset...\n"
]
}
],
"outputs": [],
"source": [
"import urllib.request\n",
"import os.path\n",
Expand All @@ -76,15 +62,25 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Num train: 217437\n",
"Num test: 217437\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"N = data.shape[0]\n",
"# make train/val/test\n",
"n_train = int(0.8 * N)\n",
"n_train = int(0.5 * N)\n",
"\n",
"train_x, train_y = data[:n_train, :-1], data[:n_train, -1]\n",
"test_x, test_y = data[n_train:, :-1], data[n_train:, -1]\n",
"\n",
Expand All @@ -106,7 +102,12 @@
"output_device = torch.device('cuda:0')\n",
"\n",
"train_x, train_y = train_x.to(output_device), train_y.to(output_device)\n",
"test_x, test_y = test_x.to(output_device), test_y.to(output_device)"
"test_x, test_y = test_x.to(output_device), test_y.to(output_device)\n",
"\n",
"print(\n",
" f\"Num train: {train_y.size(-1)}\\n\"\n",
" f\"Num test: {test_y.size(-1)}\"\n",
")"
]
},
{
Expand All @@ -120,7 +121,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -139,16 +140,36 @@
"\n",
"# initialize likelihood and model\n",
"likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()\n",
"model = ExactGPModel(train_x, train_y, likelihood).cuda()"
"model = ExactGPModel(train_x, train_y, likelihood).cuda()\n",
"\n",
"# Because we know some properties about this dataset,\n",
"# we will initialize the lengthscale to be somewhat small\n",
"# This step isn't necessary, but it will help the model converge faster.\n",
"model.covar_module.base_kernel.lengthscale = 0.05"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {
"scrolled": false
},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "691194d2d51e4d389fef9f0f7cb34f6b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0%| | 0/25 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Find optimal model hyperparameters\n",
"model.train()\n",
Expand All @@ -158,64 +179,44 @@
"optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters\n",
"\n",
"# \"Loss\" for GPs - the marginal log likelihood\n",
"mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
"mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)\n",
"\n",
"import time\n",
"training_iter = 50\n",
"for i in range(training_iter):\n",
"training_iter = 25\n",
"iterator = tqdm.tqdm(range(training_iter), desc=\"Training\")\n",
"for i in iterator:\n",
" start_time = time.time()\n",
" # Zero gradients from previous iteration\n",
" optimizer.zero_grad()\n",
" # Output from model\n",
" output = model(train_x)\n",
" # Calc loss and backprop gradients\n",
" loss = -mll(output, train_y)\n",
" print_values = dict(\n",
" loss=loss.item(),\n",
" ls=model.covar_module.base_kernel.lengthscale.norm().item(),\n",
" os=model.covar_module.outputscale.item(),\n",
" noise=model.likelihood.noise.item(),\n",
" mu=model.mean_module.constant.item(),\n",
" )\n",
" iterator.set_postfix(**print_values)\n",
" loss.backward()\n",
" print('Iter %d/%d - Loss: %.3f lengthscale: %.3f noise: %.3f' % (\n",
" i + 1, training_iter, loss.item(),\n",
" model.covar_module.base_kernel.lengthscale.item(),\n",
" model.likelihood.noise.item()\n",
" ))\n",
" optimizer.step()\n",
" print(time.time() - start_time)"
" optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Compiling libKeOpstorchd7ba409487 in /home/jake.gardner/.cache/pykeops-1.1.1//build-libKeOpstorchd7ba409487:\n",
" formula: Sum_Reduction(((((Var(0,1,2) * Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1)))))) + (IntCst(1) + (Var(3,1,2) * Square(Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1))))))))) * Exp((Var(4,1,2) * Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1)))))))) * Var(5,3320,1)),0)\n",
" aliases: Var(0,1,2); Var(1,18,0); Var(2,18,1); Var(3,1,2); Var(4,1,2); Var(5,3320,1); \n",
" dtype : float32\n",
"... Done.\n",
"Compiling libKeOpstorch7385e76d34 in /home/jake.gardner/.cache/pykeops-1.1.1//build-libKeOpstorch7385e76d34:\n",
" formula: Sum_Reduction(((((Var(0,1,2) * Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1)))))) + (IntCst(1) + (Var(3,1,2) * Square(Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1))))))))) * Exp((Var(4,1,2) * Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1)))))))) * Var(5,1,1)),0)\n",
" aliases: Var(0,1,2); Var(1,18,0); Var(2,18,1); Var(3,1,2); Var(4,1,2); Var(5,1,1); \n",
" dtype : float32\n",
"... Done.\n",
"Compiling libKeOpstorch97105370ea in /home/jake.gardner/.cache/pykeops-1.1.1//build-libKeOpstorch97105370ea:\n",
" formula: Sum_Reduction(((((Var(0,1,2) * Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1)))))) + (IntCst(1) + (Var(3,1,2) * Square(Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1))))))))) * Exp((Var(4,1,2) * Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1)))))))) * Var(5,100,1)),0)\n",
" aliases: Var(0,1,2); Var(1,18,0); Var(2,18,1); Var(3,1,2); Var(4,1,2); Var(5,100,1); \n",
" dtype : float32\n",
"... Done.\n"
]
}
],
"outputs": [],
"source": [
"# Get into evaluation (predictive posterior) mode\n",
"model.eval()\n",
"likelihood.eval()\n",
"\n",
"# Test points are regularly spaced along [0,1]\n",
"# Make predictions by feeding model through likelihood\n",
"with torch.no_grad(), gpytorch.settings.fast_pred_var():\n",
" observed_pred = likelihood(model(test_x))"
" observed_pred = model.likelihood(model(test_x))"
]
},
{
Expand All @@ -227,29 +228,27 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.1068, device='cuda:0')"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
"name": "stdout",
"output_type": "stream",
"text": [
"RMSE: 0.138\n"
]
}
],
"source": [
"torch.sqrt(torch.mean(torch.pow(observed_pred.mean - test_y, 2)))"
"rmse = (observed_pred.mean - test_y).square().mean().sqrt().item()\n",
"print(f\"RMSE: {rmse:.3f}\")"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -263,7 +262,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
"version": "3.8.0"
}
},
"nbformat": 4,
Expand Down
3 changes: 2 additions & 1 deletion gpytorch/kernels/keops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .keops_kernel import KeOpsKernel
from .matern_kernel import MaternKernel
from .periodic_kernel import PeriodicKernel
from .rbf_kernel import RBFKernel

__all__ = ["MaternKernel", "RBFKernel", "PeriodicKernel"]
__all__ = ["KeOpsKernel", "MaternKernel", "PeriodicKernel", "RBFKernel"]
76 changes: 47 additions & 29 deletions gpytorch/kernels/keops/keops_kernel.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,66 @@
from abc import abstractmethod
from typing import Any
import warnings
from typing import Any, Tuple, Union

import torch
from linear_operator import LinearOperator
from torch import Tensor

from ... import settings
from ..kernel import Kernel

try:
import pykeops # noqa F401
from pykeops.torch import LazyTensor

def _lazify_and_expand_inputs(
x1: Tensor, x2: Tensor
) -> Tuple[Union[Tensor, LazyTensor], Union[Tensor, LazyTensor]]:
r"""
Potentially wrap inputs x1 and x2 as KeOps LazyTensors,
depending on whether or not we want to use KeOps under the hood or not.
"""
x1_ = x1[..., :, None, :]
x2_ = x2[..., None, :, :]
if _use_keops(x1, x2):
res = LazyTensor(x1_), LazyTensor(x2_)
return res
return x1_, x2_

def _use_keops(x1: Tensor, x2: Tensor) -> bool:
r"""
Determine whether or not to use KeOps under the hood
This largely depends on the size of the kernel matrix
There are situations where we do not want the KeOps linear operator to use KeOps under the hood.
See https://github.com/cornellius-gp/gpytorch/pull/1319
"""
return (
settings.use_keops.on()
and x1.size(-2) >= settings.max_cholesky_size.value()
and x2.size(-2) >= settings.max_cholesky_size.value()
)

class KeOpsKernel(Kernel):
@abstractmethod
def _nonkeops_forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any):
r"""
Computes the covariance matrix (or diagonal) without using KeOps.
This function must implement both the diag=True and diag=False options.
"""
raise NotImplementedError

@abstractmethod
def _keops_forward(self, x1: Tensor, x2: Tensor, **kwargs: Any):
r"""
Computes the covariance matrix with KeOps.
This function only implements the diag=False option, and no diag keyword should be passed in.
"""
raise NotImplementedError

def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any):
if diag:
return self._nonkeops_forward(x1, x2, diag=True, **kwargs)
elif x1.size(-2) < settings.max_cholesky_size.value() or x2.size(-2) < settings.max_cholesky_size.value():
return self._nonkeops_forward(x1, x2, diag=False, **kwargs)
else:
return self._keops_forward(x1, x2, **kwargs)

def __call__(self, *args: Any, **kwargs: Any):
def __call__(self, *args: Any, **kwargs: Any) -> Union[LinearOperator, Tensor, LazyTensor]:
# Hotfix for zero gradients. See https://github.com/cornellius-gp/gpytorch/issues/1543
args = [arg.contiguous() if torch.is_tensor(arg) else arg for arg in args]
kwargs = {k: v.contiguous() if torch.is_tensor(v) else v for k, v in kwargs.items()}
return super().__call__(*args, **kwargs)

except ImportError:

def _lazify_and_expand_inputs(x1: Tensor, x2: Tensor) -> Tuple[Tensor, Tensor]:
x1_ = x1[..., :, None, :]
x2_ = x2[..., None, :, :]
return x1_, x2_

def _use_keops(x1: Tensor, x2: Tensor) -> bool:
return False

class KeOpsKernel(Kernel):
def __init__(self, *args: Any, **kwargs: Any):
raise RuntimeError("You must have KeOps installed to use a KeOpsKernel")
def __call__(self, *args: Any, **kwargs: Any) -> Union[LinearOperator, Tensor]:
warnings.warn(
"KeOps is not installed. " f"{type(self)} will revert to the the non-keops version of this kernel.",
RuntimeWarning,
)
return super().__call__(*args, **kwargs)
Loading

0 comments on commit f7eaa80

Please sign in to comment.