From 89dfb46655affd12ed60d7a94ec9afa355b15e4c Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Thu, 18 Apr 2024 15:56:49 -0400 Subject: [PATCH] ConstantKernel --- docs/source/kernels.rst | 9 +- gpytorch/kernels/__init__.py | 2 + gpytorch/kernels/constant_kernel.py | 123 +++++++++++++++++++++++++ gpytorch/test/base_kernel_test_case.py | 10 +- setup.py | 1 + test/kernels/test_constant_kernel.py | 113 +++++++++++++++++++++++ 6 files changed, 251 insertions(+), 7 deletions(-) create mode 100644 gpytorch/kernels/constant_kernel.py create mode 100644 test/kernels/test_constant_kernel.py diff --git a/docs/source/kernels.rst b/docs/source/kernels.rst index 5c7ae0945..5fa89b916 100644 --- a/docs/source/kernels.rst +++ b/docs/source/kernels.rst @@ -9,7 +9,7 @@ gpytorch.kernels If you don't know what kernel to use, we recommend that you start out with a -:code:`gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())`. +:code:`gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) + gpytorch.kernel.ConstantKernel()`. Kernel @@ -22,6 +22,13 @@ Kernel Standard Kernels ----------------------------- +:hidden:`ConstantKernel` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: ConstantKernel + :members: + + :hidden:`CosineKernel` ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/gpytorch/kernels/__init__.py b/gpytorch/kernels/__init__.py index cc85fe624..1d87e764b 100644 --- a/gpytorch/kernels/__init__.py +++ b/gpytorch/kernels/__init__.py @@ -2,6 +2,7 @@ from . import keops from .additive_structure_kernel import AdditiveStructureKernel from .arc_kernel import ArcKernel +from .constant_kernel import ConstantKernel from .cosine_kernel import CosineKernel from .cylindrical_kernel import CylindricalKernel from .distributional_input_kernel import DistributionalInputKernel @@ -38,6 +39,7 @@ "ArcKernel", "AdditiveKernel", "AdditiveStructureKernel", + "ConstantKernel", "CylindricalKernel", "MultiDeviceKernel", "CosineKernel", diff --git a/gpytorch/kernels/constant_kernel.py b/gpytorch/kernels/constant_kernel.py new file mode 100644 index 000000000..98a3560e2 --- /dev/null +++ b/gpytorch/kernels/constant_kernel.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 + +from typing import Optional, Tuple + +import torch +from torch import Tensor + +from ..constraints import Interval, Positive +from ..priors import Prior +from .kernel import Kernel + + +class ConstantKernel(Kernel): + """ + Constant covariance kernel for the probabilistic inference of constant coefficients. + + ConstantKernel represents the prior variance `k(x1, x2) = var(c)` of a constant `c`. + The prior variance of the constant is optimized during the GP hyper-parameter + optimization stage. The actual value of the constant is computed (implicitly) using + the linear algebraic approaches for the computation of GP samples and posteriors. + + The constant kernel `k_constant` is most useful as a modification of an arbitrary + base kernel `k_base`: + 1) Additive constants: The modification `k_base + k_constant` allows the GP to + infer a non-zero asymptotic value far from the training data, which generally + leads to more accurate extrapolation. Notably, the uncertainty in this constant + value affects the posterior covariances through the posterior inference equations. + This is not the case when a constant prior mean is not used, since the prior mean + does not show up the posterior covariance and is regularized by the log-determinant + during the optimization of the marginal likelihood. + 2) Multiplicative constants: The modification `k_base * k_constant` allows the GP to + modulate the variance of the kernel `k_base`, and is mathematically identical to + `ScaleKernel(base_kernel)` with the same constant. + """ + + has_lengthscale = False + + def __init__( + self, + batch_shape: Optional[torch.Size] = None, + constant_prior: Optional[Prior] = None, + constant_constraint: Optional[Interval] = None, + active_dims: Optional[Tuple[int, ...]] = None, + ): + """Constructor of ConstantKernel. + + Args: + batch_shape: The batch shape of the kernel. + constant_prior: Prior over the constant parameter. + constant_constraint: Constraint to place on constant parameter. + active_dims: The dimensions of the input with which to evaluate the kernel. + This is mute for the constant kernel, but added for compatability with + the Kernel API. + """ + super().__init__(batch_shape=batch_shape, active_dims=active_dims) + + self.register_parameter( + name="raw_constant", + parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)), + ) + + if constant_prior is not None: + if not isinstance(constant_prior, Prior): + raise TypeError("Expected gpytorch.priors.Prior but got " + type(constant_prior).__name__) + self.register_prior( + "constant_prior", + constant_prior, + lambda m: m.constant, + lambda m, v: m._set_constant(v), + ) + + if constant_constraint is None: + constant_constraint = Positive() + self.register_constraint("raw_constant", constant_constraint) + + @property + def constant(self) -> Tensor: + return self.raw_constant_constraint.transform(self.raw_constant) + + @constant.setter + def constant(self, value: Tensor) -> None: + self._set_constant(value) + + def _set_constant(self, value: Tensor) -> None: + value = value.view(*self.batch_shape, 1) + self.initialize(raw_constant=self.raw_constant_constraint.inverse_transform(value)) + + def forward( + self, + x1: Tensor, + x2: Tensor, + diag: Optional[bool] = False, + last_dim_is_batch: Optional[bool] = False, + ) -> Tensor: + """Evaluates the constant kernel. + + Args: + x1: First input tensor of shape (batch_shape x n1 x d). + x2: Second input tensor of shape (batch_shape x n2 x d). + diag: If True, returns the diagonal of the covariance matrix. + last_dim_is_batch: If True, the last dimension of size `d` of the input + tensors are treated as a batch dimension. + + Returns: + A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of + constant covariance values if diag is False, resp. True. + """ + if last_dim_is_batch: + x1 = x1.transpose(-1, -2).unsqueeze(-1) + x2 = x2.transpose(-1, -2).unsqueeze(-1) + + dtype = torch.promote_types(x1.dtype, x2.dtype) + batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) + shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],)) + constant = self.constant.to(dtype=dtype, device=x1.device) + + if not diag: + constant = constant.unsqueeze(-1) + + if last_dim_is_batch: + constant = constant.unsqueeze(-1) + + return constant.expand(shape) diff --git a/gpytorch/test/base_kernel_test_case.py b/gpytorch/test/base_kernel_test_case.py index 5301ce2d9..88f6afbd5 100644 --- a/gpytorch/test/base_kernel_test_case.py +++ b/gpytorch/test/base_kernel_test_case.py @@ -122,23 +122,21 @@ def test_no_batch_kernel_double_batch_x_ard(self): actual_diag = actual_covar_mat.diagonal(dim1=-1, dim2=-2) self.assertAllClose(kernel_diag, actual_diag, rtol=1e-3, atol=1e-5) - def test_smoke_double_batch_kernel_double_batch_x_no_ard(self): + def test_smoke_double_batch_kernel_double_batch_x_no_ard(self) -> None: kernel = self.create_kernel_no_ard(batch_shape=torch.Size([3, 2])) x = self.create_data_double_batch() - batch_covar_mat = kernel(x).evaluate_kernel().to_dense() + kernel(x).evaluate_kernel().to_dense() kernel(x, diag=True) - return batch_covar_mat - def test_smoke_double_batch_kernel_double_batch_x_ard(self): + def test_smoke_double_batch_kernel_double_batch_x_ard(self) -> None: try: kernel = self.create_kernel_ard(num_dims=2, batch_shape=torch.Size([3, 2])) except NotImplementedError: return x = self.create_data_double_batch() - batch_covar_mat = kernel(x).evaluate_kernel().to_dense() + kernel(x).evaluate_kernel().to_dense() kernel(x, diag=True) - return batch_covar_mat def test_kernel_getitem_single_batch(self): kernel = self.create_kernel_no_ard(batch_shape=torch.Size([2])) diff --git a/setup.py b/setup.py index df580c8aa..d5a05fbe9 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ def find_version(*file_paths): "nbclient<=0.7.3", "nbformat<=5.8.0", "nbsphinx<=0.9.1", + "lxml_html_clean", "platformdirs<=3.2.0", "setuptools_scm<=7.1.0", "sphinx<=6.2.1", diff --git a/test/kernels/test_constant_kernel.py b/test/kernels/test_constant_kernel.py new file mode 100644 index 000000000..849ec3996 --- /dev/null +++ b/test/kernels/test_constant_kernel.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 + +import itertools +import unittest + +import torch + +from torch import Tensor + +from gpytorch.kernels import AdditiveKernel, ConstantKernel, MaternKernel, ProductKernel, ScaleKernel +from gpytorch.lazy import LazyEvaluatedKernelTensor +from gpytorch.priors.torch_priors import GammaPrior +from gpytorch.test.base_kernel_test_case import BaseKernelTestCase + + +class TestConstantKernel(unittest.TestCase, BaseKernelTestCase): + def create_kernel_no_ard(self, **kwargs): + return ConstantKernel(**kwargs) + + def test_constant_kernel(self): + with self.subTest(device="cpu"): + self._test_constant_kernel(torch.device("cpu")) + + if torch.cuda.is_available(): + with self.subTest(device="cuda"): + self._test_constant_kernel(torch.device("cuda")) + + def _test_constant_kernel(self, device: torch.device): + n, d = 3, 5 + dtypes = [torch.float, torch.double] + batch_shapes = [(), (2,), (7, 2)] + torch.manual_seed(123) + for dtype, batch_shape in itertools.product(dtypes, batch_shapes): + tkwargs = {"dtype": dtype, "device": device} + places = 6 if dtype == torch.float else 12 + X = torch.rand(*batch_shape, n, d, **tkwargs) + + constant_kernel = ConstantKernel(batch_shape=batch_shape) + KL = constant_kernel(X) + self.assertIsInstance(KL, LazyEvaluatedKernelTensor) + KM = KL.to_dense() + self.assertIsInstance(KM, Tensor) + self.assertEqual(KM.shape, (*batch_shape, n, n)) + self.assertEqual(KM.dtype, dtype) + self.assertEqual(KM.device.type, device.type) + # standard deviation is zero iff KM is constant + self.assertAlmostEqual(KM.std().item(), 0, places=places) + + # testing last_dim_is_batch + with self.subTest(last_dim_is_batch=True): + KD = constant_kernel(X, last_dim_is_batch=True).to(device=device) + self.assertIsInstance(KD, LazyEvaluatedKernelTensor) + KM = KD.to_dense() + self.assertIsInstance(KM, Tensor) + self.assertEqual(KM.shape, (*batch_shape, d, n, n)) + self.assertAlmostEqual(KM.std().item(), 0, places=places) + self.assertEqual(KM.dtype, dtype) + self.assertEqual(KM.device.type, device.type) + + # testing diag + with self.subTest(diag=True): + KD = constant_kernel(X, diag=True) + self.assertIsInstance(KD, Tensor) + self.assertEqual(KD.shape, (*batch_shape, n)) + self.assertAlmostEqual(KD.std().item(), 0, places=places) + self.assertEqual(KD.dtype, dtype) + self.assertEqual(KD.device.type, device.type) + + # testing diag and last_dim_is_batch + with self.subTest(diag=True, last_dim_is_batch=True): + KD = constant_kernel(X, diag=True, last_dim_is_batch=True) + self.assertIsInstance(KD, Tensor) + self.assertEqual(KD.shape, (*batch_shape, d, n)) + self.assertAlmostEqual(KD.std().item(), 0, places=places) + self.assertEqual(KD.dtype, dtype) + self.assertEqual(KD.device.type, device.type) + + # testing AD + with self.subTest(requires_grad=True): + X.requires_grad = True + constant_kernel(X).to_dense().sum().backward() + self.assertIsNone(X.grad) # constant kernel is not dependent on X + + # testing algebraic combinations with another kernel + base_kernel = MaternKernel().to(device=device) + + with self.subTest(additive=True): + sum_kernel = base_kernel + constant_kernel + self.assertIsInstance(sum_kernel, AdditiveKernel) + self.assertAllClose( + sum_kernel(X).to_dense(), + base_kernel(X).to_dense() + constant_kernel.constant.unsqueeze(-1), + ) + + # product with constant is equivalent to scale kernel + with self.subTest(product=True): + product_kernel = base_kernel * constant_kernel + self.assertIsInstance(product_kernel, ProductKernel) + + scale_kernel = ScaleKernel(base_kernel, batch_shape=batch_shape) + scale_kernel.to(device=device) + self.assertAllClose(scale_kernel(X).to_dense(), product_kernel(X).to_dense()) + + # setting constant + pies = torch.full_like(constant_kernel.constant, torch.pi) + constant_kernel.constant = pies + self.assertAllClose(constant_kernel.constant, pies) + + # specifying prior + constant_kernel = ConstantKernel(constant_prior=GammaPrior(concentration=2.4, rate=2.7)) + + with self.assertRaisesRegex(TypeError, "Expected gpytorch.priors.Prior but got"): + ConstantKernel(constant_prior=1)