diff --git a/benchmark/blobs_benchmark.py b/benchmark/blobs_benchmark.py index 029adcae..95f363be 100644 --- a/benchmark/blobs_benchmark.py +++ b/benchmark/blobs_benchmark.py @@ -31,7 +31,6 @@ import json import os import time -from typing import TypeVar import jax import jax.numpy as jnp @@ -40,6 +39,7 @@ from coreax import Data, SlicedScoreMatching from coreax.kernels import ( + ScalarValuedKernel, SquaredExponentialKernel, SteinKernel, median_heuristic, @@ -54,10 +54,8 @@ ) from coreax.weights import MMDWeightsOptimiser -_Solver = TypeVar("_Solver", bound=Solver) - -def setup_kernel(x: jnp.array, random_seed: int = 45) -> SquaredExponentialKernel: +def setup_kernel(x: jax.Array, random_seed: int = 45) -> ScalarValuedKernel: """ Set up a squared exponential kernel using the median heuristic. @@ -74,7 +72,7 @@ def setup_kernel(x: jnp.array, random_seed: int = 45) -> SquaredExponentialKerne def setup_stein_kernel( - sq_exp_kernel: SquaredExponentialKernel, dataset: Data, random_seed: int = 45 + sq_exp_kernel: ScalarValuedKernel, dataset: Data, random_seed: int = 45 ) -> SteinKernel: """ Set up a Stein Kernel for Stein Thinning. @@ -100,10 +98,10 @@ def setup_stein_kernel( def setup_solvers( coreset_size: int, - sq_exp_kernel: SquaredExponentialKernel, + sq_exp_kernel: ScalarValuedKernel, stein_kernel: SteinKernel, random_seed: int = 45, -) -> list[tuple[str, _Solver]]: +) -> list[tuple[str, Solver]]: """ Set up and return a list of solver configurations for reducing a dataset. @@ -145,7 +143,7 @@ def setup_solvers( def compute_solver_metrics( - solver: _Solver, + solver: Solver, dataset: Data, mmd_metric: MMD, ksd_metric: KSD, @@ -188,7 +186,7 @@ def compute_solver_metrics( def compute_metrics( - solvers: list[tuple[str, _Solver]], + solvers: list[tuple[str, Solver]], dataset: Data, mmd_metric: MMD, ksd_metric: KSD, @@ -264,7 +262,7 @@ def main() -> None: # pylint: disable=too-many-locals aggregated_results[size][solver_name][metric].append(value) # Average results across seeds - final_results = {"n_samples": n_samples} + final_results: dict = {"n_samples": n_samples} for size, solvers in aggregated_results.items(): final_results[size] = {} for solver_name, metrics in solvers.items(): diff --git a/benchmark/blobs_benchmark_visualiser.py b/benchmark/blobs_benchmark_visualiser.py index 302fb65a..4e237079 100644 --- a/benchmark/blobs_benchmark_visualiser.py +++ b/benchmark/blobs_benchmark_visualiser.py @@ -89,7 +89,7 @@ def plot_benchmarking_results(data): # Adjust layout to avoid overlap plt.subplots_adjust(hspace=15.0, wspace=1.0) - plt.tight_layout(pad=3.0, rect=[0, 0, 1, 0.96]) + plt.tight_layout(pad=3.0, rect=(0.0, 0.0, 1.0, 0.96)) plt.show() diff --git a/benchmark/mnist_benchmark.py b/benchmark/mnist_benchmark.py index 16c95467..2fd0e59b 100644 --- a/benchmark/mnist_benchmark.py +++ b/benchmark/mnist_benchmark.py @@ -66,6 +66,7 @@ Solver, SteinThinning, ) +from coreax.util import KeyArrayLike # Convert PyTorch dataset to JAX arrays @@ -77,7 +78,8 @@ def convert_to_jax_arrays(pytorch_data: Dataset) -> tuple[jnp.ndarray, jnp.ndarr :return: Tuple of JAX arrays (data, targets). """ # Load all data in one batch - data_loader = DataLoader(pytorch_data, batch_size=len(pytorch_data)) + # pyright is wrong here, a Dataset object does have __len__ method + data_loader = DataLoader(pytorch_data, batch_size=len(pytorch_data)) # type: ignore # Grab the first batch, which is all data _data, _targets = next(iter(data_loader)) # Convert to NumPy first, then JAX array @@ -149,8 +151,8 @@ def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray: class TrainState(train_state.TrainState): """Custom train state with batch statistics and dropout RNG.""" - batch_stats: Optional[dict[str, jnp.ndarray]] = None - dropout_rng: Optional[jnp.ndarray] = None + batch_stats: Optional[dict[str, jnp.ndarray]] + dropout_rng: KeyArrayLike class Metrics(NamedTuple): @@ -161,7 +163,7 @@ class Metrics(NamedTuple): def create_train_state( - rng: jnp.ndarray, _model: nn.Module, learning_rate: float, weight_decay: float + rng: KeyArrayLike, _model: nn.Module, learning_rate: float, weight_decay: float ) -> TrainState: """ Create and initialise the train state. @@ -323,7 +325,7 @@ def train_and_evaluate( train_set: DataSet, test_set: DataSet, _model: nn.Module, - rng: jnp.ndarray, + rng: KeyArrayLike, config: dict[str, Any], ) -> dict[str, float]: """ @@ -427,7 +429,7 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr def initialise_solvers( - train_data_umap: Data, key: jax.random.PRNGKey + train_data_umap: Data, key: KeyArrayLike ) -> list[Callable[[int], Solver]]: """ Initialise and return a list of solvers for various coreset algorithms. @@ -449,7 +451,7 @@ def initialise_solvers( random_seed = 45 generator = np.random.default_rng(random_seed) idx = generator.choice(num_data_points, num_samples_length_scale, replace=False) - length_scale = median_heuristic(train_data_umap[idx]) + length_scale = median_heuristic(jnp.asarray(train_data_umap[idx])) kernel = SquaredExponentialKernel(length_scale=length_scale) def _get_herding_solver(_size: int) -> MapReduce: @@ -479,7 +481,7 @@ def _get_stein_solver(_size: int) -> MapReduce: """ # Generate small dataset for ScoreMatching for Stein Kernel - score_function = KernelDensityMatching(length_scale=length_scale).match( + score_function = KernelDensityMatching(length_scale=length_scale.item()).match( train_data_umap[idx] ) stein_kernel = SteinKernel(kernel, score_function) @@ -513,7 +515,7 @@ def _get_rp_solver(_size: int) -> RPCholesky: def train_model( data_bundle: dict[str, jnp.ndarray], - key: jax.random.PRNGKey, + key: KeyArrayLike, config: dict[str, Union[int, float]], ) -> dict[str, float]: """ diff --git a/benchmark/mnist_benchmark_visualiser.py b/benchmark/mnist_benchmark_visualiser.py index 76287a45..c43876b7 100644 --- a/benchmark/mnist_benchmark_visualiser.py +++ b/benchmark/mnist_benchmark_visualiser.py @@ -199,7 +199,10 @@ def plot_performance( if log_scale: plt.yscale("log") plt.title(title) - plt.xticks(index + bar_width * (n_algorithms - 1) / 2, coreset_sizes) + plt.xticks( + index + bar_width * (n_algorithms - 1) / 2, + [str(size) for size in coreset_sizes], + ) plt.legend() plt.grid(True, linestyle="--", alpha=0.7) plt.tight_layout() diff --git a/benchmark/pounce_benchmark.py b/benchmark/pounce_benchmark.py index ce01ea49..d9a1f395 100644 --- a/benchmark/pounce_benchmark.py +++ b/benchmark/pounce_benchmark.py @@ -60,10 +60,11 @@ def benchmark_coreset_algorithms( raw_data = np.asarray(image_data) reshaped_data = raw_data.reshape(raw_data.shape[0], -1) - umap_model = umap.UMAP(densmap=True, n_components=25) - umap_data = umap_model.fit_transform(reshaped_data) + # TODO: Change n_components back to 20 something + umap_model = umap.UMAP(densmap=True, n_components=5) + umap_data = jnp.asarray(umap_model.fit_transform(reshaped_data)) - solver_factories = initialise_solvers(umap_data, random.PRNGKey(45)) + solver_factories = initialise_solvers(Data(umap_data), random.PRNGKey(45)) for solver_creator in solver_factories: solver = solver_creator(coreset_size) @@ -83,7 +84,7 @@ def benchmark_coreset_algorithms( # Extract corresponding frames from original data and save GIF coreset_frames = raw_data[selected_indices] output_gif_path = out_dir / f"{solver_name}_coreset.gif" - imageio.mimsave(output_gif_path, coreset_frames, loop=0) + imageio.mimsave(output_gif_path, list(coreset_frames), loop=0) print(f"Saved {solver_name} coreset GIF to {output_gif_path}") print(f"time taken: {solver_name:<25} {duration:<30.4f}")