Skip to content

Commit 82ab5ec

Browse files
authored
Merge pull request #759 from gchq/docs/reorder-files
docs: reordered docs for solvers and kernels
2 parents 41b3922 + 24a7fbc commit 82ab5ec

File tree

6 files changed

+478
-455
lines changed

6 files changed

+478
-455
lines changed

coreax/data.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import jax.numpy as jnp
2121
import jax.tree_util as jtu
2222
from jax import jit
23-
from jaxtyping import Array, ArrayLike, Shaped
23+
from jaxtyping import Array, Shaped
2424
from typing_extensions import Self
2525

2626

@@ -134,8 +134,30 @@ def __getitem__(self, key) -> Self:
134134
"""Support `Array` style indexing of `Data` objects."""
135135
return jtu.tree_map(lambda x: x[key], self)
136136

137-
def __jax_array__(self) -> Shaped[ArrayLike, " n d"]:
138-
"""Register `ArrayLike` behaviour - return for `jnp.asarray(Data(...))`."""
137+
@overload
138+
def __jax_array__(
139+
self: "Data",
140+
) -> Shaped[Array, " n d"]: ...
141+
142+
@overload
143+
def __jax_array__( # pyright:ignore[reportOverlappingOverload]
144+
self: "SupervisedData",
145+
) -> Shaped[Array, " n d+p"]: ...
146+
147+
def __jax_array__(
148+
self: Union["Data", "SupervisedData"],
149+
) -> Union[Shaped[Array, " n d"], Shaped[Array, " n d+p"]]:
150+
"""
151+
Return value of `jnp.asarray(Data(...))` and `jnp.asarray(SupervisedData(...))`.
152+
153+
.. note::
154+
155+
When ``self`` is a `SupervisedData` instance `jnp.asarray` will return
156+
a single array where the ``supervision`` array has been
157+
right-concatenated onto the``data`` array.
158+
"""
159+
if isinstance(self, SupervisedData):
160+
return jnp.hstack((self.data, self.supervision))
139161
return self.data
140162

141163
def __len__(self) -> int:

coreax/kernels/__init__.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,23 @@
3939
from coreax.kernels.util import median_heuristic
4040

4141
__all__ = [
42+
"median_heuristic",
4243
"ScalarValuedKernel",
44+
"UniCompositeKernel",
45+
"PowerKernel",
46+
"DuoCompositeKernel",
47+
"AdditiveKernel",
48+
"ProductKernel",
49+
"LinearKernel",
50+
"PolynomialKernel",
4351
"ExponentialKernel",
4452
"LaplacianKernel",
45-
"LinearKernel",
46-
"LocallyPeriodicKernel",
53+
"SquaredExponentialKernel",
4754
"PCIMQKernel",
48-
"PeriodicKernel",
49-
"PolynomialKernel",
5055
"RationalQuadraticKernel",
51-
"SquaredExponentialKernel",
52-
"AdditiveKernel",
53-
"ProductKernel",
54-
"SteinKernel",
55-
"median_heuristic",
56-
"DuoCompositeKernel",
57-
"UniCompositeKernel",
58-
"PowerKernel",
59-
"PoissonKernel",
6056
"MaternKernel",
57+
"PeriodicKernel",
58+
"LocallyPeriodicKernel",
59+
"PoissonKernel",
60+
"SteinKernel",
6161
]

0 commit comments

Comments
 (0)