From f7191453c38371affcc79971770796392fef4c57 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Thu, 18 Apr 2024 15:56:49 -0400 Subject: [PATCH] ConstantKernel --- docs/source/kernels.rst | 9 +- gpytorch/kernels/__init__.py | 2 + gpytorch/kernels/constant_kernel.py | 119 +++++++++++++++++++++++++ gpytorch/test/base_kernel_test_case.py | 10 +-- test/kernels/test_constant_kernel.py | 113 +++++++++++++++++++++++ 5 files changed, 246 insertions(+), 7 deletions(-) create mode 100644 gpytorch/kernels/constant_kernel.py create mode 100644 test/kernels/test_constant_kernel.py diff --git a/docs/source/kernels.rst b/docs/source/kernels.rst index 5c7ae0945..4c240f6ef 100644 --- a/docs/source/kernels.rst +++ b/docs/source/kernels.rst @@ -9,7 +9,7 @@ gpytorch.kernels If you don't know what kernel to use, we recommend that you start out with a -:code:`gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())`. +:code:`gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) + gpytorch.kernel.ConstantKernel()`. Kernel @@ -22,6 +22,13 @@ Kernel Standard Kernels ----------------------------- +:hidden:`ConstantKernel` +~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: ConstantKernel + :members: + + :hidden:`CosineKernel` ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/gpytorch/kernels/__init__.py b/gpytorch/kernels/__init__.py index cc85fe624..1d87e764b 100644 --- a/gpytorch/kernels/__init__.py +++ b/gpytorch/kernels/__init__.py @@ -2,6 +2,7 @@ from . import keops from .additive_structure_kernel import AdditiveStructureKernel from .arc_kernel import ArcKernel +from .constant_kernel import ConstantKernel from .cosine_kernel import CosineKernel from .cylindrical_kernel import CylindricalKernel from .distributional_input_kernel import DistributionalInputKernel @@ -38,6 +39,7 @@ "ArcKernel", "AdditiveKernel", "AdditiveStructureKernel", + "ConstantKernel", "CylindricalKernel", "MultiDeviceKernel", "CosineKernel", diff --git a/gpytorch/kernels/constant_kernel.py b/gpytorch/kernels/constant_kernel.py new file mode 100644 index 000000000..5b82c1b49 --- /dev/null +++ b/gpytorch/kernels/constant_kernel.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 + +from typing import Optional, Tuple + +import torch + +from ..constraints import Interval, Positive +from ..priors import Prior +from .kernel import Kernel + + +class ConstantKernel(Kernel): + """ + Constant covariance kernel for the probabilistic inference of constant coefficients. + + ConstantKernel represents the prior variance `k(x1, x2) = var(c)` of a constant `c`. + The prior variance of the constant is optimized during the GP hyper-parameter + optimization stage. The actual value of the constant is computed (implicitly) using + the linear algebraic approaches for the computation of GP samples and posteriors. + + The kernel (`k_constant`) is most useful as a modification of an arbitrary `k_base`: + 1) Additive constants: The modification `k_base + k_constant` allows the GP to + infer a non-zero asymptotic value far from the training data, which + generally leads to more accurate extrapolation. Notably, the uncertainty in + this constant value affects the posterior covariances through the posterior + inference equations. This is not the case when a constant prior mean is + used, since the prior mean does not show up the posterior covariance and is + not regularized by the log-determinant during the optimization of the marginal + likelihood. + 2) Multiplicative constants: The modification `k_base * k_constant` allows the + GP to modulate the variance of the kernel `k_base`, and is mathematically + identical to `ScaleKernel(base_kernel)` with the same constant. + """ + + has_lengthscale = False + + def __init__( + self, + batch_shape: Optional[torch.Size] = None, + constant_prior: Optional[Prior] = None, + constant_constraint: Optional[Interval] = None, + active_dims: Optional[Tuple[int, ...]] = None, + ): + """Constructor of ConstantKernel. + + Args: + batch_shape: The batch shape of the kernel. + constant_prior: Prior over the constant parameter. + constant_constraint: Constraint to place on constant parameter. + """ + super().__init__(batch_shape=batch_shape, active_dims=active_dims) + + self.register_parameter( + name="raw_constant", + parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)), + ) + + if constant_prior is not None: + if not isinstance(constant_prior, Prior): + raise TypeError("Expected gpytorch.priors.Prior but got " + type(constant_prior).__name__) + self.register_prior( + "constant_prior", + constant_prior, + lambda m: m.constant, + lambda m, v: m._set_constant(v), + ) + + if constant_constraint is None: + constant_constraint = Positive() + self.register_constraint("raw_constant", constant_constraint) + + @property + def constant(self) -> torch.Tensor: + return self.raw_constant_constraint.transform(self.raw_constant) + + @constant.setter + def constant(self, value: torch.Tensor) -> None: + self._set_constant(value) + + def _set_constant(self, value: torch.Tensor) -> None: + value = value.view(*self.batch_shape, 1) + self.initialize(raw_constant=self.raw_constant_constraint.inverse_transform(value)) + + def forward( + self, + x1: torch.Tensor, + x2: torch.Tensor, + diag: Optional[bool] = False, + last_dim_is_batch: Optional[bool] = False, + ) -> torch.Tensor: + """Evaluates the constant kernel. + + Args: + x1: First input tensor of shape (batch_shape x n1 x d). + x2: Second input tensor of shape (batch_shape x n2 x d). + diag: If True, returns the diagonal of the covariance matrix. + last_dim_is_batch: If True, the last dimension of size `d` of the input + tensors are treated as a batch dimension. + + Returns: + A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of + constant covariance values if diag is False, resp. True. + """ + if last_dim_is_batch: + x1 = x1.transpose(-1, -2).unsqueeze(-1) + x2 = x2.transpose(-1, -2).unsqueeze(-1) + + dtype = torch.promote_types(x1.dtype, x2.dtype) + batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) + shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],)) + constant = self.constant.to(dtype=dtype, device=x1.device) + + if not diag: + constant = constant.unsqueeze(-1) + + if last_dim_is_batch: + constant = constant.unsqueeze(-1) + + return constant.expand(shape) diff --git a/gpytorch/test/base_kernel_test_case.py b/gpytorch/test/base_kernel_test_case.py index 5301ce2d9..88f6afbd5 100644 --- a/gpytorch/test/base_kernel_test_case.py +++ b/gpytorch/test/base_kernel_test_case.py @@ -122,23 +122,21 @@ def test_no_batch_kernel_double_batch_x_ard(self): actual_diag = actual_covar_mat.diagonal(dim1=-1, dim2=-2) self.assertAllClose(kernel_diag, actual_diag, rtol=1e-3, atol=1e-5) - def test_smoke_double_batch_kernel_double_batch_x_no_ard(self): + def test_smoke_double_batch_kernel_double_batch_x_no_ard(self) -> None: kernel = self.create_kernel_no_ard(batch_shape=torch.Size([3, 2])) x = self.create_data_double_batch() - batch_covar_mat = kernel(x).evaluate_kernel().to_dense() + kernel(x).evaluate_kernel().to_dense() kernel(x, diag=True) - return batch_covar_mat - def test_smoke_double_batch_kernel_double_batch_x_ard(self): + def test_smoke_double_batch_kernel_double_batch_x_ard(self) -> None: try: kernel = self.create_kernel_ard(num_dims=2, batch_shape=torch.Size([3, 2])) except NotImplementedError: return x = self.create_data_double_batch() - batch_covar_mat = kernel(x).evaluate_kernel().to_dense() + kernel(x).evaluate_kernel().to_dense() kernel(x, diag=True) - return batch_covar_mat def test_kernel_getitem_single_batch(self): kernel = self.create_kernel_no_ard(batch_shape=torch.Size([2])) diff --git a/test/kernels/test_constant_kernel.py b/test/kernels/test_constant_kernel.py new file mode 100644 index 000000000..849ec3996 --- /dev/null +++ b/test/kernels/test_constant_kernel.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 + +import itertools +import unittest + +import torch + +from torch import Tensor + +from gpytorch.kernels import AdditiveKernel, ConstantKernel, MaternKernel, ProductKernel, ScaleKernel +from gpytorch.lazy import LazyEvaluatedKernelTensor +from gpytorch.priors.torch_priors import GammaPrior +from gpytorch.test.base_kernel_test_case import BaseKernelTestCase + + +class TestConstantKernel(unittest.TestCase, BaseKernelTestCase): + def create_kernel_no_ard(self, **kwargs): + return ConstantKernel(**kwargs) + + def test_constant_kernel(self): + with self.subTest(device="cpu"): + self._test_constant_kernel(torch.device("cpu")) + + if torch.cuda.is_available(): + with self.subTest(device="cuda"): + self._test_constant_kernel(torch.device("cuda")) + + def _test_constant_kernel(self, device: torch.device): + n, d = 3, 5 + dtypes = [torch.float, torch.double] + batch_shapes = [(), (2,), (7, 2)] + torch.manual_seed(123) + for dtype, batch_shape in itertools.product(dtypes, batch_shapes): + tkwargs = {"dtype": dtype, "device": device} + places = 6 if dtype == torch.float else 12 + X = torch.rand(*batch_shape, n, d, **tkwargs) + + constant_kernel = ConstantKernel(batch_shape=batch_shape) + KL = constant_kernel(X) + self.assertIsInstance(KL, LazyEvaluatedKernelTensor) + KM = KL.to_dense() + self.assertIsInstance(KM, Tensor) + self.assertEqual(KM.shape, (*batch_shape, n, n)) + self.assertEqual(KM.dtype, dtype) + self.assertEqual(KM.device.type, device.type) + # standard deviation is zero iff KM is constant + self.assertAlmostEqual(KM.std().item(), 0, places=places) + + # testing last_dim_is_batch + with self.subTest(last_dim_is_batch=True): + KD = constant_kernel(X, last_dim_is_batch=True).to(device=device) + self.assertIsInstance(KD, LazyEvaluatedKernelTensor) + KM = KD.to_dense() + self.assertIsInstance(KM, Tensor) + self.assertEqual(KM.shape, (*batch_shape, d, n, n)) + self.assertAlmostEqual(KM.std().item(), 0, places=places) + self.assertEqual(KM.dtype, dtype) + self.assertEqual(KM.device.type, device.type) + + # testing diag + with self.subTest(diag=True): + KD = constant_kernel(X, diag=True) + self.assertIsInstance(KD, Tensor) + self.assertEqual(KD.shape, (*batch_shape, n)) + self.assertAlmostEqual(KD.std().item(), 0, places=places) + self.assertEqual(KD.dtype, dtype) + self.assertEqual(KD.device.type, device.type) + + # testing diag and last_dim_is_batch + with self.subTest(diag=True, last_dim_is_batch=True): + KD = constant_kernel(X, diag=True, last_dim_is_batch=True) + self.assertIsInstance(KD, Tensor) + self.assertEqual(KD.shape, (*batch_shape, d, n)) + self.assertAlmostEqual(KD.std().item(), 0, places=places) + self.assertEqual(KD.dtype, dtype) + self.assertEqual(KD.device.type, device.type) + + # testing AD + with self.subTest(requires_grad=True): + X.requires_grad = True + constant_kernel(X).to_dense().sum().backward() + self.assertIsNone(X.grad) # constant kernel is not dependent on X + + # testing algebraic combinations with another kernel + base_kernel = MaternKernel().to(device=device) + + with self.subTest(additive=True): + sum_kernel = base_kernel + constant_kernel + self.assertIsInstance(sum_kernel, AdditiveKernel) + self.assertAllClose( + sum_kernel(X).to_dense(), + base_kernel(X).to_dense() + constant_kernel.constant.unsqueeze(-1), + ) + + # product with constant is equivalent to scale kernel + with self.subTest(product=True): + product_kernel = base_kernel * constant_kernel + self.assertIsInstance(product_kernel, ProductKernel) + + scale_kernel = ScaleKernel(base_kernel, batch_shape=batch_shape) + scale_kernel.to(device=device) + self.assertAllClose(scale_kernel(X).to_dense(), product_kernel(X).to_dense()) + + # setting constant + pies = torch.full_like(constant_kernel.constant, torch.pi) + constant_kernel.constant = pies + self.assertAllClose(constant_kernel.constant, pies) + + # specifying prior + constant_kernel = ConstantKernel(constant_prior=GammaPrior(concentration=2.4, rate=2.7)) + + with self.assertRaisesRegex(TypeError, "Expected gpytorch.priors.Prior but got"): + ConstantKernel(constant_prior=1)