diff --git a/gpytorch/kernels/keops/keops_kernel.py b/gpytorch/kernels/keops/keops_kernel.py index 742e9b3ad..aeae4a971 100644 --- a/gpytorch/kernels/keops/keops_kernel.py +++ b/gpytorch/kernels/keops/keops_kernel.py @@ -12,6 +12,8 @@ import pykeops # noqa F401 from pykeops.torch import LazyTensor + _Anysor = Union[Tensor, LazyTensor] + def _lazify_and_expand_inputs( x1: Tensor, x2: Tensor ) -> Tuple[Union[Tensor, LazyTensor], Union[Tensor, LazyTensor]]: @@ -49,6 +51,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Union[LinearOperator, Tensor, L except ImportError: + _Anysor = Tensor + def _lazify_and_expand_inputs(x1: Tensor, x2: Tensor) -> Tuple[Tensor, Tensor]: x1_ = x1[..., :, None, :] x2_ = x2[..., None, :, :] diff --git a/gpytorch/kernels/keops/matern_kernel.py b/gpytorch/kernels/keops/matern_kernel.py index 3ea5235fc..60b9eb33b 100644 --- a/gpytorch/kernels/keops/matern_kernel.py +++ b/gpytorch/kernels/keops/matern_kernel.py @@ -2,11 +2,12 @@ import math from linear_operator.operators import KernelLinearOperator +from torch import Tensor -from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel +from .keops_kernel import _Anysor, _lazify_and_expand_inputs, KeOpsKernel -def _covar_func(x1, x2, nu=2.5, **params): +def _covar_func(x1: _Anysor, x2: _Anysor, nu: float = 2.5, **params) -> _Anysor: x1_, x2_ = _lazify_and_expand_inputs(x1, x2) sq_distance = ((x1_ - x2_) ** 2).sum(-1) @@ -57,15 +58,22 @@ class MaternKernel(KeOpsKernel): has_lengthscale = True - def __init__(self, nu=2.5, **kwargs): + def __init__(self, nu: float = 2.5, **kwargs): if nu not in {0.5, 1.5, 2.5}: raise RuntimeError("nu expected to be 0.5, 1.5, or 2.5") super().__init__(**kwargs) self.nu = nu - def forward(self, x1, x2, **kwargs): + def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs) -> KernelLinearOperator: mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)] x1_ = (x1 - mean) / self.lengthscale x2_ = (x2 - mean) / self.lengthscale # return KernelLinearOperator inst only when calculating the whole covariance matrix - return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs) + res = KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs) + + # TODO: diag=True mode will be removed with the GpyTorch 2.0 PR to remove LazyEvaluatedKernelTensor + # (it will be replaced by a `_symmetric_diag` method for quickly computing the diagonals of symmetric matrices) + if diag: + return res.diagonal(dim1=-1, dim2=-2) + else: + return res diff --git a/gpytorch/kernels/keops/periodic_kernel.py b/gpytorch/kernels/keops/periodic_kernel.py index fe4831a1d..d733189a7 100644 --- a/gpytorch/kernels/keops/periodic_kernel.py +++ b/gpytorch/kernels/keops/periodic_kernel.py @@ -3,12 +3,13 @@ import math from linear_operator.operators import KernelLinearOperator +from torch import Tensor from ..periodic_kernel import PeriodicKernel as GPeriodicKernel -from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel +from .keops_kernel import _Anysor, _lazify_and_expand_inputs, KeOpsKernel -def _covar_func(x1, x2, lengthscale, **kwargs): +def _covar_func(x1: _Anysor, x2: _Anysor, lengthscale: Tensor, **kwargs) -> _Anysor: x1_, x2_ = _lazify_and_expand_inputs(x1, x2) lengthscale = lengthscale[..., None, None, 0, :] # 1 x 1 x ndim # do not use .power(2.0) as it gives NaN values on cuda @@ -56,9 +57,16 @@ class PeriodicKernel(KeOpsKernel, GPeriodicKernel): has_lengthscale = True - def forward(self, x1, x2, **kwargs): + def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs) -> KernelLinearOperator: x1_ = x1.div(self.period_length / math.pi) x2_ = x2.div(self.period_length / math.pi) # return KernelLinearOperator inst only when calculating the whole covariance matrix # pass any parameters which are used inside _covar_func as *args to get gradients computed for them - return KernelLinearOperator(x1_, x2_, lengthscale=self.lengthscale, covar_func=_covar_func, **kwargs) + res = KernelLinearOperator(x1_, x2_, lengthscale=self.lengthscale, covar_func=_covar_func, **kwargs) + + # TODO: diag=True mode will be removed with the GpyTorch 2.0 PR to remove LazyEvaluatedKernelTensor + # (it will be replaced by a `_symmetric_diag` method for quickly computing the diagonals of symmetric matrices) + if diag: + return res.diagonal(dim1=-1, dim2=-2) + else: + return res diff --git a/gpytorch/kernels/keops/rbf_kernel.py b/gpytorch/kernels/keops/rbf_kernel.py index 5497f0f47..bf66a8e81 100644 --- a/gpytorch/kernels/keops/rbf_kernel.py +++ b/gpytorch/kernels/keops/rbf_kernel.py @@ -2,11 +2,12 @@ # from linear_operator.operators import KeOpsLinearOperator from linear_operator.operators import KernelLinearOperator +from torch import Tensor -from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel +from .keops_kernel import _Anysor, _lazify_and_expand_inputs, KeOpsKernel -def _covar_func(x1, x2, **kwargs): +def _covar_func(x1: _Anysor, x2: _Anysor, **kwargs) -> _Anysor: x1_, x2_ = _lazify_and_expand_inputs(x1, x2) K = (-((x1_ - x2_) ** 2).sum(-1) / 2).exp() return K @@ -40,8 +41,15 @@ class RBFKernel(KeOpsKernel): has_lengthscale = True - def forward(self, x1, x2, **kwargs): + def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs) -> KernelLinearOperator: x1_ = x1 / self.lengthscale x2_ = x2 / self.lengthscale # return KernelLinearOperator inst only when calculating the whole covariance matrix - return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, **kwargs) + res = KernelLinearOperator(x1_, x2_, covar_func=_covar_func, **kwargs) + + # TODO: diag=True mode will be removed with the GpyTorch 2.0 PR to remove LazyEvaluatedKernelTensor + # (it will be replaced by a `_symmetric_diag` method for quickly computing the diagonals of symmetric matrices) + if diag: + return res.diagonal(dim1=-1, dim2=-2) + else: + return res diff --git a/gpytorch/test/base_keops_test_case.py b/gpytorch/test/base_keops_test_case.py index ca32b4d64..9b8cbb13b 100644 --- a/gpytorch/test/base_keops_test_case.py +++ b/gpytorch/test/base_keops_test_case.py @@ -42,6 +42,11 @@ def test_forward_x1_eq_x2(self, ard=False, use_keops=True, **kwargs): k2 = kern2(x1, x1).to_dense() self.assertLess(torch.norm(k1 - k2), 1e-4) + # Test diagonal + d1 = kern1(x1, x1).diagonal(dim1=-1, dim2=-2) + d2 = kern2(x1, x1).diagonal(dim1=-1, dim2=-2) + self.assertLess(torch.norm(d1 - d2), 1e-4) + if use_keops: self.assertTrue(keops_mock.called) @@ -68,6 +73,11 @@ def test_forward_x1_neq_x2(self, use_keops=True, ard=False, **kwargs): k2 = kern2(x1, x2).to_dense() self.assertLess(torch.norm(k1 - k2), 1e-3) + # Test diagonal + d1 = kern1(x1, x1).diagonal(dim1=-1, dim2=-2) + d2 = kern2(x1, x1).diagonal(dim1=-1, dim2=-2) + self.assertLess(torch.norm(d1 - d2), 1e-4) + if use_keops: self.assertTrue(keops_mock.called)