Skip to content

Commit

Permalink
chore: Merge add kernel-thinning-to-benchmarking branch to pyright-be…
Browse files Browse the repository at this point in the history
…nchmarking branch

#912
  • Loading branch information
qh681248 committed Jan 21, 2025
1 parent 98ede6f commit 98369c8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
7 changes: 3 additions & 4 deletions benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@

from coreax import Data, SlicedScoreMatching
from coreax.kernels import (
ScalarValuedKernel,
SquaredExponentialKernel,
SteinKernel,
median_heuristic,
Expand All @@ -56,7 +55,7 @@
from coreax.weights import MMDWeightsOptimiser


def setup_kernel(x: jax.Array, random_seed: int = 45) -> ScalarValuedKernel:
def setup_kernel(x: jax.Array, random_seed: int = 45) -> SquaredExponentialKernel:
"""
Set up a squared exponential kernel using the median heuristic.
Expand All @@ -73,7 +72,7 @@ def setup_kernel(x: jax.Array, random_seed: int = 45) -> ScalarValuedKernel:


def setup_stein_kernel(
sq_exp_kernel: ScalarValuedKernel, dataset: Data, random_seed: int = 45
sq_exp_kernel: SquaredExponentialKernel, dataset: Data, random_seed: int = 45
) -> SteinKernel:
"""
Set up a Stein Kernel for Stein Thinning.
Expand All @@ -99,7 +98,7 @@ def setup_stein_kernel(

def setup_solvers(
coreset_size: int,
sq_exp_kernel: ScalarValuedKernel,
sq_exp_kernel: SquaredExponentialKernel,
stein_kernel: SteinKernel,
delta: float,
random_seed: int = 45,
Expand Down
8 changes: 4 additions & 4 deletions benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
import umap
from flax import linen as nn
from flax.training import train_state
from jaxtyping import Array, Float, Int
from jaxtyping import Array, Float
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

Expand Down Expand Up @@ -430,7 +430,7 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr
return train_data_jax, train_targets_jax, test_data_jax, test_targets_jax


def calculate_delta(n: Int[Array, "1"]) -> Float[Array, "1"]:
def calculate_delta(n: int) -> Float[Array, "1"]:
"""
Calculate the delta parameter for kernel thinning.
Expand All @@ -451,7 +451,7 @@ def calculate_delta(n: Int[Array, "1"]) -> Float[Array, "1"]:
if log_log_n > 0:
return 1 / (n * log_log_n)
return 1 / (n * log_n)
return 1 / n
return jnp.array(1 / n)


def initialise_solvers(
Expand Down Expand Up @@ -496,7 +496,7 @@ def _get_thinning_solver(_size: int) -> MapReduce:
coreset_size=_size,
kernel=kernel,
random_key=key,
delta=calculate_delta(num_data_points),
delta=calculate_delta(num_data_points).item(),
sqrt_kernel=sqrt_kernel,
)

Expand Down

0 comments on commit 98369c8

Please sign in to comment.