Skip to content

Commit

Permalink
chore: add kernel thinning to benchmarking script
Browse files Browse the repository at this point in the history
  • Loading branch information
qh681248 committed Jan 16, 2025
1 parent dd1332d commit 944063a
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 7 deletions.
14 changes: 14 additions & 0 deletions benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from coreax.metrics import KSD, MMD
from coreax.solvers import (
KernelHerding,
KernelThinning,
RandomSample,
RPCholesky,
Solver,
Expand Down Expand Up @@ -102,6 +103,7 @@ def setup_solvers(
coreset_size: int,
sq_exp_kernel: SquaredExponentialKernel,
stein_kernel: SteinKernel,
delta: float,
random_seed: int = 45,
) -> list[tuple[str, _Solver]]:
"""
Expand All @@ -110,12 +112,14 @@ def setup_solvers(
:param coreset_size: The size of the coresets to be generated by the solvers.
:param sq_exp_kernel: A Squared Exponential kernel for KernelHerding and RPCholesky.
:param stein_kernel: A Stein kernel object used for the SteinThinning solver.
:param delta: The delta parameter for KernelThinning solver.
:param random_seed: An integer seed for the random number generator.
:return: A list of tuples, where each tuple contains the name of the solver
and the corresponding solver object.
"""
random_key = jax.random.PRNGKey(random_seed)
sqrt_kernel = sq_exp_kernel.get_sqrt_kernel(2)
return [
(
"KernelHerding",
Expand All @@ -141,6 +145,16 @@ def setup_solvers(
regularise=False,
),
),
(
"KernelThinning",
KernelThinning(
coreset_size=coreset_size,
kernel=sq_exp_kernel,
random_key=random_key,
delta=delta,
sqrt_kernel=sqrt_kernel,
),
),
]


Expand Down
6 changes: 5 additions & 1 deletion benchmark/david_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from benchmark.mnist_benchmark import get_solver_name, initialise_solvers
from coreax import Data
from coreax.solvers import MapReduce
from examples.david_map_reduce_weighted import downsample_opencv

MAX_8BIT = 255
Expand All @@ -50,7 +51,7 @@
def benchmark_coreset_algorithms(
in_path: Path = Path("../examples/data/david_orig.png"),
out_path: Optional[Path] = Path("david_benchmark_results.png"),
downsampling_factor: int = 6,
downsampling_factor: int = 1,
):
"""
Benchmark the performance of coreset algorithms on a downsampled greyscale image.
Expand Down Expand Up @@ -93,6 +94,9 @@ def benchmark_coreset_algorithms(

for solver_creator in solver_factories:
solver = solver_creator(coreset_size)
# There is no need to use MapReduce as the data-size is small
if isinstance(solver, MapReduce):
solver = solver.base_solver
solver_name = get_solver_name(solver_creator)
start_time = time.perf_counter()
coreset, _ = eqx.filter_jit(solver.reduce)(data)
Expand Down
64 changes: 59 additions & 5 deletions benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import umap
from flax import linen as nn
from flax.training import train_state
from jaxtyping import Array, Float, Int
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

Expand All @@ -60,6 +61,7 @@
from coreax.score_matching import KernelDensityMatching
from coreax.solvers import (
KernelHerding,
KernelThinning,
MapReduce,
RandomSample,
RPCholesky,
Expand Down Expand Up @@ -426,6 +428,30 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr
return train_data_jax, train_targets_jax, test_data_jax, test_targets_jax


def calculate_delta(n: Int[Array, "1"]) -> Float[Array, "1"]:
"""
Calculate the delta parameter for kernel thinning.
The function evaluates the following cases:
1. If `jnp.log(n)` is positive:
- Further evaluates `jnp.log(jnp.log(n))`.
* If this is also positive, returns `1 / n * jnp.log(jnp.log(n))`.
* Otherwise, returns `1 / n * jnp.log(n)`.
2. If `jnp.log(n)` is negative:
- Returns `1 / n`.
:param n: The size of the dataset we wish to reduce.
:return: The calculated delta value based on the described conditions.
"""
log_n = jnp.log(n)
if log_n > 0:
log_log_n = jnp.log(log_n)
if log_log_n > 0:
return 1 / (n * log_log_n)
return 1 / (n * log_n)
return 1 / n


def initialise_solvers(
train_data_umap: Data, key: jax.random.PRNGKey
) -> list[Callable[[int], Solver]]:
Expand All @@ -451,6 +477,28 @@ def initialise_solvers(
idx = generator.choice(num_data_points, num_samples_length_scale, replace=False)
length_scale = median_heuristic(train_data_umap[idx])
kernel = SquaredExponentialKernel(length_scale=length_scale)
sqrt_kernel = kernel.get_sqrt_kernel(16)

def _get_thinning_solver(_size: int) -> MapReduce:
"""
Set up KernelThinning to use ``MapReduce``.
Create a KernelThinning solver with the specified size and return
it along with a MapReduce object for reducing a large dataset like
MNIST dataset.
:param _size: The size of the coreset to be generated.
:return: MapReduce solver with KernelThinning as the base solver.
"""
thinning_solver = KernelThinning(
coreset_size=_size,
kernel=kernel,
random_key=key,
delta=calculate_delta(num_data_points),
sqrt_kernel=sqrt_kernel,
)

return thinning_solver

def _get_herding_solver(_size: int) -> MapReduce:
"""
Expand All @@ -461,7 +509,7 @@ def _get_herding_solver(_size: int) -> MapReduce:
MNIST dataset.
:param _size: The size of the coreset to be generated.
:return: A tuple containing the solver name and the MapReduce solver.
:return: MapReduce solver with KernelHerding as the base solver.
"""
herding_solver = KernelHerding(_size, kernel)
return MapReduce(herding_solver, leaf_size=3 * _size)
Expand All @@ -475,7 +523,7 @@ def _get_stein_solver(_size: int) -> MapReduce:
a subset of the dataset.
:param _size: The size of the coreset to be generated.
:return: A tuple containing the solver name and the MapReduce solver.
:return: MapReduce solver with SteinThinning as the base solver.
"""
# Generate small dataset for ScoreMatching for Stein Kernel

Expand All @@ -493,7 +541,7 @@ def _get_random_solver(_size: int) -> RandomSample:
Set up Random Sampling to generate a coreset.
:param _size: The size of the coreset to be generated.
:return: A tuple containing the solver name and the RandomSample solver.
:return: A RandomSample solver.
"""
random_solver = RandomSample(_size, key)
return random_solver
Expand All @@ -503,12 +551,18 @@ def _get_rp_solver(_size: int) -> RPCholesky:
Set up Randomised Cholesky solver.
:param _size: The size of the coreset to be generated.
:return: A tuple containing the solver name and the RPCholesky solver.
:return: A RPCholesky solver.
"""
rp_solver = RPCholesky(coreset_size=_size, kernel=kernel, random_key=key)
return rp_solver

return [_get_random_solver, _get_rp_solver, _get_herding_solver, _get_stein_solver]
return [
_get_random_solver,
_get_rp_solver,
_get_herding_solver,
_get_stein_solver,
_get_thinning_solver,
]


def train_model(
Expand Down
2 changes: 1 addition & 1 deletion coreax/solvers/coresubset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ def probabilistic_swap(

prob = jax.random.uniform(key1)
return lax.cond(
prob < 1 / 2 * (1 - alpha / a),
prob > 1 / 2 * (1 - alpha / a),
lambda _: (2 * i, 2 * i + 1), # first case: val1 = x1, val2 = x2
lambda _: (2 * i + 1, 2 * i), # second case: val1 = x2, val2 = x1
None,
Expand Down

0 comments on commit 944063a

Please sign in to comment.