Skip to content

Commit

Permalink
fix: pyright fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
qh681248 committed Jan 21, 2025
1 parent b1c83a4 commit 62cfb83
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 25 deletions.
18 changes: 8 additions & 10 deletions benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import json
import os
import time
from typing import TypeVar

import jax
import jax.numpy as jnp
Expand All @@ -40,6 +39,7 @@

from coreax import Data, SlicedScoreMatching
from coreax.kernels import (
ScalarValuedKernel,
SquaredExponentialKernel,
SteinKernel,
median_heuristic,
Expand All @@ -54,10 +54,8 @@
)
from coreax.weights import MMDWeightsOptimiser

_Solver = TypeVar("_Solver", bound=Solver)


def setup_kernel(x: jnp.array, random_seed: int = 45) -> SquaredExponentialKernel:
def setup_kernel(x: jax.Array, random_seed: int = 45) -> ScalarValuedKernel:
"""
Set up a squared exponential kernel using the median heuristic.
Expand All @@ -74,7 +72,7 @@ def setup_kernel(x: jnp.array, random_seed: int = 45) -> SquaredExponentialKerne


def setup_stein_kernel(
sq_exp_kernel: SquaredExponentialKernel, dataset: Data, random_seed: int = 45
sq_exp_kernel: ScalarValuedKernel, dataset: Data, random_seed: int = 45
) -> SteinKernel:
"""
Set up a Stein Kernel for Stein Thinning.
Expand All @@ -100,10 +98,10 @@ def setup_stein_kernel(

def setup_solvers(
coreset_size: int,
sq_exp_kernel: SquaredExponentialKernel,
sq_exp_kernel: ScalarValuedKernel,
stein_kernel: SteinKernel,
random_seed: int = 45,
) -> list[tuple[str, _Solver]]:
) -> list[tuple[str, Solver]]:
"""
Set up and return a list of solver configurations for reducing a dataset.
Expand Down Expand Up @@ -145,7 +143,7 @@ def setup_solvers(


def compute_solver_metrics(
solver: _Solver,
solver: Solver,
dataset: Data,
mmd_metric: MMD,
ksd_metric: KSD,
Expand Down Expand Up @@ -188,7 +186,7 @@ def compute_solver_metrics(


def compute_metrics(
solvers: list[tuple[str, _Solver]],
solvers: list[tuple[str, Solver]],
dataset: Data,
mmd_metric: MMD,
ksd_metric: KSD,
Expand Down Expand Up @@ -264,7 +262,7 @@ def main() -> None: # pylint: disable=too-many-locals
aggregated_results[size][solver_name][metric].append(value)

# Average results across seeds
final_results = {"n_samples": n_samples}
final_results: dict = {"n_samples": n_samples}
for size, solvers in aggregated_results.items():
final_results[size] = {}
for solver_name, metrics in solvers.items():
Expand Down
2 changes: 1 addition & 1 deletion benchmark/blobs_benchmark_visualiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def plot_benchmarking_results(data):

# Adjust layout to avoid overlap
plt.subplots_adjust(hspace=15.0, wspace=1.0)
plt.tight_layout(pad=3.0, rect=[0, 0, 1, 0.96])
plt.tight_layout(pad=3.0, rect=(0.0, 0.0, 1.0, 0.96))
plt.show()


Expand Down
20 changes: 11 additions & 9 deletions benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
Solver,
SteinThinning,
)
from coreax.util import KeyArrayLike


# Convert PyTorch dataset to JAX arrays
Expand All @@ -77,7 +78,8 @@ def convert_to_jax_arrays(pytorch_data: Dataset) -> tuple[jnp.ndarray, jnp.ndarr
:return: Tuple of JAX arrays (data, targets).
"""
# Load all data in one batch
data_loader = DataLoader(pytorch_data, batch_size=len(pytorch_data))
# pyright is wrong here, a Dataset object does have __len__ method
data_loader = DataLoader(pytorch_data, batch_size=len(pytorch_data)) # type: ignore
# Grab the first batch, which is all data
_data, _targets = next(iter(data_loader))
# Convert to NumPy first, then JAX array
Expand Down Expand Up @@ -149,8 +151,8 @@ def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray:
class TrainState(train_state.TrainState):
"""Custom train state with batch statistics and dropout RNG."""

batch_stats: Optional[dict[str, jnp.ndarray]] = None
dropout_rng: Optional[jnp.ndarray] = None
batch_stats: Optional[dict[str, jnp.ndarray]]
dropout_rng: KeyArrayLike


class Metrics(NamedTuple):
Expand All @@ -161,7 +163,7 @@ class Metrics(NamedTuple):


def create_train_state(
rng: jnp.ndarray, _model: nn.Module, learning_rate: float, weight_decay: float
rng: KeyArrayLike, _model: nn.Module, learning_rate: float, weight_decay: float
) -> TrainState:
"""
Create and initialise the train state.
Expand Down Expand Up @@ -323,7 +325,7 @@ def train_and_evaluate(
train_set: DataSet,
test_set: DataSet,
_model: nn.Module,
rng: jnp.ndarray,
rng: KeyArrayLike,
config: dict[str, Any],
) -> dict[str, float]:
"""
Expand Down Expand Up @@ -427,7 +429,7 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr


def initialise_solvers(
train_data_umap: Data, key: jax.random.PRNGKey
train_data_umap: Data, key: KeyArrayLike
) -> list[Callable[[int], Solver]]:
"""
Initialise and return a list of solvers for various coreset algorithms.
Expand All @@ -449,7 +451,7 @@ def initialise_solvers(
random_seed = 45
generator = np.random.default_rng(random_seed)
idx = generator.choice(num_data_points, num_samples_length_scale, replace=False)
length_scale = median_heuristic(train_data_umap[idx])
length_scale = median_heuristic(jnp.asarray(train_data_umap[idx]))
kernel = SquaredExponentialKernel(length_scale=length_scale)

def _get_herding_solver(_size: int) -> MapReduce:
Expand Down Expand Up @@ -479,7 +481,7 @@ def _get_stein_solver(_size: int) -> MapReduce:
"""
# Generate small dataset for ScoreMatching for Stein Kernel

score_function = KernelDensityMatching(length_scale=length_scale).match(
score_function = KernelDensityMatching(length_scale=length_scale.item()).match(
train_data_umap[idx]
)
stein_kernel = SteinKernel(kernel, score_function)
Expand Down Expand Up @@ -513,7 +515,7 @@ def _get_rp_solver(_size: int) -> RPCholesky:

def train_model(
data_bundle: dict[str, jnp.ndarray],
key: jax.random.PRNGKey,
key: KeyArrayLike,
config: dict[str, Union[int, float]],
) -> dict[str, float]:
"""
Expand Down
5 changes: 4 additions & 1 deletion benchmark/mnist_benchmark_visualiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ def plot_performance(
if log_scale:
plt.yscale("log")
plt.title(title)
plt.xticks(index + bar_width * (n_algorithms - 1) / 2, coreset_sizes)
plt.xticks(
index + bar_width * (n_algorithms - 1) / 2,
[str(size) for size in coreset_sizes],
)
plt.legend()
plt.grid(True, linestyle="--", alpha=0.7)
plt.tight_layout()
Expand Down
9 changes: 5 additions & 4 deletions benchmark/pounce_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ def benchmark_coreset_algorithms(
raw_data = np.asarray(image_data)
reshaped_data = raw_data.reshape(raw_data.shape[0], -1)

umap_model = umap.UMAP(densmap=True, n_components=25)
umap_data = umap_model.fit_transform(reshaped_data)
# TODO: Change n_components back to 20 something
umap_model = umap.UMAP(densmap=True, n_components=5)
umap_data = jnp.asarray(umap_model.fit_transform(reshaped_data))

solver_factories = initialise_solvers(umap_data, random.PRNGKey(45))
solver_factories = initialise_solvers(Data(umap_data), random.PRNGKey(45))
for solver_creator in solver_factories:
solver = solver_creator(coreset_size)

Expand All @@ -83,7 +84,7 @@ def benchmark_coreset_algorithms(
# Extract corresponding frames from original data and save GIF
coreset_frames = raw_data[selected_indices]
output_gif_path = out_dir / f"{solver_name}_coreset.gif"
imageio.mimsave(output_gif_path, coreset_frames, loop=0)
imageio.mimsave(output_gif_path, list(coreset_frames), loop=0)
print(f"Saved {solver_name} coreset GIF to {output_gif_path}")
print(f"time taken: {solver_name:<25} {duration:<30.4f}")

Expand Down

0 comments on commit 62cfb83

Please sign in to comment.