Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding check for unsorted input coordinates when using QuasisepSolver #123

Merged
merged 3 commits into from
Oct 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docs/tutorials/quasisep-custom.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.10.6 ('tinygp')",
"language": "python",
"name": "python3"
},
Expand All @@ -684,7 +684,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
"version": "3.10.6"
},
"vscode": {
"interpreter": {
"hash": "d20ea8a315da34b3e8fab0dbd7b542a0ef3c8cf12937343660e6bc10a20768e3"
}
}
},
"nbformat": 4,
Expand Down
2 changes: 2 additions & 0 deletions news/123.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added check for sorted input coordinates when using the ``QuasisepSolver``;
a ``ValueError`` is thrown if they are not.
7 changes: 6 additions & 1 deletion src/tinygp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
solver: Optional[Any] = None,
mean_value: Optional[JAXArray] = None,
covariance_value: Optional[Any] = None,
**solver_kwargs: Any,
):
self.kernel = kernel
self.X = X
Expand Down Expand Up @@ -101,7 +102,11 @@ def __init__(
else:
solver = DirectSolver
self.solver = solver.init(
kernel, self.X, self.noise, covariance=covariance_value
kernel,
self.X,
self.noise,
covariance=covariance_value,
**solver_kwargs,
)

@property
Expand Down
27 changes: 19 additions & 8 deletions src/tinygp/kernels/quasisep.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
]

from abc import ABCMeta, abstractmethod
from typing import Optional, Union
from typing import Any, Optional, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -151,7 +151,7 @@ def __add__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
)
return Sum(self, other)

def __radd__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
def __radd__(self, other: Any) -> "Kernel":
# We'll hit this first branch when using the `sum` function
if other == 0:
return self
Expand All @@ -171,7 +171,7 @@ def __mul__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
)
return Scale(kernel=self, scale=other)

def __rmul__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
def __rmul__(self, other: Any) -> "Kernel":
if isinstance(other, Quasisep):
return Product(other, self)
if isinstance(other, Kernel) or jnp.ndim(other) != 0:
Expand Down Expand Up @@ -204,6 +204,9 @@ class Wrapper(Quasisep, metaclass=ABCMeta):

kernel: Quasisep

def coord_to_sortable(self, X: JAXArray) -> JAXArray:
return self.kernel.coord_to_sortable(X)

def design_matrix(self) -> JAXArray:
return self.kernel.design_matrix()

Expand All @@ -226,6 +229,10 @@ class Sum(Quasisep):
kernel1: Quasisep
kernel2: Quasisep

def coord_to_sortable(self, X: JAXArray) -> JAXArray:
"""We assume that both kernels use the same coordinates"""
return self.kernel1.coord_to_sortable(X)

def design_matrix(self) -> JAXArray:
return jsp.linalg.block_diag(
self.kernel1.design_matrix(), self.kernel2.design_matrix()
Expand Down Expand Up @@ -259,6 +266,10 @@ class Product(Quasisep):
kernel1: Quasisep
kernel2: Quasisep

def coord_to_sortable(self, X: JAXArray) -> JAXArray:
"""We assume that both kernels use the same coordinates"""
return self.kernel1.coord_to_sortable(X)

def design_matrix(self) -> JAXArray:
F1 = self.kernel1.design_matrix()
F2 = self.kernel2.design_matrix()
Expand Down Expand Up @@ -699,14 +710,14 @@ def init(
params = jnp.linalg.solve(
params, 0.5 * sigma**2 * jnp.eye(p, 1, k=-p + 1)
)[:, 0]
stn = []
stn_ = []
for j in range(p):
stn.append([jnp.zeros(()) for _ in range(p)])
stn_.append([jnp.zeros(()) for _ in range(p)])
for n, k in enumerate(range(j - 2, -1, -2)):
stn[-1][k] = (2 * (n % 2) - 1) * params[j - n - 1]
stn_[-1][k] = (2 * (n % 2) - 1) * params[j - n - 1]
for n, k in enumerate(range(j, p, 2)):
stn[-1][k] = (1 - 2 * (n % 2)) * params[n + j]
stn = jnp.array(list(map(jnp.stack, stn)))
stn_[-1][k] = (1 - 2 * (n % 2)) * params[n + j]
stn = jnp.array(list(map(jnp.stack, stn_)))

return cls(
sigma=sigma,
Expand Down
23 changes: 20 additions & 3 deletions src/tinygp/solvers/quasisep/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__all__ = ["QuasisepSolver"]

from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -41,6 +41,7 @@ def init(
noise: Noise,
*,
covariance: Optional[Any] = None,
assume_sorted: bool = False,
) -> "QuasisepSolver":
"""Build a :class:`QuasisepSolver` for a given kernel and coordinates

Expand All @@ -52,15 +53,24 @@ def init(
covariance: Optionally, a pre-computed
:class:`tinygp.solvers.quasisep.core.QSM` with the covariance
matrix.
assume_sorted: If ``True``, assume that the input coordinates are
sorted. If ``False``, check that they are sorted and throw an
error if they are not. This can introduce a runtime overhead,
and you can pass ``assume_sorted=True`` to get the best
performance.
"""
from tinygp.kernels.quasisep import Quasisep

if covariance is None:
assert isinstance(kernel, Quasisep)
if TYPE_CHECKING:
assert isinstance(kernel, Quasisep)
if not assume_sorted:
jax.debug.callback(_check_sorted, kernel.coord_to_sortable(X))
matrix = kernel.to_symm_qsm(X)
matrix += noise.to_qsm()
else:
assert isinstance(covariance, SymmQSM)
if TYPE_CHECKING:
assert isinstance(covariance, SymmQSM)
matrix = covariance
factor = matrix.cholesky()
return cls(X=X, matrix=matrix, factor=factor)
Expand Down Expand Up @@ -125,3 +135,10 @@ def condition(

A = self.solve_triangular(Ks)
return Kss - A.transpose() @ A


def _check_sorted(X: JAXArray) -> None:
if np.any(np.diff(X) < 0.0):
raise ValueError(
"Input coordinates must be sorted in order to use the QuasisepSolver"
)
21 changes: 20 additions & 1 deletion tests/test_solvers/test_quasisep/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_consistent_with_direct(kernel_pair, data):

@pytest.mark.skipif(celerite is None, reason="'celerite' must be installed")
def test_celerite(data):
x, y, t = data
x, y, _ = data
yerr = 0.1

a, b, c, d = 1.1, 0.8, 0.9, 0.1
Expand All @@ -125,3 +125,22 @@ def test_celerite(data):
calc = gp.log_probability(y)

np.testing.assert_allclose(calc, expected)


def test_unsorted(data):
random = np.random.default_rng(0)
inds = random.permutation(len(data[0]))
x_ = data[0][inds]
y_ = data[1][inds]

kernel = quasisep.Matern32(sigma=1.8, scale=1.5)
with pytest.raises(ValueError):
GaussianProcess(kernel, x_, diag=0.1)

@jax.jit
def impl(X, y):
return GaussianProcess(kernel, X, diag=0.1).log_probability(y)

with pytest.raises(jax.lib.xla_extension.XlaRuntimeError) as exc_info:
impl(x_, y_).block_until_ready()
assert exc_info.match(r"Input coordinates must be sorted")