From f7eaa80d43a13ac6119d267787ed46a7eacf7e3f Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 21 Sep 2023 20:16:55 +0000 Subject: [PATCH] Fix KeOps regressions from #2296. KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel. (This computation happens during preconditioning, which requires the diagonal of the already-formed kernel LinearOperator object.) This error was because KeopsLinearOperator.diagonal calls to_dense on the output of a batch kernel operation. However, to_dense is not defined for KeOps LazyTensors. This PR is in some sense a hack fix to this bug (a less hack fix will require changes to KernelLinearOperator), but it is also a generally nice and helpful refactor that will improve KeOps kernels in general. The fixes: - KeOpsKernels now only define a forward function, that will be used both when we want to use KeOps and when we want to bypass it. - KeOpsKernels now use a `_lazify_inputs` helper method, which (potentially) wraps the inputs as KeOpsLazyTensors, or potentially leaves the inputs as torch Tensors. - The KeOps wrapping happens unless we want to bypass KeOps, which occurs when either (1) the matrix is small (below Cholesky size) or (2) when the use has turned off the `gpytorch.settings.use_keops` option (*NEW IN THIS PR*). Why this is beneficial: - KeOps kernels now follow the same API as non-KeOps kernels (define a forward method) - The user now only has to define one forward method, that works in both the keops and non-keops cases - The `diagonal` call in KeopsLinearOperator constructs a batch 1x1 matrix, which is small enough to bypass keops and thus avoid the current bug. (Hence why this solution is currently a hack, but could become less hacky with a small modification to KernelLinearOperator and/or the to_dense method in LinearOperator). Other changes: - Fix stability issues with the keops MaternKernel. (There were some NaN issues) - Introduce a `gpytorch.settings.use_keops` feature flag. - Clean up KeOPs notebook [Fixes #2363] --- .../KeOps_GP_Regression.ipynb | 155 +++++++++--------- gpytorch/kernels/keops/__init__.py | 3 +- gpytorch/kernels/keops/keops_kernel.py | 76 +++++---- gpytorch/kernels/keops/matern_kernel.py | 153 ++++++++--------- gpytorch/kernels/keops/periodic_kernel.py | 143 +++++++--------- gpytorch/kernels/keops/rbf_kernel.py | 97 +++++------ gpytorch/settings.py | 14 ++ gpytorch/test/base_keops_test_case.py | 27 +-- 8 files changed, 319 insertions(+), 349 deletions(-) diff --git a/examples/02_Scalable_Exact_GPs/KeOps_GP_Regression.ipynb b/examples/02_Scalable_Exact_GPs/KeOps_GP_Regression.ipynb index d17048e82..5f796a282 100644 --- a/examples/02_Scalable_Exact_GPs/KeOps_GP_Regression.ipynb +++ b/examples/02_Scalable_Exact_GPs/KeOps_GP_Regression.ipynb @@ -17,22 +17,14 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "import math\n", "import torch\n", "import gpytorch\n", + "import tqdm.notebook as tqdm\n", "from matplotlib import pyplot as plt\n", "\n", "%matplotlib inline\n", @@ -45,22 +37,16 @@ "metadata": {}, "source": [ "### Downloading Data\n", - "We will be using the 3droad UCI dataset which contains a total of 278,319 data points. The next cell will download this dataset from a Google drive and load it." + "We will be using the 3droad UCI dataset which contains a total of 434,874 data points. We will split the data in half for training and half for testing.\n", + "\n", + "The next cell will download this dataset from a Google drive and load it." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading '3droad' UCI dataset...\n" - ] - } - ], + "outputs": [], "source": [ "import urllib.request\n", "import os.path\n", @@ -76,15 +62,25 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num train: 217437\n", + "Num test: 217437\n" + ] + } + ], "source": [ "import numpy as np\n", "\n", "N = data.shape[0]\n", "# make train/val/test\n", - "n_train = int(0.8 * N)\n", + "n_train = int(0.5 * N)\n", + "\n", "train_x, train_y = data[:n_train, :-1], data[:n_train, -1]\n", "test_x, test_y = data[n_train:, :-1], data[n_train:, -1]\n", "\n", @@ -106,7 +102,12 @@ "output_device = torch.device('cuda:0')\n", "\n", "train_x, train_y = train_x.to(output_device), train_y.to(output_device)\n", - "test_x, test_y = test_x.to(output_device), test_y.to(output_device)" + "test_x, test_y = test_x.to(output_device), test_y.to(output_device)\n", + "\n", + "print(\n", + " f\"Num train: {train_y.size(-1)}\\n\"\n", + " f\"Num test: {test_y.size(-1)}\"\n", + ")" ] }, { @@ -120,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -139,16 +140,36 @@ "\n", "# initialize likelihood and model\n", "likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()\n", - "model = ExactGPModel(train_x, train_y, likelihood).cuda()" + "model = ExactGPModel(train_x, train_y, likelihood).cuda()\n", + "\n", + "# Because we know some properties about this dataset,\n", + "# we will initialize the lengthscale to be somewhat small\n", + "# This step isn't necessary, but it will help the model converge faster.\n", + "model.covar_module.base_kernel.lengthscale = 0.05" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "691194d2d51e4d389fef9f0f7cb34f6b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0%| | 0/25 [00:00 Tuple[Union[Tensor, LazyTensor], Union[Tensor, LazyTensor]]: + r""" + Potentially wrap inputs x1 and x2 as KeOps LazyTensors, + depending on whether or not we want to use KeOps under the hood or not. + """ + x1_ = x1[..., :, None, :] + x2_ = x2[..., None, :, :] + if _use_keops(x1, x2): + res = LazyTensor(x1_), LazyTensor(x2_) + return res + return x1_, x2_ + + def _use_keops(x1: Tensor, x2: Tensor) -> bool: + r""" + Determine whether or not to use KeOps under the hood + This largely depends on the size of the kernel matrix + + There are situations where we do not want the KeOps linear operator to use KeOps under the hood. + See https://github.com/cornellius-gp/gpytorch/pull/1319 + """ + return ( + settings.use_keops.on() + and x1.size(-2) >= settings.max_cholesky_size.value() + and x2.size(-2) >= settings.max_cholesky_size.value() + ) class KeOpsKernel(Kernel): - @abstractmethod - def _nonkeops_forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any): - r""" - Computes the covariance matrix (or diagonal) without using KeOps. - This function must implement both the diag=True and diag=False options. - """ - raise NotImplementedError - - @abstractmethod - def _keops_forward(self, x1: Tensor, x2: Tensor, **kwargs: Any): - r""" - Computes the covariance matrix with KeOps. - This function only implements the diag=False option, and no diag keyword should be passed in. - """ - raise NotImplementedError - - def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any): - if diag: - return self._nonkeops_forward(x1, x2, diag=True, **kwargs) - elif x1.size(-2) < settings.max_cholesky_size.value() or x2.size(-2) < settings.max_cholesky_size.value(): - return self._nonkeops_forward(x1, x2, diag=False, **kwargs) - else: - return self._keops_forward(x1, x2, **kwargs) - - def __call__(self, *args: Any, **kwargs: Any): + def __call__(self, *args: Any, **kwargs: Any) -> Union[LinearOperator, Tensor, LazyTensor]: # Hotfix for zero gradients. See https://github.com/cornellius-gp/gpytorch/issues/1543 args = [arg.contiguous() if torch.is_tensor(arg) else arg for arg in args] kwargs = {k: v.contiguous() if torch.is_tensor(v) else v for k, v in kwargs.items()} @@ -43,6 +49,18 @@ def __call__(self, *args: Any, **kwargs: Any): except ImportError: + def _lazify_and_expand_inputs(x1: Tensor, x2: Tensor) -> Tuple[Tensor, Tensor]: + x1_ = x1[..., :, None, :] + x2_ = x2[..., None, :, :] + return x1_, x2_ + + def _use_keops(x1: Tensor, x2: Tensor) -> bool: + return False + class KeOpsKernel(Kernel): - def __init__(self, *args: Any, **kwargs: Any): - raise RuntimeError("You must have KeOps installed to use a KeOpsKernel") + def __call__(self, *args: Any, **kwargs: Any) -> Union[LinearOperator, Tensor]: + warnings.warn( + "KeOps is not installed. " f"{type(self)} will revert to the the non-keops version of this kernel.", + RuntimeWarning, + ) + return super().__call__(*args, **kwargs) diff --git a/gpytorch/kernels/keops/matern_kernel.py b/gpytorch/kernels/keops/matern_kernel.py index 73365c2d8..3ea5235fc 100644 --- a/gpytorch/kernels/keops/matern_kernel.py +++ b/gpytorch/kernels/keops/matern_kernel.py @@ -1,92 +1,71 @@ #!/usr/bin/env python3 import math -import torch from linear_operator.operators import KernelLinearOperator -from ..matern_kernel import MaternKernel as GMaternKernel -from .keops_kernel import KeOpsKernel - -try: - from pykeops.torch import LazyTensor as KEOLazyTensor - - def _covar_func(x1, x2, nu=2.5, **params): - x1_ = KEOLazyTensor(x1[..., :, None, :]) - x2_ = KEOLazyTensor(x2[..., None, :, :]) - - distance = ((x1_ - x2_) ** 2).sum(-1).sqrt() - exp_component = (-math.sqrt(nu * 2) * distance).exp() - - if nu == 0.5: - constant_component = 1 - elif nu == 1.5: - constant_component = (math.sqrt(3) * distance) + 1 - elif nu == 2.5: - constant_component = (math.sqrt(5) * distance) + (1 + 5.0 / 3.0 * (distance**2)) - - return constant_component * exp_component - - class MaternKernel(KeOpsKernel): - """ - Implements the Matern kernel using KeOps as a driver for kernel matrix multiplies. - - This class can be used as a drop in replacement for :class:`gpytorch.kernels.MaternKernel` in most cases, - and supports the same arguments. - - :param nu: (Default: 2.5) The smoothness parameter. - :type nu: float (0.5, 1.5, or 2.5) - :param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each - input dimension. It should be `d` if x1 is a `... x n x d` matrix. - :type ard_num_dims: int, optional - :param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each - batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output. - :type batch_shape: torch.Size, optional - :param active_dims: (Default: `None`) Set this if you want to - compute the covariance of only a few input dimensions. The ints - corresponds to the indices of the dimensions. - :type active_dims: Tuple(int) - :param lengthscale_prior: (Default: `None`) - Set this if you want to apply a prior to the lengthscale parameter. - :type lengthscale_prior: ~gpytorch.priors.Prior, optional - :param lengthscale_constraint: (Default: `Positive`) Set this if you want - to apply a constraint to the lengthscale parameter. - :type lengthscale_constraint: ~gpytorch.constraints.Interval, optional - :param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors). - :type eps: float, optional - """ - - has_lengthscale = True - - def __init__(self, nu=2.5, **kwargs): - if nu not in {0.5, 1.5, 2.5}: - raise RuntimeError("nu expected to be 0.5, 1.5, or 2.5") - super(MaternKernel, self).__init__(**kwargs) - self.nu = nu - - def _nonkeops_forward(self, x1, x2, diag=False, **kwargs): - mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)] - x1_ = (x1 - mean) / self.lengthscale - x2_ = (x2 - mean) / self.lengthscale - - distance = self.covar_dist(x1_, x2_, diag=diag, **kwargs) - exp_component = torch.exp(-math.sqrt(self.nu * 2) * distance) - - if self.nu == 0.5: - constant_component = 1 - elif self.nu == 1.5: - constant_component = (math.sqrt(3) * distance).add(1) - elif self.nu == 2.5: - constant_component = (math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance**2) - return constant_component * exp_component - - def _keops_forward(self, x1, x2, **kwargs): - mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)] - x1_ = (x1 - mean) / self.lengthscale - x2_ = (x2 - mean) / self.lengthscale - # return KernelLinearOperator inst only when calculating the whole covariance matrix - return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs) - -except ImportError: - - class MaternKernel(GMaternKernel): - pass +from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel + + +def _covar_func(x1, x2, nu=2.5, **params): + x1_, x2_ = _lazify_and_expand_inputs(x1, x2) + + sq_distance = ((x1_ - x2_) ** 2).sum(-1) + distance = (sq_distance + 1e-20).sqrt() + # ^^ Need to add epsilon to prevent small negative values with the sqrt + # backward pass (otherwise we get NaNs). + # using .clamp(1e-20, math.inf) doesn't work in KeOps; it also creates NaNs + exp_component = (-math.sqrt(nu * 2) * distance).exp() + + if nu == 0.5: + constant_component = 1 + elif nu == 1.5: + constant_component = (math.sqrt(3) * distance) + 1 + elif nu == 2.5: + constant_component = (math.sqrt(5) * distance) + (1 + 5.0 / 3.0 * sq_distance) + + return constant_component * exp_component + + +class MaternKernel(KeOpsKernel): + """ + Implements the Matern kernel using KeOps as a driver for kernel matrix multiplies. + + This class can be used as a drop in replacement for :class:`gpytorch.kernels.MaternKernel` in most cases, + and supports the same arguments. + + :param nu: (Default: 2.5) The smoothness parameter. + :type nu: float (0.5, 1.5, or 2.5) + :param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each + input dimension. It should be `d` if x1 is a `... x n x d` matrix. + :type ard_num_dims: int, optional + :param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each + batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output. + :type batch_shape: torch.Size, optional + :param active_dims: (Default: `None`) Set this if you want to + compute the covariance of only a few input dimensions. The ints + corresponds to the indices of the dimensions. + :type active_dims: Tuple(int) + :param lengthscale_prior: (Default: `None`) + Set this if you want to apply a prior to the lengthscale parameter. + :type lengthscale_prior: ~gpytorch.priors.Prior, optional + :param lengthscale_constraint: (Default: `Positive`) Set this if you want + to apply a constraint to the lengthscale parameter. + :type lengthscale_constraint: ~gpytorch.constraints.Interval, optional + :param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors). + :type eps: float, optional + """ + + has_lengthscale = True + + def __init__(self, nu=2.5, **kwargs): + if nu not in {0.5, 1.5, 2.5}: + raise RuntimeError("nu expected to be 0.5, 1.5, or 2.5") + super().__init__(**kwargs) + self.nu = nu + + def forward(self, x1, x2, **kwargs): + mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)] + x1_ = (x1 - mean) / self.lengthscale + x2_ = (x2 - mean) / self.lengthscale + # return KernelLinearOperator inst only when calculating the whole covariance matrix + return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs) diff --git a/gpytorch/kernels/keops/periodic_kernel.py b/gpytorch/kernels/keops/periodic_kernel.py index b844425fa..fe4831a1d 100644 --- a/gpytorch/kernels/keops/periodic_kernel.py +++ b/gpytorch/kernels/keops/periodic_kernel.py @@ -5,89 +5,60 @@ from linear_operator.operators import KernelLinearOperator from ..periodic_kernel import PeriodicKernel as GPeriodicKernel -from .keops_kernel import KeOpsKernel - -# from ...kernels import PeriodicKernel gives a cyclic import - -try: - from pykeops.torch import LazyTensor as KEOLazyTensor - - def _covar_func(x1, x2, lengthscale, **kwargs): - # symbolic array of shape ..., ndatax1_ x 1 x ndim - x1_ = KEOLazyTensor(x1[..., :, None, :]) - # symbolic array of shape ..., 1 x ndatax2_ x ndim - x2_ = KEOLazyTensor(x2[..., None, :, :]) - lengthscale = lengthscale[..., None, None, 0, :] # 1 x 1 x ndim - # do not use .power(2.0) as it gives NaN values on cuda - # seems related to https://github.com/getkeops/keops/issues/112 - K = ((((x1_ - x2_).abs().sin()) ** 2) * (-2.0 / lengthscale)).sum(-1).exp() - return K - - # subclass from original periodic kernel to reduce code duplication - class PeriodicKernel(KeOpsKernel, GPeriodicKernel): - """ - Implements the Periodic Kernel using KeOps as a driver for kernel matrix multiplies. - - This class can be used as a drop in replacement for :class:`gpytorch.kernels.PeriodicKernel` in most cases, - and supports the same arguments. - - :param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each - input dimension. It should be `d` if x1 is a `... x n x d` matrix. - :type ard_num_dims: int, optional - :param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each - batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output. - :type batch_shape: torch.Size, optional - :param active_dims: (Default: `None`) Set this if you want to - compute the covariance of only a few input dimensions. The ints - corresponds to the indices of the dimensions. - :type active_dims: Tuple(int) - :param period_length_prior: (Default: `None`) - Set this if you want to apply a prior to the period length parameter. - :type period_length_prior: ~gpytorch.priors.Prior, optional - :param period_length_constraint: (Default: `Positive`) Set this if you want - to apply a constraint to the period length parameter. - :type period_length_constraint: ~gpytorch.constraints.Interval, optional - :param lengthscale_prior: (Default: `None`) - Set this if you want to apply a prior to the lengthscale parameter. - :type lengthscale_prior: ~gpytorch.priors.Prior, optional - :param lengthscale_constraint: (Default: `Positive`) Set this if you want - to apply a constraint to the lengthscale parameter. - :type lengthscale_constraint: ~gpytorch.constraints.Interval, optional - :param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors). - :type eps: float, optional - - :var torch.Tensor period_length: The period length parameter. Size/shape of parameter depends on the - ard_num_dims and batch_shape arguments. - """ - - has_lengthscale = True - - # code from the already-implemented Periodic Kernel - def _nonkeops_forward(self, x1, x2, diag=False, **kwargs): - x1_ = x1.div(self.period_length / math.pi) - x2_ = x2.div(self.period_length / math.pi) - - # We are automatically overriding last_dim_is_batch here so that we can manually sum over dimensions. - diff = self.covar_dist(x1_, x2_, diag=diag, last_dim_is_batch=True) - - if diag: - lengthscale = self.lengthscale[..., 0, :, None] - else: - lengthscale = self.lengthscale[..., 0, :, None, None] - - exp_term = diff.sin().pow(2.0).div(lengthscale).mul(-2.0) - exp_term = exp_term.sum(dim=(-2 if diag else -3)) - - return exp_term.exp() - - def _keops_forward(self, x1, x2, **kwargs): - x1_ = x1.div(self.period_length / math.pi) - x2_ = x2.div(self.period_length / math.pi) - # return KernelLinearOperator inst only when calculating the whole covariance matrix - # pass any parameters which are used inside _covar_func as *args to get gradients computed for them - return KernelLinearOperator(x1_, x2_, lengthscale=self.lengthscale, covar_func=_covar_func, **kwargs) - -except ImportError: - - class PeriodicKernel(GPeriodicKernel): - pass +from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel + + +def _covar_func(x1, x2, lengthscale, **kwargs): + x1_, x2_ = _lazify_and_expand_inputs(x1, x2) + lengthscale = lengthscale[..., None, None, 0, :] # 1 x 1 x ndim + # do not use .power(2.0) as it gives NaN values on cuda + # seems related to https://github.com/getkeops/keops/issues/112 + K = ((((x1_ - x2_).abs().sin()) ** 2) * (-2.0 / lengthscale)).sum(-1).exp() + return K + + +# subclass from original periodic kernel to reduce code duplication +class PeriodicKernel(KeOpsKernel, GPeriodicKernel): + """ + Implements the Periodic Kernel using KeOps as a driver for kernel matrix multiplies. + + This class can be used as a drop in replacement for :class:`gpytorch.kernels.PeriodicKernel` in most cases, + and supports the same arguments. + + :param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each + input dimension. It should be `d` if x1 is a `... x n x d` matrix. + :type ard_num_dims: int, optional + :param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each + batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output. + :type batch_shape: torch.Size, optional + :param active_dims: (Default: `None`) Set this if you want to + compute the covariance of only a few input dimensions. The ints + corresponds to the indices of the dimensions. + :type active_dims: Tuple(int) + :param period_length_prior: (Default: `None`) + Set this if you want to apply a prior to the period length parameter. + :type period_length_prior: ~gpytorch.priors.Prior, optional + :param period_length_constraint: (Default: `Positive`) Set this if you want + to apply a constraint to the period length parameter. + :type period_length_constraint: ~gpytorch.constraints.Interval, optional + :param lengthscale_prior: (Default: `None`) + Set this if you want to apply a prior to the lengthscale parameter. + :type lengthscale_prior: ~gpytorch.priors.Prior, optional + :param lengthscale_constraint: (Default: `Positive`) Set this if you want + to apply a constraint to the lengthscale parameter. + :type lengthscale_constraint: ~gpytorch.constraints.Interval, optional + :param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors). + :type eps: float, optional + + :var torch.Tensor period_length: The period length parameter. Size/shape of parameter depends on the + ard_num_dims and batch_shape arguments. + """ + + has_lengthscale = True + + def forward(self, x1, x2, **kwargs): + x1_ = x1.div(self.period_length / math.pi) + x2_ = x2.div(self.period_length / math.pi) + # return KernelLinearOperator inst only when calculating the whole covariance matrix + # pass any parameters which are used inside _covar_func as *args to get gradients computed for them + return KernelLinearOperator(x1_, x2_, lengthscale=self.lengthscale, covar_func=_covar_func, **kwargs) diff --git a/gpytorch/kernels/keops/rbf_kernel.py b/gpytorch/kernels/keops/rbf_kernel.py index 663d092a8..5497f0f47 100644 --- a/gpytorch/kernels/keops/rbf_kernel.py +++ b/gpytorch/kernels/keops/rbf_kernel.py @@ -3,58 +3,45 @@ # from linear_operator.operators import KeOpsLinearOperator from linear_operator.operators import KernelLinearOperator -from ..rbf_kernel import postprocess_rbf, RBFKernel as GRBFKernel -from .keops_kernel import KeOpsKernel - -try: - from pykeops.torch import LazyTensor as KEOLazyTensor - - def _covar_func(x1, x2, **kwargs): - x1_ = KEOLazyTensor(x1[..., :, None, :]) - x2_ = KEOLazyTensor(x2[..., None, :, :]) - K = (-((x1_ - x2_) ** 2).sum(-1) / 2).exp() - return K - - class RBFKernel(KeOpsKernel): - r""" - Implements the RBF kernel using KeOps as a driver for kernel matrix multiplies. - - This class can be used as a drop in replacement for :class:`gpytorch.kernels.RBFKernel` in most cases, - and supports the same arguments. - - :param ard_num_dims: Set this if you want a separate lengthscale for each input - dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.) - :param batch_shape: Set this if you want a separate lengthscale for each batch of input - data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is - a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor. - :param active_dims: Set this if you want to compute the covariance of only - a few input dimensions. The ints corresponds to the indices of the - dimensions. (Default: `None`.) - :param lengthscale_prior: Set this if you want to apply a prior to the - lengthscale parameter. (Default: `None`) - :param lengthscale_constraint: Set this if you want to apply a constraint - to the lengthscale parameter. (Default: `Positive`.) - :param eps: The minimum value that the lengthscale can take (prevents - divide by zero errors). (Default: `1e-6`.) - - :ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the - ard_num_dims and batch_shape arguments. - """ - - has_lengthscale = True - - def _nonkeops_forward(self, x1, x2, diag=False, **kwargs): - x1_ = x1 / self.lengthscale - x2_ = x2 / self.lengthscale - return postprocess_rbf(self.covar_dist(x1_, x2_, square_dist=True, diag=diag, **kwargs)) - - def _keops_forward(self, x1, x2, **kwargs): - x1_ = x1 / self.lengthscale - x2_ = x2 / self.lengthscale - # return KernelLinearOperator inst only when calculating the whole covariance matrix - return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, **kwargs) - -except ImportError: - - class RBFKernel(GRBFKernel): - pass +from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel + + +def _covar_func(x1, x2, **kwargs): + x1_, x2_ = _lazify_and_expand_inputs(x1, x2) + K = (-((x1_ - x2_) ** 2).sum(-1) / 2).exp() + return K + + +class RBFKernel(KeOpsKernel): + r""" + Implements the RBF kernel using KeOps as a driver for kernel matrix multiplies. + + This class can be used as a drop in replacement for :class:`gpytorch.kernels.RBFKernel` in most cases, + and supports the same arguments. + + :param ard_num_dims: Set this if you want a separate lengthscale for each input + dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.) + :param batch_shape: Set this if you want a separate lengthscale for each batch of input + data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is + a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor. + :param active_dims: Set this if you want to compute the covariance of only + a few input dimensions. The ints corresponds to the indices of the + dimensions. (Default: `None`.) + :param lengthscale_prior: Set this if you want to apply a prior to the + lengthscale parameter. (Default: `None`) + :param lengthscale_constraint: Set this if you want to apply a constraint + to the lengthscale parameter. (Default: `Positive`.) + :param eps: The minimum value that the lengthscale can take (prevents + divide by zero errors). (Default: `1e-6`.) + + :ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the + ard_num_dims and batch_shape arguments. + """ + + has_lengthscale = True + + def forward(self, x1, x2, **kwargs): + x1_ = x1 / self.lengthscale + x2_ = x2 / self.lengthscale + # return KernelLinearOperator inst only when calculating the whole covariance matrix + return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, **kwargs) diff --git a/gpytorch/settings.py b/gpytorch/settings.py index 595ffe10d..99528c419 100644 --- a/gpytorch/settings.py +++ b/gpytorch/settings.py @@ -448,6 +448,19 @@ def _fill_tensor(cls, observations) -> Tensor: return torch.nan_to_num(observations, nan=cls._fill_value) +class use_keops(_feature_flag): + """ + Whether or not to use KeOps under the hood (when using any :class:`gpytorch.kernels.keops.KeOpsKernel`. + In general, this flag should be set to True. + Setting it to false will resort to non-KeOps computation, + which will be slower but may be useful for debugging or timing comparisons. + + (Default: True) + """ + + _default = True + + __all__ = [ "_linalg_dtype_symeig", "_linalg_dtype_cholesky", @@ -487,6 +500,7 @@ def _fill_tensor(cls, observations) -> Tensor: "terminate_cg_by_size", "trace_mode", "tridiagonal_jitter", + "use_keops", "use_toeplitz", "variational_cholesky_jitter", "verbose_linalg", diff --git a/gpytorch/test/base_keops_test_case.py b/gpytorch/test/base_keops_test_case.py index 2ce4cd090..fb261c860 100644 --- a/gpytorch/test/base_keops_test_case.py +++ b/gpytorch/test/base_keops_test_case.py @@ -35,14 +35,15 @@ def test_forward_x1_eq_x2(self, ard=False, use_keops=True, **kwargs): kern1 = self.k1(**kwargs) kern2 = self.k2(**kwargs) - with patch.object(self.k1, "_keops_forward", wraps=kern1._keops_forward) as _keops_forward_mock: - # The patch makes sure that we're actually using KeOps + # The patch makes sure that we're actually using KeOps + # However, we're going to bypass KeOps and instead just use non-LazyTensors + with patch("gpytorch.kernels.keops.keops_kernel.LazyTensor", wraps=lambda x: x) as keops_mock: k1 = kern1(x1, x1).to_dense() k2 = kern2(x1, x1).to_dense() self.assertLess(torch.norm(k1 - k2), 1e-4) if use_keops: - self.assertTrue(_keops_forward_mock.called) + self.assertTrue(keops_mock.called) def test_forward_x1_eq_x2_ard(self): return self.test_forward_x1_eq_x2(ard=True) @@ -61,14 +62,14 @@ def test_forward_x1_neq_x2(self, use_keops=True, ard=False, **kwargs): kern1 = self.k1(**kwargs) kern2 = self.k2(**kwargs) - with patch.object(self.k1, "_keops_forward", wraps=kern1._keops_forward) as _keops_forward_mock: + with patch("gpytorch.kernels.keops.keops_kernel.LazyTensor", wraps=lambda x: x) as keops_mock: # The patch makes sure that we're actually using KeOps k1 = kern1(x1, x2).to_dense() k2 = kern2(x1, x2).to_dense() self.assertLess(torch.norm(k1 - k2), 1e-4) if use_keops: - self.assertTrue(_keops_forward_mock.called) + self.assertTrue(keops_mock.called) def test_forward_x1_meq_x2_ard(self): return self.test_forward_x1_neq_x2(ard=True) @@ -81,14 +82,14 @@ def test_batch_matmul(self, use_keops=True, **kwargs): kern2 = self.k2(**kwargs) rhs = torch.randn(3, 2, 100, 1) - with patch.object(self.k1, "_keops_forward", wraps=kern1._keops_forward) as _keops_forward_mock: + with patch("gpytorch.kernels.keops.keops_kernel.LazyTensor", wraps=lambda x: x) as keops_mock: # The patch makes sure that we're actually using KeOps res1 = kern1(x1, x1).matmul(rhs) res2 = kern2(x1, x1).matmul(rhs) self.assertLess(torch.norm(res1 - res2), 1e-4) if use_keops: - self.assertTrue(_keops_forward_mock.called) + self.assertTrue(keops_mock.called) def test_gradient(self, use_keops=True, ard=False, **kwargs): max_cholesky_size = CHOLESKY_SIZE_KEOPS if use_keops else CHOLESKY_SIZE_NONKEOPS @@ -104,20 +105,20 @@ def test_gradient(self, use_keops=True, ard=False, **kwargs): kern1 = self.k1(**kwargs) kern2 = self.k2(**kwargs) - with patch.object(self.k1, "_keops_forward", wraps=kern1._keops_forward) as _keops_forward_mock: + with patch("gpytorch.kernels.keops.keops_kernel.LazyTensor", wraps=lambda x: x) as keops_mock: # The patch makes sure that we're actually using KeOps res1 = kern1(x1, x1) res2 = kern2(x1, x1) s1 = res1.sum() s2 = res2.sum() - # stack all gradients into a tensor - grad_s1 = torch.vstack(torch.autograd.grad(s1, [*kern1.hyperparameters()])) - grad_s2 = torch.vstack(torch.autograd.grad(s2, [*kern2.hyperparameters()])) - self.assertAllClose(grad_s1, grad_s2, rtol=1e-4, atol=1e-5) + # stack all gradients into a tensor + grad_s1 = torch.vstack(torch.autograd.grad(s1, [*kern1.hyperparameters()])) + grad_s2 = torch.vstack(torch.autograd.grad(s2, [*kern2.hyperparameters()])) + self.assertAllClose(grad_s1, grad_s2, rtol=1e-4, atol=1e-5) if use_keops: - self.assertTrue(_keops_forward_mock.called) + self.assertTrue(keops_mock.called) def test_gradient_ard(self): return self.test_gradient(ard=True)