1
1
#!/usr/bin/env python3
2
2
3
- from typing import List , Optional , Tuple , Union
3
+ from typing import Iterable , Optional , Tuple , Union
4
4
5
5
import torch
6
+ from jaxtyping import Float
6
7
from linear_operator import to_linear_operator
7
- from linear_operator .operators import InterpolatedLinearOperator
8
+ from linear_operator .operators import InterpolatedLinearOperator , LinearOperator
9
+ from torch import Tensor
8
10
9
11
from ..models .exact_prediction_strategies import InterpolatedPredictionStrategy
10
12
from ..utils .grid import create_grid
@@ -25,14 +27,14 @@ class GridInterpolationKernel(GridKernel):
25
27
.. math::
26
28
27
29
\begin{equation*}
28
- k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{U,U } \mathbf{w_{x_2}}
30
+ k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{\boldsymbol Z, \boldsymbol Z } \mathbf{w_{x_2}}
29
31
\end{equation*}
30
32
31
33
where
32
34
33
- * :math:`U ` is the set of gridded inducing points
35
+ * :math:`\boldsymbol Z ` is the set of gridded inducing points
34
36
35
- * :math:`K_{U,U }` is the kernel matrix between the inducing points
37
+ * :math:`K_{\boldsymbol Z, \boldsymbol Z }` is the kernel matrix between the inducing points
36
38
37
39
* :math:`\mathbf{w_{x_1}}` and :math:`\mathbf{w_{x_2}}` are sparse vectors based on
38
40
:math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` that apply cubic interpolation.
@@ -50,20 +52,13 @@ class GridInterpolationKernel(GridKernel):
50
52
`GridInterpolationKernel` can only wrap **stationary kernels** (such as RBF, Matern,
51
53
Periodic, Spectral Mixture, etc.)
52
54
53
- Args:
54
- base_kernel (Kernel):
55
- The kernel to approximate with KISS-GP
56
- grid_size (Union[int, List[int]]):
57
- The size of the grid in each dimension.
58
- If a single int is provided, then every dimension will have the same grid size.
59
- num_dims (int):
60
- The dimension of the input data. Required if `grid_bounds=None`
61
- grid_bounds (tuple(float, float), optional):
62
- The bounds of the grid, if known (high performance mode).
63
- The length of the tuple must match the number of dimensions.
64
- The entries represent the min/max values for each dimension.
65
- active_dims (tuple of ints, optional):
66
- Passed down to the `base_kernel`.
55
+ :param base_kernel: The kernel to approximate with KISS-GP.
56
+ :param grid_size: The size of the grid in each dimension.
57
+ If a single int is provided, then every dimension will have the same grid size.
58
+ :param num_dims: The dimension of the input data. Required if `grid_bounds=None`
59
+ :param grid_bounds: The bounds of the grid, if known (high performance mode).
60
+ The length of the tuple must match the number of dimensions.
61
+ The entries represent the min/max values for each dimension.
67
62
68
63
.. _Kernel Interpolation for Scalable Structured Gaussian Processes:
69
64
http://proceedings.mlr.press/v37/wilson15.pdf
@@ -72,10 +67,10 @@ class GridInterpolationKernel(GridKernel):
72
67
def __init__ (
73
68
self ,
74
69
base_kernel : Kernel ,
75
- grid_size : Union [int , List [int ]],
70
+ grid_size : Union [int , Iterable [int ]],
76
71
num_dims : Optional [int ] = None ,
77
72
grid_bounds : Optional [Tuple [float , float ]] = None ,
78
- active_dims : Optional [ Tuple [ int , ...]] = None ,
73
+ ** kwargs ,
79
74
):
80
75
has_initialized_grid = 0
81
76
grid_is_dynamic = True
@@ -116,8 +111,7 @@ def __init__(
116
111
super (GridInterpolationKernel , self ).__init__ (
117
112
base_kernel = base_kernel ,
118
113
grid = grid ,
119
- interpolation_mode = True ,
120
- active_dims = active_dims ,
114
+ ** kwargs ,
121
115
)
122
116
self .register_buffer ("has_initialized_grid" , torch .tensor (has_initialized_grid , dtype = torch .bool ))
123
117
@@ -129,23 +123,26 @@ def _tight_grid_bounds(self):
129
123
for bound , spacing in zip (self .grid_bounds , grid_spacings )
130
124
)
131
125
132
- def _compute_grid (self , inputs , last_dim_is_batch = False ):
133
- n_data , n_dimensions = inputs .size (- 2 ), inputs .size (- 1 )
134
- if last_dim_is_batch :
135
- inputs = inputs .transpose (- 1 , - 2 ).unsqueeze (- 1 )
136
- n_dimensions = 1
137
- batch_shape = inputs .shape [:- 2 ]
138
-
126
+ def _compute_grid (self , inputs ):
127
+ * batch_shape , n_data , n_dimensions = inputs .shape
139
128
inputs = inputs .reshape (- 1 , n_dimensions )
140
129
interp_indices , interp_values = Interpolation ().interpolate (self .grid , inputs )
141
130
interp_indices = interp_indices .view (* batch_shape , n_data , - 1 )
142
131
interp_values = interp_values .view (* batch_shape , n_data , - 1 )
143
132
return interp_indices , interp_values
144
133
145
- def _inducing_forward (self , last_dim_is_batch , ** params ):
146
- return super ().forward (self .grid , self .grid , last_dim_is_batch = last_dim_is_batch , ** params )
134
+ def _create_or_update_full_grid (self , grid : Iterable [Tensor ]):
135
+ pass
136
+
137
+ def _validate_inputs (self , x : Float [Tensor , "... N D" ]) -> bool :
138
+ return True
147
139
148
- def forward (self , x1 , x2 , diag = False , last_dim_is_batch = False , ** params ):
140
+ def _inducing_forward (self , ** params ):
141
+ return super ().forward (None , None , ** params )
142
+
143
+ def forward (
144
+ self , x1 : Float [Tensor , "... N_1 D" ], x2 : Float [Tensor , "... N_2 D" ], diag : bool = False , ** params
145
+ ) -> Float [Union [Tensor , LinearOperator ], "... N_1 N_2" ]:
149
146
# See if we need to update the grid or not
150
147
if self .grid_is_dynamic : # This is true if a grid_bounds wasn't passed in
151
148
if torch .equal (x1 , x2 ):
@@ -180,16 +177,13 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
180
177
)
181
178
self .update_grid (grid )
182
179
183
- base_lazy_tsr = to_linear_operator (self ._inducing_forward (last_dim_is_batch = last_dim_is_batch , ** params ))
184
- if last_dim_is_batch and base_lazy_tsr .size (- 3 ) == 1 :
185
- base_lazy_tsr = base_lazy_tsr .repeat (* x1 .shape [:- 2 ], x1 .size (- 1 ), 1 , 1 )
186
-
187
- left_interp_indices , left_interp_values = self ._compute_grid (x1 , last_dim_is_batch )
180
+ base_lazy_tsr = to_linear_operator (self ._inducing_forward (** params ))
181
+ left_interp_indices , left_interp_values = self ._compute_grid (x1 )
188
182
if torch .equal (x1 , x2 ):
189
183
right_interp_indices = left_interp_indices
190
184
right_interp_values = left_interp_values
191
185
else :
192
- right_interp_indices , right_interp_values = self ._compute_grid (x2 , last_dim_is_batch )
186
+ right_interp_indices , right_interp_values = self ._compute_grid (x2 )
193
187
194
188
batch_shape = torch .broadcast_shapes (
195
189
base_lazy_tsr .batch_shape ,
0 commit comments