Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/tree recombination #504

Merged
merged 15 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .cspell/people.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@ Ferenc
Frobenius
Garg
Garreau
Goaoc
Halko
Helly
Huszar
Jaehoon
Jiaxin
Jitkrittum
Kanagawa
Litterer
Loera
Lyons
Martinsson
Meunier
Motonobu
Nabil
Nystr
Nystrom
Qiang
Expand All @@ -28,10 +33,13 @@ Sahaj
Schreuder
Smirnov
Smola
Sperner
Staber
Tchakaloff
Tchernychova
Teichmann
Tropp
Tverberg
Veiga
Wittawat
Yifan
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `RecombinationSolver`: an abstract base class for recombination solvers.
- `CaratheodoryRecombination`: a simple deterministic approach to solving recombination problems.
- `TreeRecombination`: an advanced deterministic approach that utilises `CaratheodoryRecombination`,
but provides superior performance for solving all but the smallest recombination problems.
but is faster for solving all but the smallest recombination problems.
- Added supervised coreset construction algorithm in `coreax.solvers.GreedyKernelPoints`
- Added `coreax.kernels.PowerKernel` to replace repeated calls of `coreax.kernels.ProductKernel`
within the `**` magic method of `coreax.kernel.ScalarValuedKernel`
Expand Down
93 changes: 66 additions & 27 deletions coreax/solvers/recombination.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
.. math::
\begin{align}
\mathbf{Y} \mathbf{\hat{w}} &= \text{CoM}(\mu_n),\\
\mathbf{\hat{w}} &\ge \mathbf{0},
\mathbf{\hat{w}} &\ge 0,
\end{align}

where the system variables and "centre-of-mass" are defined as
Expand Down Expand Up @@ -65,10 +65,11 @@
I = \{i \mid \hat{w_i} \neq 0\, \forall i \in \{1, \dots, n\}\}.
\end{gather}

Due to Tchakaloff's theorem, which follows from Caratheodory's convex hull theorem, we
know there always exists a basic-feasible solution to the linear-program, with at most
:math:`m^\prime = \text{dim}(\text{span}(\Phi))` non-zero weights. Hence, we have an
upper bound on the size of a coresubset, controlled by the choice of test-functions.
Due to Tchakaloff's theorem :cite:`tchakaloff1957,bayer2006tchakaloff`, which follows
from Caratheodory's convex hull theorem :cite:`caratheodory1907,loera2018caratheodory`,
we know there always exists a basic-feasible solution to the linear-program, with at
most :math:`m^\prime = \text{dim}(\text{span}(\Phi))` non-zero weights. Hence, we have
an upper bound on the size of a coresubset, controlled by the choice of test-functions.

.. note::
A basic feasible solution (coresubset produced by recombination) is non-unique. In
Expand Down Expand Up @@ -105,18 +106,18 @@

class RecombinationSolver(CoresubsetSolver[_Data, _State], Generic[_Data, _State]):
r"""
Solver which returns a :class:`coreax.coreset.Coresubset` via recombination.
Solver which returns a :class:`~coreax.coreset.Coresubset` via recombination.

Given :math:`m-1` explicitly provided test-functions :math:`\Phi^\prime`, a
recombination solver finds a coresubset with :math:`m^\prime \le m` points, whose
push-forward :math:`\hat\{mu}_{\m^\prime}` has the same "centre-of-mass" as the
dataset push-forward :math:`\mu_n := \Phi_* \nu_n`.
push-forward :math:`\hat{\mu}_{m^\prime}` has the same "centre-of-mass" (CoM) as
the dataset push-forward :math:`\mu_n := \Phi_* \nu_n`.

:param test_functions: A callable that applies a set of specified test-functions
db091756 marked this conversation as resolved.
Show resolved Hide resolved
:math:`\Phi^\prime = \{\phi_1,\dots,\phi_{m-1}\}` where each function is a map
:math:`\phi_i \colon \Omega\to\mathbb{R}`; a value of none implies the identity
map :math:`\Phi^\prime \colon x \mapsto x`, and necessarily assumes that
:math:`x \in \Omega \subseteq \mathbb{R}^{m-1}`
:math:`\phi_i \colon \Omega\to\mathbb{R}`; a value of :data:`None` implies the
identity map :math:`\Phi^\prime \colon x \mapsto x`, and necessarily assumes
that :math:`x \in \Omega \subseteq \mathbb{R}^{m-1}`
:param mode: 'implicit-explicit' explicitly removes :math:`n - m` points, yielding
a coreset of size :math:`m`, with :math:`m - m^\prime` zero-weighted (implicitly
removed) points; 'implicit' explicitly removes no points, yielding a coreset of
Expand All @@ -126,7 +127,7 @@ class RecombinationSolver(CoresubsetSolver[_Data, _State], Generic[_Data, _State
compatible as the coreset size :math:`m^\prime` is unknown at compile time.
"""

test_functions: Union[Callable[[Omega], Real[Array, " m-1"]], None] = None
test_functions: Optional[Callable[[Omega], Real[Array, " m-1"]]] = None
mode: Literal["implicit-explicit", "implicit", "explicit"] = "implicit-explicit"

def __check_init__(self):
Expand Down Expand Up @@ -172,10 +173,10 @@ class CaratheodoryRecombination(RecombinationSolver[Data, None]):
compatible as the coreset size :math:`m^\prime` is unknown at compile time.
:param rcond: A relative condition number; any singular value :math:`s` below the
threshold :math:`\text{rcond} * \text{max}(s)` is treated as equal to zero; if
:code:`rcond is None`, it defaults to `floating point eps * max(n, d)`
rcond is :data:`None`, it defaults to `floating point eps * max(n, d)`
"""

rcond: Union[float, None] = None
rcond: Optional[float] = None

@override
def reduce(
Expand Down Expand Up @@ -206,6 +207,9 @@ def _eliminate_cond(state: _EliminationState) -> bool:
vectors, is due to the dimension of the null space being unknown at JIT
compile time, preventing us from slicing the left singular vectors down
to only those which form a basis for the left null space.
db091756 marked this conversation as resolved.
Show resolved Hide resolved

:param state: Elimination state information
:return: Boolean indicating if to continue/exit the elimination loop.
"""
*_, basis_index = state
return basis_index < null_space_rank
Expand All @@ -222,20 +226,36 @@ def _eliminate(state: _EliminationState) -> _EliminationState:
If the procedure is repeated until all the left null space basis vectors
are eliminated, the resulting weights (when combined with the original
nodes) are a BFS to the recombination problem/linear-program.
db091756 marked this conversation as resolved.
Show resolved Hide resolved

:param state: Elimination state information
:return: Updated `state` information resulting from the elimination step.
"""
# Algorithm 6 - Chapter 3.3 of :cite:`tchernychova2016recombination`
# Our Notation -> Their Notation
# - `basis_index` (loop iteration) -> i
# - `elimination_index` -> k^{(i)}
# - `elimination_rescaling_factor` -> \alpha_{(i)}
# - `updated_weights` -> \underline\Beta^{(i)}
# - `null_space_basis_update` -> d_{l+1}^{(i)}\phi_1^{(i-1)}
# - `updated_null_space_basis` -> \Psi^{(i))
_weights, null_space_basis, basis_index = state
basis_vector = null_space_basis[basis_index]
_elimination_condition = _weights / basis_vector
# Equation 3: Select the weight to eliminate.
elimination_condition = jnp.where(
basis_vector > 0, _elimination_condition, jnp.inf
basis_vector > 0, _weights / basis_vector, jnp.inf
)
elimination_index = jnp.argmin(elimination_condition)
elimination_rescaling_factor = elimination_condition[elimination_index]
# Equation 4: Eliminate the selected weight and redistribute its mass.
# NOTE: Equation 5 is implicit from Equation 4 and is performed outside
# of `_eliminate` via `_coresubset_nodes`.
updated_weights = _weights - elimination_rescaling_factor * basis_vector
updated_weights = updated_weights.at[elimination_index].set(0)
rescaled_basis_vector = basis_vector / basis_vector[elimination_index]
# Equations 6, 7 and 8: Update the Null space basis.
null_space_basis_update = jnp.tensordot(
null_space_basis[:, elimination_index], rescaled_basis_vector, axes=0
null_space_basis[:, elimination_index],
basis_vector / basis_vector[elimination_index],
axes=0,
)
updated_null_space_basis = null_space_basis - null_space_basis_update
updated_null_space_basis = updated_null_space_basis.at[basis_index].set(0)
Expand All @@ -259,7 +279,7 @@ def _eliminate(state: _EliminationState) -> _EliminationState:

def _push_forward(
nodes: Shaped[Array, "n d"],
test_functions: Union[Callable[[Omega], Real[Array, " m-1"]], None],
test_functions: Optional[Callable[[Omega], Real[Array, " m-1"]]],
db091756 marked this conversation as resolved.
Show resolved Hide resolved
augment: bool = True,
) -> Shaped[Array, "n m"]:
r"""
Expand All @@ -271,9 +291,9 @@ def _push_forward(
:math:`\phi_i \colon \Omega \to \mathbb{R}`; a value of non implies the identity
map :math:`\Phi^\prime \colon x \mapsto x`, and necessarily assumes that
:math:`x \in \Omega \subseteq \mathbb{R}^{m-1}`
:param augment: If to prepend prepend the affine-augmentation test function
:param augment: If to prepend the affine-augmentation test function
:math:`\{x \mapsto 1\}` to the explicitly pushed forward nodes \Phi^\prime(x),
to yield \Phi(x)
to yield \Phi(x); default behaviour prepends the affine-augmentation function
:return: The pushed-forward nodes.
"""
if test_functions is None:
Expand Down Expand Up @@ -320,10 +340,14 @@ def _co_linearize(
non_zero_weights_mask = weights > 0
db091756 marked this conversation as resolved.
Show resolved Hide resolved
zero_weights_mask = 1 - non_zero_weights_mask
n_zeros = zero_weights_mask.sum()
weights = weights.at[max_index].divide(n_zeros + 1)
# Create a new set of indices that replace the zero-weighted node indices with the
# maximum weighted node's index.
indices = jnp.arange(weights.shape[0])
indices *= non_zero_weights_mask
indices += zero_weights_mask * max_index
# Renormalize the maximum weight; ensures the weight sum is preserved under the new
# (co-linearized) indices; prevents co-linearization from changing the weight sum.
weights = weights.at[max_index].divide(n_zeros + 1)
db091756 marked this conversation as resolved.
Show resolved Hide resolved
return nodes[indices], weights[indices], indices


Expand Down Expand Up @@ -444,9 +468,9 @@ class TreeRecombination(RecombinationSolver[Data, None]):
introduced in :cite:`litterer2012recombination`.

The time complexity is of order :math:`\mathcal{O}(\log_2(\frac{n}{c_r m}) m^3)`,
where `c_r = tree_reduction_factor`. The time complexity can be equivalently
expressed as :math:`\mathcal{O}(m^3)`, using the same arguments as used in
:class:`CaratheodoryRecombination`.
where :math`c_r` is the `tree_reduction_factor`. The time complexity can be
equivalently expressed as :math:`\mathcal{O}(m^3)`, using the same arguments as used
in :class:`CaratheodoryRecombination`.

..note::
As the ratio of :math:`n / m` grows, the constant factor for the time complexity
Expand All @@ -457,7 +481,7 @@ class TreeRecombination(RecombinationSolver[Data, None]):

:param test_functions: the map :math:`\Phi^\prime = \{ \phi_1, \dots, \phi_{M-1} \}`
where each :math:`\phi_i \colon \Omega \to \mathbb{R}` represents a linearly
independent test-function; a value of `None` implies the identity function
independent test-function; a value of :data:`None` implies the identity function
(necessarily assuming :math:`\Omega \subseteq \mathbb{R}^{M-1}`)
:param mode: 'implicit-explicit' explicitly removes :math:`n - m` points, yielding
a coreset of size :math:`m`, with :math:`m - m^\prime` zero-weighted (implicitly
Expand Down Expand Up @@ -491,7 +515,7 @@ def reduce(
padding, count, depth = _prepare_tree(n, m + 1, self.tree_reduction_factor)
car_recomb_solver = CaratheodoryRecombination(rcond=self.rcond, mode="implicit")

def _tree_reduce(_, state):
def _tree_reduce(_, state: tuple[Array, Array]) -> tuple[Array, Array]:
"""
Apply Tree-Based Caratheodory Recombination (Gaussian-Elimination).

Expand All @@ -504,20 +528,35 @@ def _tree_reduce(_, state):
number of remaining clusters down to 'm'. We can repeat the process until
each cluster contains, at most, a single non-zero weighted point (at this
point the recombination problem has been solved).
db091756 marked this conversation as resolved.
Show resolved Hide resolved

:param _: Not used
:param state: Tuple of node weights and indices; indices are passed to keep
a correspondence between the original data indices and
:return: Updated tuple of node weights and indices; weights are zeroed
(implicitly removed) where appropriate; indices are shuffled to ensure
balanced centroids in subsequent iterations (centroids are balanced when
they are all constructed from subsets with as near to an equal number
of non-zero weighted nodes as possible).
"""
_weights, _indices = state
# Index weights to a centroid; argsort ensures that centroids are balanced.
centroid_indices = jnp.argsort(_weights).reshape(count, -1, order="F")
centroid_nodes, centroid_weights = _centroid(
push_forward_nodes[_indices[centroid_indices]],
_weights[centroid_indices],
)
centroid_dataset = Data(centroid_nodes, centroid_weights)
# Solve the measure reduction problem on the centroid dataset.
centroid_coresubset, _ = car_recomb_solver.reduce(centroid_dataset)
coresubset_indices = centroid_coresubset.unweighted_indices
coresubset_weights = centroid_coresubset.coreset.weights
# Propagate centroid coresubset weights to the underlying weights for each
# centroid, as defined by `centroid_indices`.
weight_update_indices = centroid_indices[coresubset_indices]
weight_update = coresubset_weights / centroid_weights[coresubset_indices]
updated_weights = _weights[weight_update_indices] * weight_update[..., None]
# Maintain a correspondence between the original data indices and the sorted
# indices, used to construct the balanced centroids.
updated_indices = _indices[weight_update_indices.reshape(-1, order="F")]
return updated_weights.reshape(-1, order="F"), updated_indices

Expand Down
2 changes: 2 additions & 0 deletions documentation/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@
("py:obj", "coreax.solvers.coresubset._Data"),
("py:obj", "coreax.solvers.coresubset._State"),
("py:obj", "coreax.solvers.coresubset._Coreset"),
("py:obj", "coreax.solvers.recombination._Data"),
("py:obj", "coreax.solvers.recombination._State"),
("py:obj", "coreax.weights._Data"),
("py:obj", "coreax.metrics._Data"),
("py:obj", "coreax.solvers.coresubset._SupervisedData"),
Expand Down
45 changes: 45 additions & 0 deletions documentation/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,48 @@ @phdthesis{tchernychova2016recombination
year = {2016},
url={https://ora.ox.ac.uk/objects/uuid:a3a10980-d35d-467b-b3c0-d10d2e491f2d}
}

// cSpell:disable - French,
@article{tchakaloff1957,
title = {Formules de cubatures mécaniques à coefficients non négatifs},
author = {Tchakaloff, V},
year = {1957},
journal = {Bulletin des Sciences Mathématiques},
number = {2},
volume = {81},
pages = {123--134}
}
// cSpell:enable

@article{bayer2006tchakaloff,
title = {The proof of Tchakaloff's Theorem},
author = {Bayer, C. and Teichmann, J.},
year = {2006},
journal = {Proceedings of the American Mathematical Society},
volume = {134},
pages = {3035--3040},
url = {https://doi.org/10.1090/S0002-9939-06-08249-9}
}

// cSpell:disable - German,
@article{caratheodory1907,
title = {Über den Variabilitätsbereich der Koeffizienten von Potenzreihen, die gegebene Werte nicht annehmen},
author = {Carathéodory, C.},
year = {1907},
journal = {Mathematische Annalen},
volume = {64},
issue = {1},
pages = {95--115},
url = {https://doi.org/10.1007/BF01449883}
}
// cSpell:enable

@misc{loera2018caratheodory,
title = {The discrete yet ubiquitous theorems of Carathéodory, Helly, Sperner, Tucker, and Tverberg},
author = {Jesus A. De Loera and Xavier Goaoc and Frédéric Meunier and Nabil Mustafa},
year = {2018},
eprint = {1706.05975},
archivePrefix = {arXiv},
primaryClass = {math.CO},
url = {https://arxiv.org/abs/1706.05975},
}
2 changes: 1 addition & 1 deletion tests/unit/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def check_solution_invariants(
1. Check 'sum(coreset.weights)' is one.
1. Check 'len(coreset)' is less than or equal to the upper bound `m`.
2. Check 'len(coreset[idx]) where idx = jnp.nonzero(coreset.weights)' is less
than or equal to the rank, `m^\prime`, of the pushed forward nodes.
than or equal to the rank, :math:`m^\prime`, of the pushed forward nodes.
3. Check the push-forward of the coreset preserves the "centre-of-mass" (CoM) of
the pushed-forward dataset (with implicit and explicit zero weight removal).
4. Check the default value of 'test_functions' is the identity map.
Expand Down