From 944063afa17c9524b16a30b0e812b34e354962bc Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Thu, 16 Jan 2025 16:38:56 +0000 Subject: [PATCH] chore: add kernel thinning to benchmarking script #919 --- benchmark/blobs_benchmark.py | 14 ++++++++ benchmark/david_benchmark.py | 6 +++- benchmark/mnist_benchmark.py | 64 +++++++++++++++++++++++++++++++++--- coreax/solvers/coresubset.py | 2 +- 4 files changed, 79 insertions(+), 7 deletions(-) diff --git a/benchmark/blobs_benchmark.py b/benchmark/blobs_benchmark.py index 029adcae3..5fd9b0311 100644 --- a/benchmark/blobs_benchmark.py +++ b/benchmark/blobs_benchmark.py @@ -47,6 +47,7 @@ from coreax.metrics import KSD, MMD from coreax.solvers import ( KernelHerding, + KernelThinning, RandomSample, RPCholesky, Solver, @@ -102,6 +103,7 @@ def setup_solvers( coreset_size: int, sq_exp_kernel: SquaredExponentialKernel, stein_kernel: SteinKernel, + delta: float, random_seed: int = 45, ) -> list[tuple[str, _Solver]]: """ @@ -110,12 +112,14 @@ def setup_solvers( :param coreset_size: The size of the coresets to be generated by the solvers. :param sq_exp_kernel: A Squared Exponential kernel for KernelHerding and RPCholesky. :param stein_kernel: A Stein kernel object used for the SteinThinning solver. + :param delta: The delta parameter for KernelThinning solver. :param random_seed: An integer seed for the random number generator. :return: A list of tuples, where each tuple contains the name of the solver and the corresponding solver object. """ random_key = jax.random.PRNGKey(random_seed) + sqrt_kernel = sq_exp_kernel.get_sqrt_kernel(2) return [ ( "KernelHerding", @@ -141,6 +145,16 @@ def setup_solvers( regularise=False, ), ), + ( + "KernelThinning", + KernelThinning( + coreset_size=coreset_size, + kernel=sq_exp_kernel, + random_key=random_key, + delta=delta, + sqrt_kernel=sqrt_kernel, + ), + ), ] diff --git a/benchmark/david_benchmark.py b/benchmark/david_benchmark.py index 2734d54d0..fbbdaf622 100644 --- a/benchmark/david_benchmark.py +++ b/benchmark/david_benchmark.py @@ -41,6 +41,7 @@ from benchmark.mnist_benchmark import get_solver_name, initialise_solvers from coreax import Data +from coreax.solvers import MapReduce from examples.david_map_reduce_weighted import downsample_opencv MAX_8BIT = 255 @@ -50,7 +51,7 @@ def benchmark_coreset_algorithms( in_path: Path = Path("../examples/data/david_orig.png"), out_path: Optional[Path] = Path("david_benchmark_results.png"), - downsampling_factor: int = 6, + downsampling_factor: int = 1, ): """ Benchmark the performance of coreset algorithms on a downsampled greyscale image. @@ -93,6 +94,9 @@ def benchmark_coreset_algorithms( for solver_creator in solver_factories: solver = solver_creator(coreset_size) + # There is no need to use MapReduce as the data-size is small + if isinstance(solver, MapReduce): + solver = solver.base_solver solver_name = get_solver_name(solver_creator) start_time = time.perf_counter() coreset, _ = eqx.filter_jit(solver.reduce)(data) diff --git a/benchmark/mnist_benchmark.py b/benchmark/mnist_benchmark.py index 16c95467f..ae8fea9ad 100644 --- a/benchmark/mnist_benchmark.py +++ b/benchmark/mnist_benchmark.py @@ -52,6 +52,7 @@ import umap from flax import linen as nn from flax.training import train_state +from jaxtyping import Array, Float, Int from torch.utils.data import DataLoader, Dataset from torchvision import transforms @@ -60,6 +61,7 @@ from coreax.score_matching import KernelDensityMatching from coreax.solvers import ( KernelHerding, + KernelThinning, MapReduce, RandomSample, RPCholesky, @@ -426,6 +428,30 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr return train_data_jax, train_targets_jax, test_data_jax, test_targets_jax +def calculate_delta(n: Int[Array, "1"]) -> Float[Array, "1"]: + """ + Calculate the delta parameter for kernel thinning. + + The function evaluates the following cases: + 1. If `jnp.log(n)` is positive: + - Further evaluates `jnp.log(jnp.log(n))`. + * If this is also positive, returns `1 / n * jnp.log(jnp.log(n))`. + * Otherwise, returns `1 / n * jnp.log(n)`. + 2. If `jnp.log(n)` is negative: + - Returns `1 / n`. + + :param n: The size of the dataset we wish to reduce. + :return: The calculated delta value based on the described conditions. + """ + log_n = jnp.log(n) + if log_n > 0: + log_log_n = jnp.log(log_n) + if log_log_n > 0: + return 1 / (n * log_log_n) + return 1 / (n * log_n) + return 1 / n + + def initialise_solvers( train_data_umap: Data, key: jax.random.PRNGKey ) -> list[Callable[[int], Solver]]: @@ -451,6 +477,28 @@ def initialise_solvers( idx = generator.choice(num_data_points, num_samples_length_scale, replace=False) length_scale = median_heuristic(train_data_umap[idx]) kernel = SquaredExponentialKernel(length_scale=length_scale) + sqrt_kernel = kernel.get_sqrt_kernel(16) + + def _get_thinning_solver(_size: int) -> MapReduce: + """ + Set up KernelThinning to use ``MapReduce``. + + Create a KernelThinning solver with the specified size and return + it along with a MapReduce object for reducing a large dataset like + MNIST dataset. + + :param _size: The size of the coreset to be generated. + :return: MapReduce solver with KernelThinning as the base solver. + """ + thinning_solver = KernelThinning( + coreset_size=_size, + kernel=kernel, + random_key=key, + delta=calculate_delta(num_data_points), + sqrt_kernel=sqrt_kernel, + ) + + return thinning_solver def _get_herding_solver(_size: int) -> MapReduce: """ @@ -461,7 +509,7 @@ def _get_herding_solver(_size: int) -> MapReduce: MNIST dataset. :param _size: The size of the coreset to be generated. - :return: A tuple containing the solver name and the MapReduce solver. + :return: MapReduce solver with KernelHerding as the base solver. """ herding_solver = KernelHerding(_size, kernel) return MapReduce(herding_solver, leaf_size=3 * _size) @@ -475,7 +523,7 @@ def _get_stein_solver(_size: int) -> MapReduce: a subset of the dataset. :param _size: The size of the coreset to be generated. - :return: A tuple containing the solver name and the MapReduce solver. + :return: MapReduce solver with SteinThinning as the base solver. """ # Generate small dataset for ScoreMatching for Stein Kernel @@ -493,7 +541,7 @@ def _get_random_solver(_size: int) -> RandomSample: Set up Random Sampling to generate a coreset. :param _size: The size of the coreset to be generated. - :return: A tuple containing the solver name and the RandomSample solver. + :return: A RandomSample solver. """ random_solver = RandomSample(_size, key) return random_solver @@ -503,12 +551,18 @@ def _get_rp_solver(_size: int) -> RPCholesky: Set up Randomised Cholesky solver. :param _size: The size of the coreset to be generated. - :return: A tuple containing the solver name and the RPCholesky solver. + :return: A RPCholesky solver. """ rp_solver = RPCholesky(coreset_size=_size, kernel=kernel, random_key=key) return rp_solver - return [_get_random_solver, _get_rp_solver, _get_herding_solver, _get_stein_solver] + return [ + _get_random_solver, + _get_rp_solver, + _get_herding_solver, + _get_stein_solver, + _get_thinning_solver, + ] def train_model( diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index c5fa34549..a8838a51c 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -1101,7 +1101,7 @@ def probabilistic_swap( prob = jax.random.uniform(key1) return lax.cond( - prob < 1 / 2 * (1 - alpha / a), + prob > 1 / 2 * (1 - alpha / a), lambda _: (2 * i, 2 * i + 1), # first case: val1 = x1, val2 = x2 lambda _: (2 * i + 1, 2 * i), # second case: val1 = x2, val2 = x1 None,