Skip to content

Commit c18d95d

Browse files
authored
[Breaking Change] Remove last_dim_is_batch (#2544)
* Remove `convert_legacy_grid` option. **Note**: This is not a breaking change; "legacy" grids were deprecated pre v1.0. * Rework GridKernel and GridInterpolationKernel to not use last_dim_is_batch * [Breaking Change] remove AdditiveStructureKernel and ProductStructureKernel - The functionality of both kernels has not disappeared, but both kernels cannot work without the last_dim_is_batch_option. - The examples/00_Basic_Usage/kernels_with_additive_or_product_structure.ipynb notebook describes how to replicate the functionality of both kernels without last_dim_is_batch. * [Breaking Change] remove NewtonGirardAdditiveKernel - The functionality of this kernels has not disappeared, but this kernel cannot work without the last_dim_is_batch_option. - The examples/00_Basic_Usage/kernels_with_additive_or_product_structure.ipynb notebook describes how to replicate the functionality of this kernel using the gpytorch.utils.sum_interaction_terms utility. * [Breaking Change] remove last_dim_is_batch from remaining kernels
1 parent 128254d commit c18d95d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+324
-1172
lines changed

docs/source/kernels.rst

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,24 +119,12 @@ Composition/Decoration Kernels
119119
.. autoclass:: MultiDeviceKernel
120120
:members:
121121

122-
:hidden:`AdditiveStructureKernel`
123-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
124-
125-
.. autoclass:: AdditiveStructureKernel
126-
:members:
127-
128122
:hidden:`ProductKernel`
129123
~~~~~~~~~~~~~~~~~~~~~~~~~
130124

131125
.. autoclass:: ProductKernel
132126
:members:
133127

134-
:hidden:`ProductStructureKernel`
135-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
136-
137-
.. autoclass:: ProductStructureKernel
138-
:members:
139-
140128
:hidden:`ScaleKernel`
141129
~~~~~~~~~~~~~~~~~~~~~~~~~
142130

examples/02_Scalable_Exact_GPs/Grid_GP_Regression.ipynb

Lines changed: 74 additions & 78 deletions
Large diffs are not rendered by default.

gpytorch/kernels/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python3
22
from . import keops
3-
from .additive_structure_kernel import AdditiveStructureKernel
43
from .arc_kernel import ArcKernel
54
from .constant_kernel import ConstantKernel
65
from .cosine_kernel import CosineKernel
@@ -19,12 +18,10 @@
1918
from .matern_kernel import MaternKernel
2019
from .multi_device_kernel import MultiDeviceKernel
2120
from .multitask_kernel import MultitaskKernel
22-
from .newton_girard_additive_kernel import NewtonGirardAdditiveKernel
2321
from .periodic_kernel import PeriodicKernel
2422
from .piecewise_polynomial_kernel import PiecewisePolynomialKernel
2523
from .polynomial_kernel import PolynomialKernel
2624
from .polynomial_kernel_grad import PolynomialKernelGrad
27-
from .product_structure_kernel import ProductStructureKernel
2825
from .rbf_kernel import RBFKernel
2926
from .rbf_kernel_grad import RBFKernelGrad
3027
from .rbf_kernel_gradgrad import RBFKernelGradGrad
@@ -39,7 +36,6 @@
3936
"Kernel",
4037
"ArcKernel",
4138
"AdditiveKernel",
42-
"AdditiveStructureKernel",
4339
"ConstantKernel",
4440
"CylindricalKernel",
4541
"MultiDeviceKernel",
@@ -55,13 +51,11 @@
5551
"LinearKernel",
5652
"MaternKernel",
5753
"MultitaskKernel",
58-
"NewtonGirardAdditiveKernel",
5954
"PeriodicKernel",
6055
"PiecewisePolynomialKernel",
6156
"PolynomialKernel",
6257
"PolynomialKernelGrad",
6358
"ProductKernel",
64-
"ProductStructureKernel",
6559
"RBFKernel",
6660
"RFFKernel",
6761
"RBFKernelGrad",

gpytorch/kernels/additive_structure_kernel.py

Lines changed: 0 additions & 73 deletions
This file was deleted.

gpytorch/kernels/constant_kernel.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,25 +90,18 @@ def forward(
9090
x1: Tensor,
9191
x2: Tensor,
9292
diag: Optional[bool] = False,
93-
last_dim_is_batch: Optional[bool] = False,
9493
) -> Tensor:
9594
"""Evaluates the constant kernel.
9695
9796
Args:
9897
x1: First input tensor of shape (batch_shape x n1 x d).
9998
x2: Second input tensor of shape (batch_shape x n2 x d).
10099
diag: If True, returns the diagonal of the covariance matrix.
101-
last_dim_is_batch: If True, the last dimension of size `d` of the input
102-
tensors are treated as a batch dimension.
103100
104101
Returns:
105102
A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of
106103
constant covariance values if diag is False, resp. True.
107104
"""
108-
if last_dim_is_batch:
109-
x1 = x1.transpose(-1, -2).unsqueeze(-1)
110-
x2 = x2.transpose(-1, -2).unsqueeze(-1)
111-
112105
dtype = torch.promote_types(x1.dtype, x2.dtype)
113106
batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
114107
shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],))
@@ -117,7 +110,4 @@ def forward(
117110
if not diag:
118111
constant = constant.unsqueeze(-1)
119112

120-
if last_dim_is_batch:
121-
constant = constant.unsqueeze(-1)
122-
123113
return constant.expand(shape)

gpytorch/kernels/grid_interpolation_kernel.py

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#!/usr/bin/env python3
22

3-
from typing import List, Optional, Tuple, Union
3+
from typing import Iterable, Optional, Tuple, Union
44

55
import torch
6+
from jaxtyping import Float
67
from linear_operator import to_linear_operator
7-
from linear_operator.operators import InterpolatedLinearOperator
8+
from linear_operator.operators import InterpolatedLinearOperator, LinearOperator
9+
from torch import Tensor
810

911
from ..models.exact_prediction_strategies import InterpolatedPredictionStrategy
1012
from ..utils.grid import create_grid
@@ -25,14 +27,14 @@ class GridInterpolationKernel(GridKernel):
2527
.. math::
2628
2729
\begin{equation*}
28-
k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{U,U} \mathbf{w_{x_2}}
30+
k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{\boldsymbol Z, \boldsymbol Z} \mathbf{w_{x_2}}
2931
\end{equation*}
3032
3133
where
3234
33-
* :math:`U` is the set of gridded inducing points
35+
* :math:`\boldsymbol Z` is the set of gridded inducing points
3436
35-
* :math:`K_{U,U}` is the kernel matrix between the inducing points
37+
* :math:`K_{\boldsymbol Z, \boldsymbol Z}` is the kernel matrix between the inducing points
3638
3739
* :math:`\mathbf{w_{x_1}}` and :math:`\mathbf{w_{x_2}}` are sparse vectors based on
3840
:math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` that apply cubic interpolation.
@@ -50,20 +52,13 @@ class GridInterpolationKernel(GridKernel):
5052
`GridInterpolationKernel` can only wrap **stationary kernels** (such as RBF, Matern,
5153
Periodic, Spectral Mixture, etc.)
5254
53-
Args:
54-
base_kernel (Kernel):
55-
The kernel to approximate with KISS-GP
56-
grid_size (Union[int, List[int]]):
57-
The size of the grid in each dimension.
58-
If a single int is provided, then every dimension will have the same grid size.
59-
num_dims (int):
60-
The dimension of the input data. Required if `grid_bounds=None`
61-
grid_bounds (tuple(float, float), optional):
62-
The bounds of the grid, if known (high performance mode).
63-
The length of the tuple must match the number of dimensions.
64-
The entries represent the min/max values for each dimension.
65-
active_dims (tuple of ints, optional):
66-
Passed down to the `base_kernel`.
55+
:param base_kernel: The kernel to approximate with KISS-GP.
56+
:param grid_size: The size of the grid in each dimension.
57+
If a single int is provided, then every dimension will have the same grid size.
58+
:param num_dims: The dimension of the input data. Required if `grid_bounds=None`
59+
:param grid_bounds: The bounds of the grid, if known (high performance mode).
60+
The length of the tuple must match the number of dimensions.
61+
The entries represent the min/max values for each dimension.
6762
6863
.. _Kernel Interpolation for Scalable Structured Gaussian Processes:
6964
http://proceedings.mlr.press/v37/wilson15.pdf
@@ -72,10 +67,10 @@ class GridInterpolationKernel(GridKernel):
7267
def __init__(
7368
self,
7469
base_kernel: Kernel,
75-
grid_size: Union[int, List[int]],
70+
grid_size: Union[int, Iterable[int]],
7671
num_dims: Optional[int] = None,
7772
grid_bounds: Optional[Tuple[float, float]] = None,
78-
active_dims: Optional[Tuple[int, ...]] = None,
73+
**kwargs,
7974
):
8075
has_initialized_grid = 0
8176
grid_is_dynamic = True
@@ -116,8 +111,7 @@ def __init__(
116111
super(GridInterpolationKernel, self).__init__(
117112
base_kernel=base_kernel,
118113
grid=grid,
119-
interpolation_mode=True,
120-
active_dims=active_dims,
114+
**kwargs,
121115
)
122116
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))
123117

@@ -129,23 +123,26 @@ def _tight_grid_bounds(self):
129123
for bound, spacing in zip(self.grid_bounds, grid_spacings)
130124
)
131125

132-
def _compute_grid(self, inputs, last_dim_is_batch=False):
133-
n_data, n_dimensions = inputs.size(-2), inputs.size(-1)
134-
if last_dim_is_batch:
135-
inputs = inputs.transpose(-1, -2).unsqueeze(-1)
136-
n_dimensions = 1
137-
batch_shape = inputs.shape[:-2]
138-
126+
def _compute_grid(self, inputs):
127+
*batch_shape, n_data, n_dimensions = inputs.shape
139128
inputs = inputs.reshape(-1, n_dimensions)
140129
interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs)
141130
interp_indices = interp_indices.view(*batch_shape, n_data, -1)
142131
interp_values = interp_values.view(*batch_shape, n_data, -1)
143132
return interp_indices, interp_values
144133

145-
def _inducing_forward(self, last_dim_is_batch, **params):
146-
return super().forward(self.grid, self.grid, last_dim_is_batch=last_dim_is_batch, **params)
134+
def _create_or_update_full_grid(self, grid: Iterable[Tensor]):
135+
pass
136+
137+
def _validate_inputs(self, x: Float[Tensor, "... N D"]) -> bool:
138+
return True
147139

148-
def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
140+
def _inducing_forward(self, **params):
141+
return super().forward(None, None, **params)
142+
143+
def forward(
144+
self, x1: Float[Tensor, "... N_1 D"], x2: Float[Tensor, "... N_2 D"], diag: bool = False, **params
145+
) -> Float[Union[Tensor, LinearOperator], "... N_1 N_2"]:
149146
# See if we need to update the grid or not
150147
if self.grid_is_dynamic: # This is true if a grid_bounds wasn't passed in
151148
if torch.equal(x1, x2):
@@ -180,16 +177,13 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
180177
)
181178
self.update_grid(grid)
182179

183-
base_lazy_tsr = to_linear_operator(self._inducing_forward(last_dim_is_batch=last_dim_is_batch, **params))
184-
if last_dim_is_batch and base_lazy_tsr.size(-3) == 1:
185-
base_lazy_tsr = base_lazy_tsr.repeat(*x1.shape[:-2], x1.size(-1), 1, 1)
186-
187-
left_interp_indices, left_interp_values = self._compute_grid(x1, last_dim_is_batch)
180+
base_lazy_tsr = to_linear_operator(self._inducing_forward(**params))
181+
left_interp_indices, left_interp_values = self._compute_grid(x1)
188182
if torch.equal(x1, x2):
189183
right_interp_indices = left_interp_indices
190184
right_interp_values = left_interp_values
191185
else:
192-
right_interp_indices, right_interp_values = self._compute_grid(x2, last_dim_is_batch)
186+
right_interp_indices, right_interp_values = self._compute_grid(x2)
193187

194188
batch_shape = torch.broadcast_shapes(
195189
base_lazy_tsr.batch_shape,

0 commit comments

Comments
 (0)