Skip to content

Commit a158e44

Browse files
authored
Merge pull request #2511 from SebastianAment/constant_kernel
`ConstantKernel`
2 parents e3d8a5e + 89dfb46 commit a158e44

File tree

6 files changed

+251
-7
lines changed

6 files changed

+251
-7
lines changed

docs/source/kernels.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ gpytorch.kernels
99

1010

1111
If you don't know what kernel to use, we recommend that you start out with a
12-
:code:`gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())`.
12+
:code:`gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) + gpytorch.kernel.ConstantKernel()`.
1313

1414

1515
Kernel
@@ -22,6 +22,13 @@ Kernel
2222
Standard Kernels
2323
-----------------------------
2424

25+
:hidden:`ConstantKernel`
26+
~~~~~~~~~~~~~~~~~~~~~~~~~
27+
28+
.. autoclass:: ConstantKernel
29+
:members:
30+
31+
2532
:hidden:`CosineKernel`
2633
~~~~~~~~~~~~~~~~~~~~~~
2734

gpytorch/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from . import keops
33
from .additive_structure_kernel import AdditiveStructureKernel
44
from .arc_kernel import ArcKernel
5+
from .constant_kernel import ConstantKernel
56
from .cosine_kernel import CosineKernel
67
from .cylindrical_kernel import CylindricalKernel
78
from .distributional_input_kernel import DistributionalInputKernel
@@ -38,6 +39,7 @@
3839
"ArcKernel",
3940
"AdditiveKernel",
4041
"AdditiveStructureKernel",
42+
"ConstantKernel",
4143
"CylindricalKernel",
4244
"MultiDeviceKernel",
4345
"CosineKernel",

gpytorch/kernels/constant_kernel.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#!/usr/bin/env python3
2+
3+
from typing import Optional, Tuple
4+
5+
import torch
6+
from torch import Tensor
7+
8+
from ..constraints import Interval, Positive
9+
from ..priors import Prior
10+
from .kernel import Kernel
11+
12+
13+
class ConstantKernel(Kernel):
14+
"""
15+
Constant covariance kernel for the probabilistic inference of constant coefficients.
16+
17+
ConstantKernel represents the prior variance `k(x1, x2) = var(c)` of a constant `c`.
18+
The prior variance of the constant is optimized during the GP hyper-parameter
19+
optimization stage. The actual value of the constant is computed (implicitly) using
20+
the linear algebraic approaches for the computation of GP samples and posteriors.
21+
22+
The constant kernel `k_constant` is most useful as a modification of an arbitrary
23+
base kernel `k_base`:
24+
1) Additive constants: The modification `k_base + k_constant` allows the GP to
25+
infer a non-zero asymptotic value far from the training data, which generally
26+
leads to more accurate extrapolation. Notably, the uncertainty in this constant
27+
value affects the posterior covariances through the posterior inference equations.
28+
This is not the case when a constant prior mean is not used, since the prior mean
29+
does not show up the posterior covariance and is regularized by the log-determinant
30+
during the optimization of the marginal likelihood.
31+
2) Multiplicative constants: The modification `k_base * k_constant` allows the GP to
32+
modulate the variance of the kernel `k_base`, and is mathematically identical to
33+
`ScaleKernel(base_kernel)` with the same constant.
34+
"""
35+
36+
has_lengthscale = False
37+
38+
def __init__(
39+
self,
40+
batch_shape: Optional[torch.Size] = None,
41+
constant_prior: Optional[Prior] = None,
42+
constant_constraint: Optional[Interval] = None,
43+
active_dims: Optional[Tuple[int, ...]] = None,
44+
):
45+
"""Constructor of ConstantKernel.
46+
47+
Args:
48+
batch_shape: The batch shape of the kernel.
49+
constant_prior: Prior over the constant parameter.
50+
constant_constraint: Constraint to place on constant parameter.
51+
active_dims: The dimensions of the input with which to evaluate the kernel.
52+
This is mute for the constant kernel, but added for compatability with
53+
the Kernel API.
54+
"""
55+
super().__init__(batch_shape=batch_shape, active_dims=active_dims)
56+
57+
self.register_parameter(
58+
name="raw_constant",
59+
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)),
60+
)
61+
62+
if constant_prior is not None:
63+
if not isinstance(constant_prior, Prior):
64+
raise TypeError("Expected gpytorch.priors.Prior but got " + type(constant_prior).__name__)
65+
self.register_prior(
66+
"constant_prior",
67+
constant_prior,
68+
lambda m: m.constant,
69+
lambda m, v: m._set_constant(v),
70+
)
71+
72+
if constant_constraint is None:
73+
constant_constraint = Positive()
74+
self.register_constraint("raw_constant", constant_constraint)
75+
76+
@property
77+
def constant(self) -> Tensor:
78+
return self.raw_constant_constraint.transform(self.raw_constant)
79+
80+
@constant.setter
81+
def constant(self, value: Tensor) -> None:
82+
self._set_constant(value)
83+
84+
def _set_constant(self, value: Tensor) -> None:
85+
value = value.view(*self.batch_shape, 1)
86+
self.initialize(raw_constant=self.raw_constant_constraint.inverse_transform(value))
87+
88+
def forward(
89+
self,
90+
x1: Tensor,
91+
x2: Tensor,
92+
diag: Optional[bool] = False,
93+
last_dim_is_batch: Optional[bool] = False,
94+
) -> Tensor:
95+
"""Evaluates the constant kernel.
96+
97+
Args:
98+
x1: First input tensor of shape (batch_shape x n1 x d).
99+
x2: Second input tensor of shape (batch_shape x n2 x d).
100+
diag: If True, returns the diagonal of the covariance matrix.
101+
last_dim_is_batch: If True, the last dimension of size `d` of the input
102+
tensors are treated as a batch dimension.
103+
104+
Returns:
105+
A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of
106+
constant covariance values if diag is False, resp. True.
107+
"""
108+
if last_dim_is_batch:
109+
x1 = x1.transpose(-1, -2).unsqueeze(-1)
110+
x2 = x2.transpose(-1, -2).unsqueeze(-1)
111+
112+
dtype = torch.promote_types(x1.dtype, x2.dtype)
113+
batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
114+
shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],))
115+
constant = self.constant.to(dtype=dtype, device=x1.device)
116+
117+
if not diag:
118+
constant = constant.unsqueeze(-1)
119+
120+
if last_dim_is_batch:
121+
constant = constant.unsqueeze(-1)
122+
123+
return constant.expand(shape)

gpytorch/test/base_kernel_test_case.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,23 +122,21 @@ def test_no_batch_kernel_double_batch_x_ard(self):
122122
actual_diag = actual_covar_mat.diagonal(dim1=-1, dim2=-2)
123123
self.assertAllClose(kernel_diag, actual_diag, rtol=1e-3, atol=1e-5)
124124

125-
def test_smoke_double_batch_kernel_double_batch_x_no_ard(self):
125+
def test_smoke_double_batch_kernel_double_batch_x_no_ard(self) -> None:
126126
kernel = self.create_kernel_no_ard(batch_shape=torch.Size([3, 2]))
127127
x = self.create_data_double_batch()
128-
batch_covar_mat = kernel(x).evaluate_kernel().to_dense()
128+
kernel(x).evaluate_kernel().to_dense()
129129
kernel(x, diag=True)
130-
return batch_covar_mat
131130

132-
def test_smoke_double_batch_kernel_double_batch_x_ard(self):
131+
def test_smoke_double_batch_kernel_double_batch_x_ard(self) -> None:
133132
try:
134133
kernel = self.create_kernel_ard(num_dims=2, batch_shape=torch.Size([3, 2]))
135134
except NotImplementedError:
136135
return
137136

138137
x = self.create_data_double_batch()
139-
batch_covar_mat = kernel(x).evaluate_kernel().to_dense()
138+
kernel(x).evaluate_kernel().to_dense()
140139
kernel(x, diag=True)
141-
return batch_covar_mat
142140

143141
def test_kernel_getitem_single_batch(self):
144142
kernel = self.create_kernel_no_ard(batch_shape=torch.Size([2]))

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def find_version(*file_paths):
8282
"nbclient<=0.7.3",
8383
"nbformat<=5.8.0",
8484
"nbsphinx<=0.9.1",
85+
"lxml_html_clean",
8586
"platformdirs<=3.2.0",
8687
"setuptools_scm<=7.1.0",
8788
"sphinx<=6.2.1",

test/kernels/test_constant_kernel.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#!/usr/bin/env python3
2+
3+
import itertools
4+
import unittest
5+
6+
import torch
7+
8+
from torch import Tensor
9+
10+
from gpytorch.kernels import AdditiveKernel, ConstantKernel, MaternKernel, ProductKernel, ScaleKernel
11+
from gpytorch.lazy import LazyEvaluatedKernelTensor
12+
from gpytorch.priors.torch_priors import GammaPrior
13+
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase
14+
15+
16+
class TestConstantKernel(unittest.TestCase, BaseKernelTestCase):
17+
def create_kernel_no_ard(self, **kwargs):
18+
return ConstantKernel(**kwargs)
19+
20+
def test_constant_kernel(self):
21+
with self.subTest(device="cpu"):
22+
self._test_constant_kernel(torch.device("cpu"))
23+
24+
if torch.cuda.is_available():
25+
with self.subTest(device="cuda"):
26+
self._test_constant_kernel(torch.device("cuda"))
27+
28+
def _test_constant_kernel(self, device: torch.device):
29+
n, d = 3, 5
30+
dtypes = [torch.float, torch.double]
31+
batch_shapes = [(), (2,), (7, 2)]
32+
torch.manual_seed(123)
33+
for dtype, batch_shape in itertools.product(dtypes, batch_shapes):
34+
tkwargs = {"dtype": dtype, "device": device}
35+
places = 6 if dtype == torch.float else 12
36+
X = torch.rand(*batch_shape, n, d, **tkwargs)
37+
38+
constant_kernel = ConstantKernel(batch_shape=batch_shape)
39+
KL = constant_kernel(X)
40+
self.assertIsInstance(KL, LazyEvaluatedKernelTensor)
41+
KM = KL.to_dense()
42+
self.assertIsInstance(KM, Tensor)
43+
self.assertEqual(KM.shape, (*batch_shape, n, n))
44+
self.assertEqual(KM.dtype, dtype)
45+
self.assertEqual(KM.device.type, device.type)
46+
# standard deviation is zero iff KM is constant
47+
self.assertAlmostEqual(KM.std().item(), 0, places=places)
48+
49+
# testing last_dim_is_batch
50+
with self.subTest(last_dim_is_batch=True):
51+
KD = constant_kernel(X, last_dim_is_batch=True).to(device=device)
52+
self.assertIsInstance(KD, LazyEvaluatedKernelTensor)
53+
KM = KD.to_dense()
54+
self.assertIsInstance(KM, Tensor)
55+
self.assertEqual(KM.shape, (*batch_shape, d, n, n))
56+
self.assertAlmostEqual(KM.std().item(), 0, places=places)
57+
self.assertEqual(KM.dtype, dtype)
58+
self.assertEqual(KM.device.type, device.type)
59+
60+
# testing diag
61+
with self.subTest(diag=True):
62+
KD = constant_kernel(X, diag=True)
63+
self.assertIsInstance(KD, Tensor)
64+
self.assertEqual(KD.shape, (*batch_shape, n))
65+
self.assertAlmostEqual(KD.std().item(), 0, places=places)
66+
self.assertEqual(KD.dtype, dtype)
67+
self.assertEqual(KD.device.type, device.type)
68+
69+
# testing diag and last_dim_is_batch
70+
with self.subTest(diag=True, last_dim_is_batch=True):
71+
KD = constant_kernel(X, diag=True, last_dim_is_batch=True)
72+
self.assertIsInstance(KD, Tensor)
73+
self.assertEqual(KD.shape, (*batch_shape, d, n))
74+
self.assertAlmostEqual(KD.std().item(), 0, places=places)
75+
self.assertEqual(KD.dtype, dtype)
76+
self.assertEqual(KD.device.type, device.type)
77+
78+
# testing AD
79+
with self.subTest(requires_grad=True):
80+
X.requires_grad = True
81+
constant_kernel(X).to_dense().sum().backward()
82+
self.assertIsNone(X.grad) # constant kernel is not dependent on X
83+
84+
# testing algebraic combinations with another kernel
85+
base_kernel = MaternKernel().to(device=device)
86+
87+
with self.subTest(additive=True):
88+
sum_kernel = base_kernel + constant_kernel
89+
self.assertIsInstance(sum_kernel, AdditiveKernel)
90+
self.assertAllClose(
91+
sum_kernel(X).to_dense(),
92+
base_kernel(X).to_dense() + constant_kernel.constant.unsqueeze(-1),
93+
)
94+
95+
# product with constant is equivalent to scale kernel
96+
with self.subTest(product=True):
97+
product_kernel = base_kernel * constant_kernel
98+
self.assertIsInstance(product_kernel, ProductKernel)
99+
100+
scale_kernel = ScaleKernel(base_kernel, batch_shape=batch_shape)
101+
scale_kernel.to(device=device)
102+
self.assertAllClose(scale_kernel(X).to_dense(), product_kernel(X).to_dense())
103+
104+
# setting constant
105+
pies = torch.full_like(constant_kernel.constant, torch.pi)
106+
constant_kernel.constant = pies
107+
self.assertAllClose(constant_kernel.constant, pies)
108+
109+
# specifying prior
110+
constant_kernel = ConstantKernel(constant_prior=GammaPrior(concentration=2.4, rate=2.7))
111+
112+
with self.assertRaisesRegex(TypeError, "Expected gpytorch.priors.Prior but got"):
113+
ConstantKernel(constant_prior=1)

0 commit comments

Comments
 (0)