-
Notifications
You must be signed in to change notification settings - Fork 562
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel. (This computation happens during preconditioning, which requires the diagonal of the already-formed kernel LinearOperator object.) This error was because KeopsLinearOperator.diagonal calls to_dense on the output of a batch kernel operation. However, to_dense is not defined for KeOps LazyTensors. This PR is in some sense a hack fix to this bug (a less hack fix will require changes to KernelLinearOperator), but it is also a generally nice and helpful refactor that will improve KeOps kernels in general. The fixes: - KeOpsKernels now only define a forward function, that will be used both when we want to use KeOps and when we want to bypass it. - KeOpsKernels now use a `_lazify_inputs` helper method, which (potentially) wraps the inputs as KeOpsLazyTensors, or potentially leaves the inputs as torch Tensors. - The KeOps wrapping happens unless we want to bypass KeOps, which occurs when either (1) the matrix is small (below Cholesky size) or (2) when the use has turned off the `gpytorch.settings.use_keops` option (*NEW IN THIS PR*). Why this is beneficial: - KeOps kernels now follow the same API as non-KeOps kernels (define a forward method) - The user now only has to define one forward method, that works in both the keops and non-keops cases - The `diagonal` call in KeopsLinearOperator constructs a batch 1x1 matrix, which is small enough to bypass keops and thus avoid the current bug. (Hence why this solution is currently a hack, but could become less hacky with a small modification to KernelLinearOperator and/or the to_dense method in LinearOperator). Other changes: - Fix stability issues with the keops MaternKernel. (There were some NaN issues) - Introduce a `gpytorch.settings.use_keops` feature flag. - Clean up KeOPs notebook [Fixes #2363]
- Loading branch information
Showing
8 changed files
with
319 additions
and
349 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .keops_kernel import KeOpsKernel | ||
from .matern_kernel import MaternKernel | ||
from .periodic_kernel import PeriodicKernel | ||
from .rbf_kernel import RBFKernel | ||
|
||
__all__ = ["MaternKernel", "RBFKernel", "PeriodicKernel"] | ||
__all__ = ["KeOpsKernel", "MaternKernel", "PeriodicKernel", "RBFKernel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,66 @@ | ||
from abc import abstractmethod | ||
from typing import Any | ||
import warnings | ||
from typing import Any, Tuple, Union | ||
|
||
import torch | ||
from linear_operator import LinearOperator | ||
from torch import Tensor | ||
|
||
from ... import settings | ||
from ..kernel import Kernel | ||
|
||
try: | ||
import pykeops # noqa F401 | ||
from pykeops.torch import LazyTensor | ||
|
||
def _lazify_and_expand_inputs( | ||
x1: Tensor, x2: Tensor | ||
) -> Tuple[Union[Tensor, LazyTensor], Union[Tensor, LazyTensor]]: | ||
r""" | ||
Potentially wrap inputs x1 and x2 as KeOps LazyTensors, | ||
depending on whether or not we want to use KeOps under the hood or not. | ||
""" | ||
x1_ = x1[..., :, None, :] | ||
x2_ = x2[..., None, :, :] | ||
if _use_keops(x1, x2): | ||
res = LazyTensor(x1_), LazyTensor(x2_) | ||
return res | ||
return x1_, x2_ | ||
|
||
def _use_keops(x1: Tensor, x2: Tensor) -> bool: | ||
r""" | ||
Determine whether or not to use KeOps under the hood | ||
This largely depends on the size of the kernel matrix | ||
There are situations where we do not want the KeOps linear operator to use KeOps under the hood. | ||
See https://github.com/cornellius-gp/gpytorch/pull/1319 | ||
""" | ||
return ( | ||
settings.use_keops.on() | ||
and x1.size(-2) >= settings.max_cholesky_size.value() | ||
and x2.size(-2) >= settings.max_cholesky_size.value() | ||
) | ||
|
||
class KeOpsKernel(Kernel): | ||
@abstractmethod | ||
def _nonkeops_forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any): | ||
r""" | ||
Computes the covariance matrix (or diagonal) without using KeOps. | ||
This function must implement both the diag=True and diag=False options. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def _keops_forward(self, x1: Tensor, x2: Tensor, **kwargs: Any): | ||
r""" | ||
Computes the covariance matrix with KeOps. | ||
This function only implements the diag=False option, and no diag keyword should be passed in. | ||
""" | ||
raise NotImplementedError | ||
|
||
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any): | ||
if diag: | ||
return self._nonkeops_forward(x1, x2, diag=True, **kwargs) | ||
elif x1.size(-2) < settings.max_cholesky_size.value() or x2.size(-2) < settings.max_cholesky_size.value(): | ||
return self._nonkeops_forward(x1, x2, diag=False, **kwargs) | ||
else: | ||
return self._keops_forward(x1, x2, **kwargs) | ||
|
||
def __call__(self, *args: Any, **kwargs: Any): | ||
def __call__(self, *args: Any, **kwargs: Any) -> Union[LinearOperator, Tensor, LazyTensor]: | ||
# Hotfix for zero gradients. See https://github.com/cornellius-gp/gpytorch/issues/1543 | ||
args = [arg.contiguous() if torch.is_tensor(arg) else arg for arg in args] | ||
kwargs = {k: v.contiguous() if torch.is_tensor(v) else v for k, v in kwargs.items()} | ||
return super().__call__(*args, **kwargs) | ||
|
||
except ImportError: | ||
|
||
def _lazify_and_expand_inputs(x1: Tensor, x2: Tensor) -> Tuple[Tensor, Tensor]: | ||
x1_ = x1[..., :, None, :] | ||
x2_ = x2[..., None, :, :] | ||
return x1_, x2_ | ||
|
||
def _use_keops(x1: Tensor, x2: Tensor) -> bool: | ||
return False | ||
|
||
class KeOpsKernel(Kernel): | ||
def __init__(self, *args: Any, **kwargs: Any): | ||
raise RuntimeError("You must have KeOps installed to use a KeOpsKernel") | ||
def __call__(self, *args: Any, **kwargs: Any) -> Union[LinearOperator, Tensor]: | ||
warnings.warn( | ||
"KeOps is not installed. " f"{type(self)} will revert to the the non-keops version of this kernel.", | ||
RuntimeWarning, | ||
) | ||
return super().__call__(*args, **kwargs) |
Oops, something went wrong.