Skip to content

Commit

Permalink
Merge pull request #759 from gchq/docs/reorder-files
Browse files Browse the repository at this point in the history
docs: reordered docs for solvers and kernels
  • Loading branch information
tp832944 authored Oct 2, 2024
2 parents 41b3922 + 24a7fbc commit 82ab5ec
Show file tree
Hide file tree
Showing 6 changed files with 478 additions and 455 deletions.
28 changes: 25 additions & 3 deletions coreax/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import jit
from jaxtyping import Array, ArrayLike, Shaped
from jaxtyping import Array, Shaped
from typing_extensions import Self


Expand Down Expand Up @@ -134,8 +134,30 @@ def __getitem__(self, key) -> Self:
"""Support `Array` style indexing of `Data` objects."""
return jtu.tree_map(lambda x: x[key], self)

def __jax_array__(self) -> Shaped[ArrayLike, " n d"]:
"""Register `ArrayLike` behaviour - return for `jnp.asarray(Data(...))`."""
@overload
def __jax_array__(
self: "Data",
) -> Shaped[Array, " n d"]: ...

@overload
def __jax_array__( # pyright:ignore[reportOverlappingOverload]
self: "SupervisedData",
) -> Shaped[Array, " n d+p"]: ...

def __jax_array__(
self: Union["Data", "SupervisedData"],
) -> Union[Shaped[Array, " n d"], Shaped[Array, " n d+p"]]:
"""
Return value of `jnp.asarray(Data(...))` and `jnp.asarray(SupervisedData(...))`.
.. note::
When ``self`` is a `SupervisedData` instance `jnp.asarray` will return
a single array where the ``supervision`` array has been
right-concatenated onto the``data`` array.
"""
if isinstance(self, SupervisedData):
return jnp.hstack((self.data, self.supervision))
return self.data

def __len__(self) -> int:
Expand Down
26 changes: 13 additions & 13 deletions coreax/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@
from coreax.kernels.util import median_heuristic

__all__ = [
"median_heuristic",
"ScalarValuedKernel",
"UniCompositeKernel",
"PowerKernel",
"DuoCompositeKernel",
"AdditiveKernel",
"ProductKernel",
"LinearKernel",
"PolynomialKernel",
"ExponentialKernel",
"LaplacianKernel",
"LinearKernel",
"LocallyPeriodicKernel",
"SquaredExponentialKernel",
"PCIMQKernel",
"PeriodicKernel",
"PolynomialKernel",
"RationalQuadraticKernel",
"SquaredExponentialKernel",
"AdditiveKernel",
"ProductKernel",
"SteinKernel",
"median_heuristic",
"DuoCompositeKernel",
"UniCompositeKernel",
"PowerKernel",
"PoissonKernel",
"MaternKernel",
"PeriodicKernel",
"LocallyPeriodicKernel",
"PoissonKernel",
"SteinKernel",
]
Loading

0 comments on commit 82ab5ec

Please sign in to comment.