Skip to content

Commit

Permalink
make jahs optional
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Aug 24, 2023
1 parent 6b772a5 commit 4a203d4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 8 deletions.
33 changes: 27 additions & 6 deletions mfpbench/jahs/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from abc import ABC
from pathlib import Path
from typing import Any
from typing import Any, TYPE_CHECKING

import jahs_bench
from ConfigSpace import Configuration, ConfigurationSpace
from typing_extensions import Literal

from mfpbench.benchmark import Benchmark
from mfpbench.download import DATAROOT
Expand All @@ -14,6 +14,9 @@
from mfpbench.jahs.spaces import jahs_configspace
from mfpbench.util import rename

if TYPE_CHECKING:
import jahs_bench


class JAHSBenchmark(Benchmark[JAHSConfig, JAHSResult, int], ABC):
"""Manages access to jahs-bench.
Expand Down Expand Up @@ -117,6 +120,25 @@ def load(self) -> None:
def bench(self) -> jahs_bench.Benchmark:
"""The underlying benchmark used."""
if not self._bench:
try:
import jahs_bench
except ImportError as e:
raise ImportError(
"jahs-bench not installed, please install it with "
"`pip install jahs-bench`"
) from e

tasks = {
"cifar10": jahs_bench.BenchmarkTasks.CIFAR10,
"colorectal_histology": jahs_bench.BenchmarkTasks.ColorectalHistology,
"fashion_mnist": jahs_bench.BenchmarkTasks.FashionMNIST,
}
task = tasks.get(self.task, None)
if task is None:
raise ValueError(
f"Unknown task {self.task}, must be one of {list(tasks.keys())}"
)

self._bench = jahs_bench.Benchmark(
task=self.task,
save_dir=self.datadir,
Expand Down Expand Up @@ -242,12 +264,11 @@ def __repr__(self) -> str:


class JAHSCifar10(JAHSBenchmark):
task = jahs_bench.BenchmarkTasks.CIFAR10
task = "cifar10"


class JAHSColorectalHistology(JAHSBenchmark):
task = jahs_bench.BenchmarkTasks.ColorectalHistology

task = "colorectal_histology"

class JAHSFashionMNIST(JAHSBenchmark):
task = jahs_bench.BenchmarkTasks.FashionMNIST
task = "fashion_mnist"
9 changes: 8 additions & 1 deletion mfpbench/jahs/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Constant,
UniformFloatHyperparameter,
)
from jahs_bench.lib.core.constants import Activations


def jahs_configspace(
Expand All @@ -33,6 +32,14 @@ def jahs_configspace(
if isinstance(seed, np.random.RandomState):
seed = seed.tomaxint()

try:
from jahs_bench.lib.core.constants import Activations
except ImportError as e:
raise ImportError(
"jahs-bench not installed, please install it with "
"`pip install jahs-bench`"
) from e

space = ConfigurationSpace(name=name, seed=seed)
space.add_hyperparameters(
[
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ pyyaml = "*"
tqdm = "*"
numpy = "1.*"
yahpo-gym = "1.0.1"
jahs-bench = { git = "https://github.com/automl/jahs_bench_201.git", rev = "880fbcb35a83df7b6c02440a6c13adb921f54657" }
xgboost = "^1"

[tool.poetry.group.jahs.dependencies]
jahs-bench = { git = "https://github.com/automl/jahs_bench_201.git", rev = "880fbcb35a83df7b6c02440a6c13adb921f54657" }


[tool.poetry.group.dev.dependencies]
ruff = "^0.0.177"
Expand Down

0 comments on commit 4a203d4

Please sign in to comment.