-
Notifications
You must be signed in to change notification settings - Fork 562
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e3d8a5e
commit f719145
Showing
5 changed files
with
246 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
|
||
from ..constraints import Interval, Positive | ||
from ..priors import Prior | ||
from .kernel import Kernel | ||
|
||
|
||
class ConstantKernel(Kernel): | ||
""" | ||
Constant covariance kernel for the probabilistic inference of constant coefficients. | ||
ConstantKernel represents the prior variance `k(x1, x2) = var(c)` of a constant `c`. | ||
The prior variance of the constant is optimized during the GP hyper-parameter | ||
optimization stage. The actual value of the constant is computed (implicitly) using | ||
the linear algebraic approaches for the computation of GP samples and posteriors. | ||
The kernel (`k_constant`) is most useful as a modification of an arbitrary `k_base`: | ||
1) Additive constants: The modification `k_base + k_constant` allows the GP to | ||
infer a non-zero asymptotic value far from the training data, which | ||
generally leads to more accurate extrapolation. Notably, the uncertainty in | ||
this constant value affects the posterior covariances through the posterior | ||
inference equations. This is not the case when a constant prior mean is | ||
used, since the prior mean does not show up the posterior covariance and is | ||
not regularized by the log-determinant during the optimization of the marginal | ||
likelihood. | ||
2) Multiplicative constants: The modification `k_base * k_constant` allows the | ||
GP to modulate the variance of the kernel `k_base`, and is mathematically | ||
identical to `ScaleKernel(base_kernel)` with the same constant. | ||
""" | ||
|
||
has_lengthscale = False | ||
|
||
def __init__( | ||
self, | ||
batch_shape: Optional[torch.Size] = None, | ||
constant_prior: Optional[Prior] = None, | ||
constant_constraint: Optional[Interval] = None, | ||
active_dims: Optional[Tuple[int, ...]] = None, | ||
): | ||
"""Constructor of ConstantKernel. | ||
Args: | ||
batch_shape: The batch shape of the kernel. | ||
constant_prior: Prior over the constant parameter. | ||
constant_constraint: Constraint to place on constant parameter. | ||
""" | ||
super().__init__(batch_shape=batch_shape, active_dims=active_dims) | ||
|
||
self.register_parameter( | ||
name="raw_constant", | ||
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)), | ||
) | ||
|
||
if constant_prior is not None: | ||
if not isinstance(constant_prior, Prior): | ||
raise TypeError("Expected gpytorch.priors.Prior but got " + type(constant_prior).__name__) | ||
self.register_prior( | ||
"constant_prior", | ||
constant_prior, | ||
lambda m: m.constant, | ||
lambda m, v: m._set_constant(v), | ||
) | ||
|
||
if constant_constraint is None: | ||
constant_constraint = Positive() | ||
self.register_constraint("raw_constant", constant_constraint) | ||
|
||
@property | ||
def constant(self) -> torch.Tensor: | ||
return self.raw_constant_constraint.transform(self.raw_constant) | ||
|
||
@constant.setter | ||
def constant(self, value: torch.Tensor) -> None: | ||
self._set_constant(value) | ||
|
||
def _set_constant(self, value: torch.Tensor) -> None: | ||
value = value.view(*self.batch_shape, 1) | ||
self.initialize(raw_constant=self.raw_constant_constraint.inverse_transform(value)) | ||
|
||
def forward( | ||
self, | ||
x1: torch.Tensor, | ||
x2: torch.Tensor, | ||
diag: Optional[bool] = False, | ||
last_dim_is_batch: Optional[bool] = False, | ||
) -> torch.Tensor: | ||
"""Evaluates the constant kernel. | ||
Args: | ||
x1: First input tensor of shape (batch_shape x n1 x d). | ||
x2: Second input tensor of shape (batch_shape x n2 x d). | ||
diag: If True, returns the diagonal of the covariance matrix. | ||
last_dim_is_batch: If True, the last dimension of size `d` of the input | ||
tensors are treated as a batch dimension. | ||
Returns: | ||
A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of | ||
constant covariance values if diag is False, resp. True. | ||
""" | ||
if last_dim_is_batch: | ||
x1 = x1.transpose(-1, -2).unsqueeze(-1) | ||
x2 = x2.transpose(-1, -2).unsqueeze(-1) | ||
|
||
dtype = torch.promote_types(x1.dtype, x2.dtype) | ||
batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) | ||
shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],)) | ||
constant = self.constant.to(dtype=dtype, device=x1.device) | ||
|
||
if not diag: | ||
constant = constant.unsqueeze(-1) | ||
|
||
if last_dim_is_batch: | ||
constant = constant.unsqueeze(-1) | ||
|
||
return constant.expand(shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import itertools | ||
import unittest | ||
|
||
import torch | ||
|
||
from torch import Tensor | ||
|
||
from gpytorch.kernels import AdditiveKernel, ConstantKernel, MaternKernel, ProductKernel, ScaleKernel | ||
from gpytorch.lazy import LazyEvaluatedKernelTensor | ||
from gpytorch.priors.torch_priors import GammaPrior | ||
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase | ||
|
||
|
||
class TestConstantKernel(unittest.TestCase, BaseKernelTestCase): | ||
def create_kernel_no_ard(self, **kwargs): | ||
return ConstantKernel(**kwargs) | ||
|
||
def test_constant_kernel(self): | ||
with self.subTest(device="cpu"): | ||
self._test_constant_kernel(torch.device("cpu")) | ||
|
||
if torch.cuda.is_available(): | ||
with self.subTest(device="cuda"): | ||
self._test_constant_kernel(torch.device("cuda")) | ||
|
||
def _test_constant_kernel(self, device: torch.device): | ||
n, d = 3, 5 | ||
dtypes = [torch.float, torch.double] | ||
batch_shapes = [(), (2,), (7, 2)] | ||
torch.manual_seed(123) | ||
for dtype, batch_shape in itertools.product(dtypes, batch_shapes): | ||
tkwargs = {"dtype": dtype, "device": device} | ||
places = 6 if dtype == torch.float else 12 | ||
X = torch.rand(*batch_shape, n, d, **tkwargs) | ||
|
||
constant_kernel = ConstantKernel(batch_shape=batch_shape) | ||
KL = constant_kernel(X) | ||
self.assertIsInstance(KL, LazyEvaluatedKernelTensor) | ||
KM = KL.to_dense() | ||
self.assertIsInstance(KM, Tensor) | ||
self.assertEqual(KM.shape, (*batch_shape, n, n)) | ||
self.assertEqual(KM.dtype, dtype) | ||
self.assertEqual(KM.device.type, device.type) | ||
# standard deviation is zero iff KM is constant | ||
self.assertAlmostEqual(KM.std().item(), 0, places=places) | ||
|
||
# testing last_dim_is_batch | ||
with self.subTest(last_dim_is_batch=True): | ||
KD = constant_kernel(X, last_dim_is_batch=True).to(device=device) | ||
self.assertIsInstance(KD, LazyEvaluatedKernelTensor) | ||
KM = KD.to_dense() | ||
self.assertIsInstance(KM, Tensor) | ||
self.assertEqual(KM.shape, (*batch_shape, d, n, n)) | ||
self.assertAlmostEqual(KM.std().item(), 0, places=places) | ||
self.assertEqual(KM.dtype, dtype) | ||
self.assertEqual(KM.device.type, device.type) | ||
|
||
# testing diag | ||
with self.subTest(diag=True): | ||
KD = constant_kernel(X, diag=True) | ||
self.assertIsInstance(KD, Tensor) | ||
self.assertEqual(KD.shape, (*batch_shape, n)) | ||
self.assertAlmostEqual(KD.std().item(), 0, places=places) | ||
self.assertEqual(KD.dtype, dtype) | ||
self.assertEqual(KD.device.type, device.type) | ||
|
||
# testing diag and last_dim_is_batch | ||
with self.subTest(diag=True, last_dim_is_batch=True): | ||
KD = constant_kernel(X, diag=True, last_dim_is_batch=True) | ||
self.assertIsInstance(KD, Tensor) | ||
self.assertEqual(KD.shape, (*batch_shape, d, n)) | ||
self.assertAlmostEqual(KD.std().item(), 0, places=places) | ||
self.assertEqual(KD.dtype, dtype) | ||
self.assertEqual(KD.device.type, device.type) | ||
|
||
# testing AD | ||
with self.subTest(requires_grad=True): | ||
X.requires_grad = True | ||
constant_kernel(X).to_dense().sum().backward() | ||
self.assertIsNone(X.grad) # constant kernel is not dependent on X | ||
|
||
# testing algebraic combinations with another kernel | ||
base_kernel = MaternKernel().to(device=device) | ||
|
||
with self.subTest(additive=True): | ||
sum_kernel = base_kernel + constant_kernel | ||
self.assertIsInstance(sum_kernel, AdditiveKernel) | ||
self.assertAllClose( | ||
sum_kernel(X).to_dense(), | ||
base_kernel(X).to_dense() + constant_kernel.constant.unsqueeze(-1), | ||
) | ||
|
||
# product with constant is equivalent to scale kernel | ||
with self.subTest(product=True): | ||
product_kernel = base_kernel * constant_kernel | ||
self.assertIsInstance(product_kernel, ProductKernel) | ||
|
||
scale_kernel = ScaleKernel(base_kernel, batch_shape=batch_shape) | ||
scale_kernel.to(device=device) | ||
self.assertAllClose(scale_kernel(X).to_dense(), product_kernel(X).to_dense()) | ||
|
||
# setting constant | ||
pies = torch.full_like(constant_kernel.constant, torch.pi) | ||
constant_kernel.constant = pies | ||
self.assertAllClose(constant_kernel.constant, pies) | ||
|
||
# specifying prior | ||
constant_kernel = ConstantKernel(constant_prior=GammaPrior(concentration=2.4, rate=2.7)) | ||
|
||
with self.assertRaisesRegex(TypeError, "Expected gpytorch.priors.Prior but got"): | ||
ConstantKernel(constant_prior=1) |