From dc3b72e6ab38856db90f01ef594afcb6e78c5b00 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:18:18 +0000 Subject: [PATCH] chore: update david benchmark script to import from examples #919 --- benchmark/david_benchmark.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/benchmark/david_benchmark.py b/benchmark/david_benchmark.py index fbbdaf62..0606480d 100644 --- a/benchmark/david_benchmark.py +++ b/benchmark/david_benchmark.py @@ -29,6 +29,7 @@ """ import os +import sys import time from pathlib import Path from typing import Optional @@ -38,12 +39,15 @@ import matplotlib.pyplot as plt import numpy as np from jax import random +from mnist_benchmark import get_solver_name, initialise_solvers -from benchmark.mnist_benchmark import get_solver_name, initialise_solvers from coreax import Data from coreax.solvers import MapReduce from examples.david_map_reduce_weighted import downsample_opencv +sys.path.append(str(Path(__file__).parent.parent)) + + MAX_8BIT = 255 @@ -51,7 +55,7 @@ def benchmark_coreset_algorithms( in_path: Path = Path("../examples/data/david_orig.png"), out_path: Optional[Path] = Path("david_benchmark_results.png"), - downsampling_factor: int = 1, + downsampling_factor: int = 6, ): """ Benchmark the performance of coreset algorithms on a downsampled greyscale image. @@ -66,7 +70,6 @@ def benchmark_coreset_algorithms( """ # Base directory of the current script base_dir = os.path.dirname(os.path.abspath(__file__)) - # Convert to absolute paths using os.path.join if not in_path.is_absolute(): in_path = Path(os.path.join(base_dir, in_path))