Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add kernel thinning to benchmarking #927

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from coreax.metrics import KSD, MMD
from coreax.solvers import (
KernelHerding,
KernelThinning,
RandomSample,
RPCholesky,
Solver,
Expand Down Expand Up @@ -102,6 +103,7 @@ def setup_solvers(
coreset_size: int,
sq_exp_kernel: SquaredExponentialKernel,
stein_kernel: SteinKernel,
delta: float,
random_seed: int = 45,
) -> list[tuple[str, _Solver]]:
"""
Expand All @@ -110,12 +112,14 @@ def setup_solvers(
:param coreset_size: The size of the coresets to be generated by the solvers.
:param sq_exp_kernel: A Squared Exponential kernel for KernelHerding and RPCholesky.
:param stein_kernel: A Stein kernel object used for the SteinThinning solver.
:param delta: The delta parameter for KernelThinning solver.
:param random_seed: An integer seed for the random number generator.

:return: A list of tuples, where each tuple contains the name of the solver
and the corresponding solver object.
"""
random_key = jax.random.PRNGKey(random_seed)
sqrt_kernel = sq_exp_kernel.get_sqrt_kernel(2)
return [
(
"KernelHerding",
Expand All @@ -141,6 +145,16 @@ def setup_solvers(
regularise=False,
),
),
(
"KernelThinning",
KernelThinning(
coreset_size=coreset_size,
kernel=sq_exp_kernel,
random_key=random_key,
delta=delta,
sqrt_kernel=sqrt_kernel,
),
),
]


Expand Down
7 changes: 5 additions & 2 deletions benchmark/david_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"""

import os
import sys
import time
from pathlib import Path
from typing import Optional
Expand All @@ -38,11 +39,14 @@
import matplotlib.pyplot as plt
import numpy as np
from jax import random
from mnist_benchmark import get_solver_name, initialise_solvers

from benchmark.mnist_benchmark import get_solver_name, initialise_solvers
from coreax import Data
from examples.david_map_reduce_weighted import downsample_opencv

sys.path.append(str(Path(__file__).parent.parent))


MAX_8BIT = 255


Expand All @@ -65,7 +69,6 @@ def benchmark_coreset_algorithms(
"""
# Base directory of the current script
base_dir = os.path.dirname(os.path.abspath(__file__))

# Convert to absolute paths using os.path.join
if not in_path.is_absolute():
in_path = Path(os.path.join(base_dir, in_path))
Expand Down
68 changes: 61 additions & 7 deletions benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import umap
from flax import linen as nn
from flax.training import train_state
from jaxtyping import Array, Float, Int
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

Expand All @@ -60,6 +61,7 @@
from coreax.score_matching import KernelDensityMatching
from coreax.solvers import (
KernelHerding,
KernelThinning,
MapReduce,
RandomSample,
RPCholesky,
Expand Down Expand Up @@ -426,6 +428,30 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr
return train_data_jax, train_targets_jax, test_data_jax, test_targets_jax


def calculate_delta(n: Int[Array, "1"]) -> Float[Array, "1"]:
"""
Calculate the delta parameter for kernel thinning.

The function evaluates the following cases:
1. If `jnp.log(n)` is positive:
- Further evaluates `jnp.log(jnp.log(n))`.
* If this is also positive, returns `1 / n * jnp.log(jnp.log(n))`.
* Otherwise, returns `1 / n * jnp.log(n)`.
2. If `jnp.log(n)` is negative:
- Returns `1 / n`.

:param n: The size of the dataset we wish to reduce.
:return: The calculated delta value based on the described conditions.
"""
log_n = jnp.log(n)
if log_n > 0:
log_log_n = jnp.log(log_n)
if log_log_n > 0:
return 1 / (n * log_log_n)
return 1 / (n * log_n)
return 1 / n


def initialise_solvers(
train_data_umap: Data, key: jax.random.PRNGKey
) -> list[Callable[[int], Solver]]:
Expand All @@ -451,6 +477,28 @@ def initialise_solvers(
idx = generator.choice(num_data_points, num_samples_length_scale, replace=False)
length_scale = median_heuristic(train_data_umap[idx])
kernel = SquaredExponentialKernel(length_scale=length_scale)
sqrt_kernel = kernel.get_sqrt_kernel(16)

def _get_thinning_solver(_size: int) -> MapReduce:
"""
Set up KernelThinning to use ``MapReduce``.

Create a KernelThinning solver with the specified size and return
it along with a MapReduce object for reducing a large dataset like
MNIST dataset.

:param _size: The size of the coreset to be generated.
:return: MapReduce solver with KernelThinning as the base solver.
"""
thinning_solver = KernelThinning(
coreset_size=_size,
kernel=kernel,
random_key=key,
delta=calculate_delta(num_data_points),
sqrt_kernel=sqrt_kernel,
)

return MapReduce(thinning_solver, leaf_size=15_000)

def _get_herding_solver(_size: int) -> MapReduce:
"""
Expand All @@ -461,10 +509,10 @@ def _get_herding_solver(_size: int) -> MapReduce:
MNIST dataset.

:param _size: The size of the coreset to be generated.
:return: A tuple containing the solver name and the MapReduce solver.
:return: MapReduce solver with KernelHerding as the base solver.
"""
herding_solver = KernelHerding(_size, kernel)
return MapReduce(herding_solver, leaf_size=3 * _size)
return MapReduce(herding_solver, leaf_size=15_000)

def _get_stein_solver(_size: int) -> MapReduce:
"""
Expand All @@ -475,7 +523,7 @@ def _get_stein_solver(_size: int) -> MapReduce:
a subset of the dataset.

:param _size: The size of the coreset to be generated.
:return: A tuple containing the solver name and the MapReduce solver.
:return: MapReduce solver with SteinThinning as the base solver.
"""
# Generate small dataset for ScoreMatching for Stein Kernel

Expand All @@ -486,14 +534,14 @@ def _get_stein_solver(_size: int) -> MapReduce:
stein_solver = SteinThinning(
coreset_size=_size, kernel=stein_kernel, regularise=False
)
return MapReduce(stein_solver, leaf_size=3 * _size)
return MapReduce(stein_solver, leaf_size=15_000)

def _get_random_solver(_size: int) -> RandomSample:
"""
Set up Random Sampling to generate a coreset.

:param _size: The size of the coreset to be generated.
:return: A tuple containing the solver name and the RandomSample solver.
:return: A RandomSample solver.
"""
random_solver = RandomSample(_size, key)
return random_solver
Expand All @@ -503,12 +551,18 @@ def _get_rp_solver(_size: int) -> RPCholesky:
Set up Randomised Cholesky solver.

:param _size: The size of the coreset to be generated.
:return: A tuple containing the solver name and the RPCholesky solver.
:return: A RPCholesky solver.
"""
rp_solver = RPCholesky(coreset_size=_size, kernel=kernel, random_key=key)
return rp_solver

return [_get_random_solver, _get_rp_solver, _get_herding_solver, _get_stein_solver]
return [
_get_random_solver,
_get_rp_solver,
_get_herding_solver,
_get_stein_solver,
_get_thinning_solver,
]


def train_model(
Expand Down
Loading
Loading