Skip to content

Commit

Permalink
[Breaking Change] Remove last_dim_is_batch (#2544)
Browse files Browse the repository at this point in the history
* Remove `convert_legacy_grid` option.

**Note**: This is not a breaking change; "legacy" grids were deprecated
pre v1.0.

* Rework GridKernel and GridInterpolationKernel to not use last_dim_is_batch

* [Breaking Change] remove AdditiveStructureKernel and ProductStructureKernel

- The functionality of both kernels has not disappeared, but both
  kernels cannot work without the last_dim_is_batch_option.
- The examples/00_Basic_Usage/kernels_with_additive_or_product_structure.ipynb
  notebook describes how to replicate the functionality of both kernels
  without last_dim_is_batch.

* [Breaking Change] remove NewtonGirardAdditiveKernel

- The functionality of this kernels has not disappeared, but this
  kernel cannot work without the last_dim_is_batch_option.
- The examples/00_Basic_Usage/kernels_with_additive_or_product_structure.ipynb
  notebook describes how to replicate the functionality of this kernel
  using the gpytorch.utils.sum_interaction_terms utility.

* [Breaking Change] remove last_dim_is_batch from remaining kernels
  • Loading branch information
gpleiss authored Aug 13, 2024
1 parent 128254d commit c18d95d
Show file tree
Hide file tree
Showing 43 changed files with 324 additions and 1,172 deletions.
12 changes: 0 additions & 12 deletions docs/source/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,24 +119,12 @@ Composition/Decoration Kernels
.. autoclass:: MultiDeviceKernel
:members:

:hidden:`AdditiveStructureKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: AdditiveStructureKernel
:members:

:hidden:`ProductKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ProductKernel
:members:

:hidden:`ProductStructureKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ProductStructureKernel
:members:

:hidden:`ScaleKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
152 changes: 74 additions & 78 deletions examples/02_Scalable_Exact_GPs/Grid_GP_Regression.ipynb

Large diffs are not rendered by default.

6 changes: 0 additions & 6 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3
from . import keops
from .additive_structure_kernel import AdditiveStructureKernel
from .arc_kernel import ArcKernel
from .constant_kernel import ConstantKernel
from .cosine_kernel import CosineKernel
Expand All @@ -19,12 +18,10 @@
from .matern_kernel import MaternKernel
from .multi_device_kernel import MultiDeviceKernel
from .multitask_kernel import MultitaskKernel
from .newton_girard_additive_kernel import NewtonGirardAdditiveKernel
from .periodic_kernel import PeriodicKernel
from .piecewise_polynomial_kernel import PiecewisePolynomialKernel
from .polynomial_kernel import PolynomialKernel
from .polynomial_kernel_grad import PolynomialKernelGrad
from .product_structure_kernel import ProductStructureKernel
from .rbf_kernel import RBFKernel
from .rbf_kernel_grad import RBFKernelGrad
from .rbf_kernel_gradgrad import RBFKernelGradGrad
Expand All @@ -39,7 +36,6 @@
"Kernel",
"ArcKernel",
"AdditiveKernel",
"AdditiveStructureKernel",
"ConstantKernel",
"CylindricalKernel",
"MultiDeviceKernel",
Expand All @@ -55,13 +51,11 @@
"LinearKernel",
"MaternKernel",
"MultitaskKernel",
"NewtonGirardAdditiveKernel",
"PeriodicKernel",
"PiecewisePolynomialKernel",
"PolynomialKernel",
"PolynomialKernelGrad",
"ProductKernel",
"ProductStructureKernel",
"RBFKernel",
"RFFKernel",
"RBFKernelGrad",
Expand Down
73 changes: 0 additions & 73 deletions gpytorch/kernels/additive_structure_kernel.py

This file was deleted.

10 changes: 0 additions & 10 deletions gpytorch/kernels/constant_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,18 @@ def forward(
x1: Tensor,
x2: Tensor,
diag: Optional[bool] = False,
last_dim_is_batch: Optional[bool] = False,
) -> Tensor:
"""Evaluates the constant kernel.
Args:
x1: First input tensor of shape (batch_shape x n1 x d).
x2: Second input tensor of shape (batch_shape x n2 x d).
diag: If True, returns the diagonal of the covariance matrix.
last_dim_is_batch: If True, the last dimension of size `d` of the input
tensors are treated as a batch dimension.
Returns:
A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of
constant covariance values if diag is False, resp. True.
"""
if last_dim_is_batch:
x1 = x1.transpose(-1, -2).unsqueeze(-1)
x2 = x2.transpose(-1, -2).unsqueeze(-1)

dtype = torch.promote_types(x1.dtype, x2.dtype)
batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],))
Expand All @@ -117,7 +110,4 @@ def forward(
if not diag:
constant = constant.unsqueeze(-1)

if last_dim_is_batch:
constant = constant.unsqueeze(-1)

return constant.expand(shape)
72 changes: 33 additions & 39 deletions gpytorch/kernels/grid_interpolation_kernel.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#!/usr/bin/env python3

from typing import List, Optional, Tuple, Union
from typing import Iterable, Optional, Tuple, Union

import torch
from jaxtyping import Float
from linear_operator import to_linear_operator
from linear_operator.operators import InterpolatedLinearOperator
from linear_operator.operators import InterpolatedLinearOperator, LinearOperator
from torch import Tensor

from ..models.exact_prediction_strategies import InterpolatedPredictionStrategy
from ..utils.grid import create_grid
Expand All @@ -25,14 +27,14 @@ class GridInterpolationKernel(GridKernel):
.. math::
\begin{equation*}
k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{U,U} \mathbf{w_{x_2}}
k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{\boldsymbol Z, \boldsymbol Z} \mathbf{w_{x_2}}
\end{equation*}
where
* :math:`U` is the set of gridded inducing points
* :math:`\boldsymbol Z` is the set of gridded inducing points
* :math:`K_{U,U}` is the kernel matrix between the inducing points
* :math:`K_{\boldsymbol Z, \boldsymbol Z}` is the kernel matrix between the inducing points
* :math:`\mathbf{w_{x_1}}` and :math:`\mathbf{w_{x_2}}` are sparse vectors based on
:math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` that apply cubic interpolation.
Expand All @@ -50,20 +52,13 @@ class GridInterpolationKernel(GridKernel):
`GridInterpolationKernel` can only wrap **stationary kernels** (such as RBF, Matern,
Periodic, Spectral Mixture, etc.)
Args:
base_kernel (Kernel):
The kernel to approximate with KISS-GP
grid_size (Union[int, List[int]]):
The size of the grid in each dimension.
If a single int is provided, then every dimension will have the same grid size.
num_dims (int):
The dimension of the input data. Required if `grid_bounds=None`
grid_bounds (tuple(float, float), optional):
The bounds of the grid, if known (high performance mode).
The length of the tuple must match the number of dimensions.
The entries represent the min/max values for each dimension.
active_dims (tuple of ints, optional):
Passed down to the `base_kernel`.
:param base_kernel: The kernel to approximate with KISS-GP.
:param grid_size: The size of the grid in each dimension.
If a single int is provided, then every dimension will have the same grid size.
:param num_dims: The dimension of the input data. Required if `grid_bounds=None`
:param grid_bounds: The bounds of the grid, if known (high performance mode).
The length of the tuple must match the number of dimensions.
The entries represent the min/max values for each dimension.
.. _Kernel Interpolation for Scalable Structured Gaussian Processes:
http://proceedings.mlr.press/v37/wilson15.pdf
Expand All @@ -72,10 +67,10 @@ class GridInterpolationKernel(GridKernel):
def __init__(
self,
base_kernel: Kernel,
grid_size: Union[int, List[int]],
grid_size: Union[int, Iterable[int]],
num_dims: Optional[int] = None,
grid_bounds: Optional[Tuple[float, float]] = None,
active_dims: Optional[Tuple[int, ...]] = None,
**kwargs,
):
has_initialized_grid = 0
grid_is_dynamic = True
Expand Down Expand Up @@ -116,8 +111,7 @@ def __init__(
super(GridInterpolationKernel, self).__init__(
base_kernel=base_kernel,
grid=grid,
interpolation_mode=True,
active_dims=active_dims,
**kwargs,
)
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))

Expand All @@ -129,23 +123,26 @@ def _tight_grid_bounds(self):
for bound, spacing in zip(self.grid_bounds, grid_spacings)
)

def _compute_grid(self, inputs, last_dim_is_batch=False):
n_data, n_dimensions = inputs.size(-2), inputs.size(-1)
if last_dim_is_batch:
inputs = inputs.transpose(-1, -2).unsqueeze(-1)
n_dimensions = 1
batch_shape = inputs.shape[:-2]

def _compute_grid(self, inputs):
*batch_shape, n_data, n_dimensions = inputs.shape
inputs = inputs.reshape(-1, n_dimensions)
interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs)
interp_indices = interp_indices.view(*batch_shape, n_data, -1)
interp_values = interp_values.view(*batch_shape, n_data, -1)
return interp_indices, interp_values

def _inducing_forward(self, last_dim_is_batch, **params):
return super().forward(self.grid, self.grid, last_dim_is_batch=last_dim_is_batch, **params)
def _create_or_update_full_grid(self, grid: Iterable[Tensor]):
pass

def _validate_inputs(self, x: Float[Tensor, "... N D"]) -> bool:
return True

def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
def _inducing_forward(self, **params):
return super().forward(None, None, **params)

def forward(
self, x1: Float[Tensor, "... N_1 D"], x2: Float[Tensor, "... N_2 D"], diag: bool = False, **params
) -> Float[Union[Tensor, LinearOperator], "... N_1 N_2"]:
# See if we need to update the grid or not
if self.grid_is_dynamic: # This is true if a grid_bounds wasn't passed in
if torch.equal(x1, x2):
Expand Down Expand Up @@ -180,16 +177,13 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
)
self.update_grid(grid)

base_lazy_tsr = to_linear_operator(self._inducing_forward(last_dim_is_batch=last_dim_is_batch, **params))
if last_dim_is_batch and base_lazy_tsr.size(-3) == 1:
base_lazy_tsr = base_lazy_tsr.repeat(*x1.shape[:-2], x1.size(-1), 1, 1)

left_interp_indices, left_interp_values = self._compute_grid(x1, last_dim_is_batch)
base_lazy_tsr = to_linear_operator(self._inducing_forward(**params))
left_interp_indices, left_interp_values = self._compute_grid(x1)
if torch.equal(x1, x2):
right_interp_indices = left_interp_indices
right_interp_values = left_interp_values
else:
right_interp_indices, right_interp_values = self._compute_grid(x2, last_dim_is_batch)
right_interp_indices, right_interp_values = self._compute_grid(x2)

batch_shape = torch.broadcast_shapes(
base_lazy_tsr.batch_shape,
Expand Down
Loading

0 comments on commit c18d95d

Please sign in to comment.