Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix: fix .diagonal() calls for keops kernel matrices. #2590

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions gpytorch/kernels/keops/keops_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import pykeops # noqa F401
from pykeops.torch import LazyTensor

_Anysor = Union[Tensor, LazyTensor]

def _lazify_and_expand_inputs(
x1: Tensor, x2: Tensor
) -> Tuple[Union[Tensor, LazyTensor], Union[Tensor, LazyTensor]]:
Expand Down Expand Up @@ -49,6 +51,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Union[LinearOperator, Tensor, L

except ImportError:

_Anysor = Tensor

def _lazify_and_expand_inputs(x1: Tensor, x2: Tensor) -> Tuple[Tensor, Tensor]:
x1_ = x1[..., :, None, :]
x2_ = x2[..., None, :, :]
Expand Down
18 changes: 13 additions & 5 deletions gpytorch/kernels/keops/matern_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import math

from linear_operator.operators import KernelLinearOperator
from torch import Tensor

from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel
from .keops_kernel import _Anysor, _lazify_and_expand_inputs, KeOpsKernel


def _covar_func(x1, x2, nu=2.5, **params):
def _covar_func(x1: _Anysor, x2: _Anysor, nu: float = 2.5, **params) -> _Anysor:
x1_, x2_ = _lazify_and_expand_inputs(x1, x2)

sq_distance = ((x1_ - x2_) ** 2).sum(-1)
Expand Down Expand Up @@ -57,15 +58,22 @@ class MaternKernel(KeOpsKernel):

has_lengthscale = True

def __init__(self, nu=2.5, **kwargs):
def __init__(self, nu: float = 2.5, **kwargs):
if nu not in {0.5, 1.5, 2.5}:
raise RuntimeError("nu expected to be 0.5, 1.5, or 2.5")
super().__init__(**kwargs)
self.nu = nu

def forward(self, x1, x2, **kwargs):
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs) -> KernelLinearOperator:
mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)]
x1_ = (x1 - mean) / self.lengthscale
x2_ = (x2 - mean) / self.lengthscale
# return KernelLinearOperator inst only when calculating the whole covariance matrix
return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs)
res = KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs)

# TODO: diag=True mode will be removed with the GpyTorch 2.0 PR to remove LazyEvaluatedKernelTensor
# (it will be replaced by a `_symmetric_diag` method for quickly computing the diagonals of symmetric matrices)
if diag:
return res.diagonal(dim1=-1, dim2=-2)
else:
return res
16 changes: 12 additions & 4 deletions gpytorch/kernels/keops/periodic_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import math

from linear_operator.operators import KernelLinearOperator
from torch import Tensor

from ..periodic_kernel import PeriodicKernel as GPeriodicKernel
from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel
from .keops_kernel import _Anysor, _lazify_and_expand_inputs, KeOpsKernel


def _covar_func(x1, x2, lengthscale, **kwargs):
def _covar_func(x1: _Anysor, x2: _Anysor, lengthscale: Tensor, **kwargs) -> _Anysor:
x1_, x2_ = _lazify_and_expand_inputs(x1, x2)
lengthscale = lengthscale[..., None, None, 0, :] # 1 x 1 x ndim
# do not use .power(2.0) as it gives NaN values on cuda
Expand Down Expand Up @@ -56,9 +57,16 @@ class PeriodicKernel(KeOpsKernel, GPeriodicKernel):

has_lengthscale = True

def forward(self, x1, x2, **kwargs):
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs) -> KernelLinearOperator:
x1_ = x1.div(self.period_length / math.pi)
x2_ = x2.div(self.period_length / math.pi)
# return KernelLinearOperator inst only when calculating the whole covariance matrix
# pass any parameters which are used inside _covar_func as *args to get gradients computed for them
return KernelLinearOperator(x1_, x2_, lengthscale=self.lengthscale, covar_func=_covar_func, **kwargs)
res = KernelLinearOperator(x1_, x2_, lengthscale=self.lengthscale, covar_func=_covar_func, **kwargs)

# TODO: diag=True mode will be removed with the GpyTorch 2.0 PR to remove LazyEvaluatedKernelTensor
# (it will be replaced by a `_symmetric_diag` method for quickly computing the diagonals of symmetric matrices)
if diag:
return res.diagonal(dim1=-1, dim2=-2)
else:
return res
16 changes: 12 additions & 4 deletions gpytorch/kernels/keops/rbf_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

# from linear_operator.operators import KeOpsLinearOperator
from linear_operator.operators import KernelLinearOperator
from torch import Tensor

from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel
from .keops_kernel import _Anysor, _lazify_and_expand_inputs, KeOpsKernel


def _covar_func(x1, x2, **kwargs):
def _covar_func(x1: _Anysor, x2: _Anysor, **kwargs) -> _Anysor:
x1_, x2_ = _lazify_and_expand_inputs(x1, x2)
K = (-((x1_ - x2_) ** 2).sum(-1) / 2).exp()
return K
Expand Down Expand Up @@ -40,8 +41,15 @@ class RBFKernel(KeOpsKernel):

has_lengthscale = True

def forward(self, x1, x2, **kwargs):
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs) -> KernelLinearOperator:
x1_ = x1 / self.lengthscale
x2_ = x2 / self.lengthscale
# return KernelLinearOperator inst only when calculating the whole covariance matrix
return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, **kwargs)
res = KernelLinearOperator(x1_, x2_, covar_func=_covar_func, **kwargs)

# TODO: diag=True mode will be removed with the GpyTorch 2.0 PR to remove LazyEvaluatedKernelTensor
# (it will be replaced by a `_symmetric_diag` method for quickly computing the diagonals of symmetric matrices)
if diag:
return res.diagonal(dim1=-1, dim2=-2)
else:
return res
10 changes: 10 additions & 0 deletions gpytorch/test/base_keops_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def test_forward_x1_eq_x2(self, ard=False, use_keops=True, **kwargs):
k2 = kern2(x1, x1).to_dense()
self.assertLess(torch.norm(k1 - k2), 1e-4)

# Test diagonal
d1 = kern1(x1, x1).diagonal(dim1=-1, dim2=-2)
d2 = kern2(x1, x1).diagonal(dim1=-1, dim2=-2)
self.assertLess(torch.norm(d1 - d2), 1e-4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only tests that what the diagonal calls returns are the same, not that they actually compute the diagonal. You should add sth like this here.

Suggested change
self.assertLess(torch.norm(d1 - d2), 1e-4)
self.assertLess(torch.norm(d1 - d2), 1e-4)
self.assertTrue(torch.equal(k1.diag(), d1))


if use_keops:
self.assertTrue(keops_mock.called)

Expand All @@ -68,6 +73,11 @@ def test_forward_x1_neq_x2(self, use_keops=True, ard=False, **kwargs):
k2 = kern2(x1, x2).to_dense()
self.assertLess(torch.norm(k1 - k2), 1e-3)

# Test diagonal
d1 = kern1(x1, x1).diagonal(dim1=-1, dim2=-2)
d2 = kern2(x1, x1).diagonal(dim1=-1, dim2=-2)
self.assertLess(torch.norm(d1 - d2), 1e-4)

if use_keops:
self.assertTrue(keops_mock.called)

Expand Down
Loading