Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: reordered docs for solvers and kernels #759

Merged
merged 14 commits into from
Oct 2, 2024
Merged
28 changes: 25 additions & 3 deletions coreax/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import jit
from jaxtyping import Array, ArrayLike, Shaped
from jaxtyping import Array, Shaped
from typing_extensions import Self


Expand Down Expand Up @@ -134,8 +134,30 @@ def __getitem__(self, key) -> Self:
"""Support `Array` style indexing of `Data` objects."""
return jtu.tree_map(lambda x: x[key], self)

def __jax_array__(self) -> Shaped[ArrayLike, " n d"]:
"""Register `ArrayLike` behaviour - return for `jnp.asarray(Data(...))`."""
@overload
def __jax_array__(
self: "Data",
) -> Shaped[Array, " n d"]: ...

@overload
def __jax_array__( # pyright:ignore[reportOverlappingOverload]
self: "SupervisedData",
) -> Shaped[Array, " n d + p"]: ...
tp832944 marked this conversation as resolved.
Show resolved Hide resolved

def __jax_array__(
self: Union["Data", "SupervisedData"],
) -> Union[Shaped[Array, " n d"], Shaped[Array, " n d + p"]]:
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
"""
Return value of `jnp.asarray(Data(...))` and `jnp.asarray(SupervisedData(...))`.

.. note::

When ``self`` is a `SupervisedData` instance `jnp.asarray` will return
a single array where the ``supervision`` array has been
right-concatenated onto the``data`` array.
"""
if isinstance(self, SupervisedData):
return jnp.hstack((self.data, self.supervision))
return self.data

def __len__(self) -> int:
Expand Down
92 changes: 49 additions & 43 deletions coreax/kernels/scalar_valued.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ class PolynomialKernel(ScalarValuedKernel):
r"""
Define a polynomial kernel.

Given :math:`\rho =``output_scale`, :math:`c =`'constant', and :math:`d=`'degree',
the polynomial kernel is defined as
:math:`k: \mathbb{R}^d\times \mathbb{R}^d \to \mathbb{R}`,
Given :math:`\rho =` ``output_scale``, :math:`c =` ``constant``, and
:math:`d=` ``degree``, the polynomial kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho (x^Ty + c)^d`.

:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
Expand Down Expand Up @@ -140,10 +140,11 @@ class ExponentialKernel(ScalarValuedKernel):
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
exponential kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp( -\frac{ ||x-y|| }{ 2 \lambda^2 } )` where
:math:`k(x, y) = \rho * \exp(-\frac{||x-y||}{2 \lambda^2})` where
:math:`||\cdot||` is the usual :math:`L_2`-norm.

.. warning::

The exponential kernel is not differentiable when :math:`x=y`.
tp832944 marked this conversation as resolved.
Show resolved Hide resolved

:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
Expand All @@ -162,25 +163,25 @@ def __check_init__(self):
raise ValueError("'output_scale' must be positive")

@override
def compute_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def compute_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
return self.output_scale * jnp.exp(
-jnp.linalg.norm(jnp.subtract(x, y)) / (2 * self.length_scale**2)
)

@override
def grad_x_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def grad_x_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
return -self.grad_y_elementwise(x, y)

@override
def grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def grad_y_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
sub = jnp.subtract(x, y)
dist = jnp.linalg.norm(sub)
factor = 2 * self.length_scale**2
return self.output_scale * sub * jnp.exp(-dist / factor) / (factor * dist)

@override
def divergence_x_grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
d = len(jnp.asarray(x))
def divergence_x_grad_y_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
d = len(jnp.atleast_1d(x))
sub = jnp.subtract(x, y)
dist = jnp.linalg.norm(sub)
factor = 2 * self.length_scale**2
Expand All @@ -198,9 +199,9 @@ class LaplacianKernel(ScalarValuedKernel):
r"""
Define a Laplacian kernel.

Given :math:`\lambda =``length_scale` and :math:`\rho =``output_scale`, the
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
Laplacian kernel is defined as
:math:`k: \mathbb{R}^d\times \mathbb{R}^d \to \mathbb{R}`,
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp(-\frac{||x-y||_1}{2 \lambda^2})` where
:math:`||\cdot||_1` is the :math:`L_1`-norm.

Expand All @@ -220,25 +221,25 @@ def __check_init__(self):
raise ValueError("'output_scale' must be positive")

@override
def compute_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def compute_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
return self.output_scale * jnp.exp(
-jnp.linalg.norm(jnp.subtract(x, y), ord=1) / (2 * self.length_scale**2)
)

@override
def grad_x_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def grad_x_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
return -self.grad_y_elementwise(x, y)

@override
def grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def grad_y_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
return (
jnp.sign(jnp.subtract(x, y))
/ (2 * self.length_scale**2)
* self.compute_elementwise(x, y)
)

@override
def divergence_x_grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def divergence_x_grad_y_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
k = self.compute_elementwise(x, y)
d = len(jnp.asarray(x))
return -d * k / (4 * self.length_scale**4)
Expand All @@ -248,9 +249,9 @@ class SquaredExponentialKernel(ScalarValuedKernel):
r"""
Define a squared exponential kernel.

Given :math:`\lambda =``length_scale` and :math:`\rho =``output_scale`, the squared
exponential kernel is defined as
:math:`k: \mathbb{R}^d\times \mathbb{R}^d \to \mathbb{R}`,
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
squared exponential kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp(-\frac{||x-y||^2}{2 \lambda^2})` where
:math:`||\cdot||` is the usual :math:`L_2`-norm.

Expand Down Expand Up @@ -297,9 +298,9 @@ class PCIMQKernel(ScalarValuedKernel):
r"""
Define a pre-conditioned inverse multi-quadric (PCIMQ) kernel.

Given :math:`\lambda =``length_scale` and :math:`\rho =``output_scale`, the
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
PCIMQ kernel is defined as
:math:`k: \mathbb{R}^d\times \mathbb{R}^d \to \mathbb{R}`,
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \frac{\rho}{\sqrt{1 + \frac{||x-y||^2}{2 \lambda^2}}}
where :math:`||\cdot||` is the usual :math:`L_2`-norm.

Expand All @@ -319,17 +320,17 @@ def __check_init__(self):
raise ValueError("'output_scale' must be positive")

@override
def compute_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def compute_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
scaling = 2 * self.length_scale**2
mq_array = squared_distance(x, y) / scaling
return self.output_scale / jnp.sqrt(1 + mq_array)

@override
def grad_x_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def grad_x_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
return -self.grad_y_elementwise(x, y)

@override
def grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def grad_y_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
return (
self.output_scale
* jnp.subtract(x, y)
Expand All @@ -338,7 +339,7 @@ def grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
)

@override
def divergence_x_grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def divergence_x_grad_y_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
k = self.compute_elementwise(x, y) / self.output_scale
scale = 2 * self.length_scale**2
d = len(jnp.asarray(x))
Expand All @@ -353,9 +354,9 @@ class RationalQuadraticKernel(ScalarValuedKernel):
r"""
Define a rational quadratic kernel.

Given :math:`\lambda =``length_scale`, :math:`\rho =``output_scale`, and
:math:`\alpha =``relative_weighting`, the rational quadratic kernel is defined as
:math:`k: \mathbb{R}^d\times \mathbb{R}^d \to \mathbb{R}`,
Given :math:`\lambda =` ``length_scale``, :math:`\rho =` ``output_scale``, and
:math:`\alpha =` ``relative_weighting``, the rational quadratic kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * (1 + \frac{||x-y||^2}{2 \alpha \lambda^2})^{-\alpha}` where
:math:`||\cdot||` is the usual :math:`L_2`-norm.

Expand All @@ -382,7 +383,7 @@ def __check_init__(self):
raise ValueError("'relative_weighting' must be non-negative")

@override
def compute_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def compute_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
return (
self.output_scale
* (
Expand All @@ -398,16 +399,16 @@ def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)

@override
def grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def grad_y_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
return (self.output_scale * jnp.subtract(x, y) / self.length_scale**2) * (
1
+ squared_distance(x, y)
/ (2 * self.relative_weighting * self.length_scale**2)
) ** (-self.relative_weighting - 1)

@override
def divergence_x_grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
d = len(jnp.asarray(x))
def divergence_x_grad_y_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
d = len(jnp.atleast_1d(x))
sq_dist = squared_distance(x, y)
power = self.relative_weighting + 1
div = self.relative_weighting * self.length_scale**2
Expand All @@ -430,6 +431,7 @@ class MaternKernel(ScalarValuedKernel):
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,

.. math::

k(x, y) = \rho^2 * \exp\left(-\frac{\sqrt{2p+1}||x-y||}{\lambda}\right)
\frac{p!}{(2p)!}\sum_{i=0}^p\frac{(p+i)!}{i!(p-i)!}
\left(2\sqrt{2p+1}\frac{||x-y||}{\lambda}\right)^{p-i}
Expand Down Expand Up @@ -466,6 +468,7 @@ def _compute_summation_term(
Given :math:`p`=``degree``:math:`\in\mathbb{N}`, compute

.. math::

\gamma := \sum_{i=0}^p\frac{(p+i)!}{i!(p-i)!}
\left(2\sqrt{2p+1}\frac{||x-y||}{\lambda}\right)^{p-i}.

Expand Down Expand Up @@ -503,18 +506,19 @@ class PeriodicKernel(ScalarValuedKernel):
Define a periodic kernel.

Given :math:`\lambda =` ``length_scale``, :math:`\rho =` ``output_scale``, and
:math:`\p =` ``periodicity``, the periodic kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp ( \frac{ -2 \sin^2( \pi ||x-y|| / p ) }{ \lambda^2 } )`
where :math:`||\cdot||` is the usual :math:`L_2`-norm.
:math:`p =` ``periodicity``, the periodic kernel is defined as
:math:`k: \mathbb{R}^d\times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp(\frac{-2 \sin^2(\pi ||x-y||/p)}{\lambda^2})` where
:math:`||\cdot||` is the usual :math:`L_2`-norm.

.. warning::

.. Warning::
The periodic kernel is not differentiable when :math:`x=y`.
tp832944 marked this conversation as resolved.
Show resolved Hide resolved

tp832944 marked this conversation as resolved.
Show resolved Hide resolved
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
:param periodicity: Parameter controlling the periodicity of the kernel. :\math: `p`
:param periodicity: Parameter controlling the periodicity of the kernel :math:`p`
"""

length_scale: float = 1.0
Expand Down Expand Up @@ -593,11 +597,12 @@ class LocallyPeriodicKernel(ProductKernel):
Define a locally periodic kernel.

The periodic kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k: \mathbb{R}^d\times \mathbb{R}^d \to \mathbb{R}`,
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
:math:`k(x, y) = r(x,y)l(x,y)` where :math:`r` is the periodic kernel and
:math:`l` is the squared exponential kernel.

.. Warning::
.. warning::

The locally periodic kernel is not differentiable when :math:`x=y`.
tp832944 marked this conversation as resolved.
Show resolved Hide resolved

:param periodic_length_scale: Periodic kernel smoothing/bandwidth parameter
Expand Down Expand Up @@ -640,6 +645,7 @@ class PoissonKernel(ScalarValuedKernel):
:math:`k(x, y) = \frac{\rho}{1 - 2r\cos(x-y) + r^2}`.

.. warning::

Unlike many other kernels in Coreax, the Poisson kernel is not defined on
arbitrary :math:`\mathbb{R}^d`, but instead a subset of the positive real line
:math:`[0, 2\pi)`. We do not check that inputs to methods in this class lie in
Expand All @@ -664,19 +670,19 @@ def __check_init__(self):
raise ValueError("'output_scale' must be positive")

@override
def compute_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def compute_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
return self.output_scale / (
1
- 2 * self.index * jnp.cos(jnp.linalg.norm(jnp.subtract(x, y)))
+ self.index**2
)

@override
def grad_x_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def grad_x_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
return -self.grad_y_elementwise(x, y)

@override
def grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def grad_y_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
# Note that we do not take a norm here in order to maintain the dimensionality
# of the vectors x and y, this ensures calls to 'grad_y' and 'grad_x' have
# expected dimensionality.
Expand All @@ -686,7 +692,7 @@ def grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
) ** 2

@override
def divergence_x_grad_y_elementwise(self, x: ArrayLike, y: ArrayLike) -> Array:
def divergence_x_grad_y_elementwise(self, x, y):
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
distance = jnp.linalg.norm(jnp.subtract(x, y))
div = 1 - 2 * self.index * jnp.cos(distance) + self.index**2
first_term = (2 * self.output_scale * self.index * jnp.cos(distance)) / div**2
Expand Down
3 changes: 3 additions & 0 deletions coreax/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,7 @@
"RPCholesky",
"GreedyKernelPointsState",
"GreedyKernelPoints",
"RecombinationSolver",
"CaratheodoryRecombination",
"TreeRecombination",
]
9 changes: 5 additions & 4 deletions coreax/solvers/coresubset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
from coreax.util import KeyArrayLike, sample_batch_indices, tree_zero_pad_leading_axis

_Data = TypeVar("_Data", bound=Data)
_SupervisedData = TypeVar("_SupervisedData", bound=SupervisedData)


MSG = "'coreset_size' must be less than 'len(dataset)' by definition of a coreset"

Expand Down Expand Up @@ -91,7 +89,7 @@ def reduce(
p=selection_weights,
replace=not self.unique,
)
return Coresubset(random_indices, dataset), solver_state
return Coresubset(Data(random_indices), dataset), solver_state
except ValueError as err:
if self.coreset_size > len(dataset) and self.unique:
raise ValueError(MSG) from err
Expand Down Expand Up @@ -150,7 +148,10 @@ def _greedy_kernel_selection(
unroll=unroll,
)

def _greedy_body(i: int, val: tuple[Array, Array]) -> tuple[Array, ArrayLike]:
def _greedy_body(
i: int,
val: tuple[Shaped[Array, " coreset_size"], Shaped[Array, " n"]],
) -> tuple[Shaped[Array, " coreset_size"], Shaped[Array, " n"]]:
coreset_indices, kernel_similarity_penalty = val
valid_kernel_similarity_penalty = jnp.where(
weights > 0, kernel_similarity_penalty, jnp.nan
Expand Down
9 changes: 7 additions & 2 deletions tests/unit/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,15 @@ def test_getitem(self, data_type, index):
_expected_indexed_data = jtu.tree_map(lambda x: x[index], _data)
assert eqx.tree_equal(_data[index], _expected_indexed_data)

def test_arraylike(self, data_type):
def test_asarray(self, data_type):
"""Test interpreting data as a JAX array."""
_data = data_type()
assert eqx.tree_equal(jnp.asarray(_data), _data.data)
if isinstance(_data, coreax.data.SupervisedData):
assert eqx.tree_equal(
jnp.asarray(_data), jnp.hstack((_data.data, _data.supervision))
)
else:
assert eqx.tree_equal(jnp.asarray(_data), _data.data)

def test_len(self, data_type):
"""Test length of data."""
Expand Down