diff --git a/.github/workflows/new_tasks.yml b/.github/workflows/new_tasks.yml index 0f83a6b283..d43082a8c2 100644 --- a/.github/workflows/new_tasks.yml +++ b/.github/workflows/new_tasks.yml @@ -16,7 +16,7 @@ jobs: name: Scan for changed tasks steps: - name: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 2 # OR "2" -> To retrieve the preceding commit. @@ -25,7 +25,7 @@ jobs: # and prepends the filter name to the standard output names. - name: Check task folders id: changed-tasks - uses: tj-actions/changed-files@v46.0.5 + uses: tj-actions/changed-files@24d32ffd492484c1d75e0c0b894501ddb9d30d62 with: # tasks checks the tasks folder and api checks the api folder for changes files_yaml: | @@ -44,28 +44,24 @@ jobs: echo "One or more test file(s) has changed." echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}" - - name: Set up Python 3.10 + - name: Install uv if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' - uses: actions/setup-python@v5 + uses: astral-sh/setup-uv@v7 with: - python-version: '3.10' - cache: 'pip' - cache-dependency-path: pyproject.toml + enable-cache: true + python-version: "3.10" + activate-environment: true - name: Install dependencies if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' run: | - python -m pip install --upgrade pip - pip install -e '.[dev,ifeval,unitxt,math,longbench]' --extra-index-url https://download.pytorch.org/whl/cpu - # Install optional git dependencies - # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt - # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + uv pip install -e '.[dev,ifeval,unitxt,math,longbench,hf]' --extra-index-url https://download.pytorch.org/whl/cpu - name: Test with pytest # if new tasks are added, run tests on them if: steps.changed-tasks.outputs.tasks_any_modified == 'true' - run: python -m pytest tests/test_tasks.py -s -vv + run: pytest -x -s -vv tests/test_tasks.py # if api is modified, run tests on it - name: Test more tasks with pytest env: API: true if: steps.changed-tasks.outputs.api_any_modified == 'true' - run: python -m pytest tests/test_tasks.py -s -vv + run: pytest -x -s -vv -n=auto tests/test_tasks.py diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index b8c4b96ea5..3c6b3a040e 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -21,13 +21,15 @@ jobs: steps: - name: Checkout Code - uses: actions/checkout@v4 - - name: Set up Python 3.10 - uses: actions/setup-python@v5 + uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: - python-version: '3.10' - cache: pip - cache-dependency-path: pyproject.toml + enable-cache: true + python-version: "3.10" + activate-environment: true + - name: Install pip + run: uv pip install pip - name: Pre-Commit env: SKIP: "no-commit-to-branch,mypy" @@ -43,13 +45,13 @@ jobs: timeout-minutes: 30 steps: - name: Checkout Code - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + enable-cache: true python-version: ${{ matrix.python-version }} - cache: pip - cache-dependency-path: pyproject.toml + activate-environment: true # Cache HuggingFace cache directory for CPU tests - name: Cache HuggingFace cache (CPU tests) @@ -63,17 +65,16 @@ jobs: - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install -e '.[dev,unitxt]' --extra-index-url https://download.pytorch.org/whl/cpu - pip install hf_xet + uv pip install -e '.[dev,unitxt,hf]' --extra-index-url https://download.pytorch.org/whl/cpu + uv pip install hf_xet - name: Test with pytest - run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/models/test_openvino.py --ignore=tests/models/test_hf_steered.py - continue-on-error: true # Continue workflow even if tests fail + run: pytest -x --showlocals -s -vv -n=auto --ignore=tests/models/test_openvino.py --ignore=tests/models/test_hf_steered.py --ignore=tests/scripts/test_zeno_visualize.py # Save test artifacts - name: Archive test artifacts - uses: actions/upload-artifact@v4 + if: always() # Upload artifacts even if tests fail + uses: actions/upload-artifact@v5 with: name: output_testcpu${{ matrix.python-version }} path: | diff --git a/.gitignore b/.gitignore index 9ae167be97..f9b4150b63 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,6 @@ examples/wandb/ # PyInstaller *.spec + +#uv +uv.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7fe20c34a..0dba723591 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: - id: mixed-line-ending args: [ --fix=lf ] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.13.2 + rev: v0.14.6 hooks: # Run the linter. - id: ruff-check @@ -46,7 +46,7 @@ repos: args: [ --check-filenames, --check-hidden, --ignore-words=ignore.txt ] - repo: https://github.com/jackdewinter/pymarkdown - rev: v0.9.32 + rev: v0.9.33 hooks: - id: pymarkdown exclude: ^(lm_eval/tasks/.*|docs/footguns\.md)$ diff --git a/lm_eval/api/registry.py b/lm_eval/api/registry.py index 4673b157b1..edd2d688ed 100644 --- a/lm_eval/api/registry.py +++ b/lm_eval/api/registry.py @@ -1,58 +1,518 @@ -import logging -from typing import Callable, Dict, Union +"""Registry system for lm_eval components. + +This module provides a centralized registration system for models, tasks, metrics, +filters, and other components in the lm_eval framework. -import evaluate as hf_evaluate +## Usage Examples +### Registering a Model +```python +from lm_eval.api.registry import register_model from lm_eval.api.model import LM +@register_model("my-model") +class MyModel(LM): + def __init__(self, **kwargs): + ... +``` + +### Registering with Lazy Loading +```python +# Register without importing the actual implementation +model_registry.register("lazy-model", lazy="my_package.models:LazyModel") +``` + +### Looking up Components +```python +from lm_eval.api.registry import get_model + +# Get a model class +model_cls = get_model("gpt-j") +model = model_cls(**config) +``` +""" + +from __future__ import annotations + +import importlib +import importlib.metadata as md +import inspect +import logging +import threading +from collections.abc import Callable +from functools import lru_cache +from types import MappingProxyType +from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, overload + eval_logger = logging.getLogger(__name__) -MODEL_REGISTRY = {} + +if TYPE_CHECKING: + from lm_eval.api.filter import Filter + from lm_eval.api.model import LM + + +__all__ = [ + # Core registry class + "Registry", + # Registry instances + "model_registry", + "task_registry", + "filter_registry", + "aggregation_registry", + "metric_registry", + "metric_agg_registry", + "higher_is_better_registry", + "freeze_all", + # Helper functions + "register_model", + "get_model", + "register_task", + "get_task", + "register_group", + "register_metric", + "get_metric", + "register_aggregation", + "get_aggregation", + "get_metric_aggregation", + "is_higher_better", + "register_filter", + "get_filter", + # Backward compat aliases (point to Registry instances) + "MODEL_REGISTRY", + "TASK_REGISTRY", + "FILTER_REGISTRY", + "METRIC_REGISTRY", + "METRIC_AGGREGATION_REGISTRY", + "AGGREGATION_REGISTRY", + "HIGHER_IS_BETTER_REGISTRY", + # Legacy global state + "GROUP_REGISTRY", + "ALL_TASKS", + "DEFAULT_METRIC_REGISTRY", +] + + +T = TypeVar("T") +D = TypeVar("D") +Placeholder = Union[str, md.EntryPoint] + +# Sentinel for distinguishing "no default" from "default=None" +_MISSING: Any = object() + + +@lru_cache(maxsize=16) +def _materialise_placeholder(ph: Placeholder) -> Any: + """Materialize a lazy placeholder into the actual object. + + This is at module level to avoid memory leaks from lru_cache on instance methods. + + Args: + ph: Either a string path "module:object" or an EntryPoint instance + + Returns: + The loaded object + + Raises: + ValueError: If the string format is invalid + ImportError: If the module cannot be imported + AttributeError: If the object doesn't exist in the module + """ + if isinstance(ph, str): + mod, _, attr = ph.partition(":") + if not attr: + raise ValueError(f"Invalid lazy path '{ph}', expected 'module:object'") + return getattr(importlib.import_module(mod), attr) + return ph.load() + + +class Registry(Generic[T]): + """Thread-safe dict mapping string aliases to objects or lazy placeholders. + + Lazy placeholders ("module.path:attr" strings or EntryPoints) are + materialized on first access via `get()`. Optional `base_cls` enforces + type constraints. Call `freeze()` to make read-only. + """ + + def __init__( + self, + name: str, + *, + base_cls: type[T] | None = None, + ) -> None: + """Initialize a new registry. + + Args: + name: Human-readable name for error messages (e.g., "model", "metric") + base_cls: Optional base class that all registered objects must inherit from + """ + self._name = name + self._base_cls = base_cls + self._objs: dict[str, T | Placeholder] = {} + self._lock = threading.RLock() + + # Registration (decorator or direct call) -------------------------------------- + + def register( + self, + *aliases: str, + lazy: T | Placeholder | None = None, + ) -> Callable[[T], T]: + """Register an object under one or more aliases. + + Can be used as a decorator or called directly for lazy registration. + + Args: + *aliases: Names to register the object under. If empty, uses object's __name__ + lazy: For direct calls only - a placeholder string "module:object" or EntryPoint + + Returns: + Decorator function (or no-op if lazy registration) + + Examples: + >>> # As decorator + >>> @model_registry.register("name1", "name2") + >>> class MyModel(LM): + ... pass + >>> + >>> # Direct lazy registration + >>> model_registry.register("lazy-name", lazy="mymodule:MyModel") + + Raises: + ValueError: If alias is already registered with a different target + TypeError: If an object doesn't inherit from base_cls (when specified) + """ + + def _store(alias: str, target: T | Placeholder) -> None: + current = self._objs.get(alias) + # collision handling ------------------------------------------ + if current is not None and current != target: + # allow placeholder → real object upgrade + if ( + isinstance(current, str) + and isinstance(target, type) + and current == f"{target.__module__}:{target.__name__}" + ): + self._objs[alias] = target + return + raise ValueError( + f"{self._name!r} alias '{alias}' already registered (" + f"existing={current}, new={target})" + ) + # type check for concrete classes ---------------------------------------------- + if ( + self._base_cls is not None + and isinstance(target, type) + and not issubclass(target, self._base_cls) + ): + raise TypeError( + f"{target} must inherit from {self._base_cls} to be a {self._name}" + ) + self._objs[alias] = target + + def decorator(obj: T) -> T: # type: ignore[valid-type] + names = aliases or (getattr(obj, "__name__", str(obj)),) + with self._lock: + for name in names: + _store(name, obj) + return obj + + # Direct call with *lazy* placeholder + if lazy is not None: + if len(aliases) != 1: + raise ValueError("Exactly one alias required when using 'lazy='") + with self._lock: + _store(aliases[0], lazy) # type: ignore[arg-type] + # return no‑op decorator for accidental use + return lambda x: x # type: ignore[return-value] + + return decorator + + # Lookup & materialisation -------------------------------------------------- + + def _materialise(self, ph: Placeholder) -> T: + """Materialize a placeholder using the module-level cached function. + + Args: + ph: Placeholder to materialize + + Returns: + The materialized object, cast to type T + """ + return cast(T, _materialise_placeholder(ph)) + + @overload + def get(self, alias: str) -> T: ... + + @overload + def get(self, alias: str, default: D) -> T | D: ... + + def get(self, alias: str, default: D | Any = _MISSING) -> T | D: + """Retrieve an object by alias, materializing if needed. + + Thread-safe lazy loading: if the alias points to a placeholder, + it will be loaded and cached before returning. + + Args: + alias: The registered name to look up + default: Default value to return if alias not found (can be None) + + Returns: + The registered object, or default if not found + + Raises: + KeyError: If alias not found and no default provided + TypeError: If materialized object doesn't match base_cls + ImportError/AttributeError: If lazy loading fails + """ + try: + target = self._objs[alias] + except KeyError as exc: + if default is not _MISSING: + return default + raise KeyError( + f"Unknown {self._name} '{alias}'. Available: {', '.join(self._objs)}" + ) from exc + + if isinstance(target, (str, md.EntryPoint)): + with self._lock: + # Re‑check under lock (another thread might have resolved it) + fresh = self._objs[alias] + if isinstance(fresh, (str, md.EntryPoint)): + concrete = self._materialise(fresh) + # Only update if not frozen (MappingProxyType) + if not isinstance(self._objs, MappingProxyType): + self._objs[alias] = concrete + else: + concrete = fresh # another thread did the job + target = concrete + + # Late type/validator checks + if ( + self._base_cls is not None + and isinstance(target, type) + and not issubclass(target, self._base_cls) + ): + raise TypeError( + f"{target} does not inherit from {self._base_cls} (alias '{alias}')" + ) from None + return target + + def __getitem__(self, alias: str) -> T: + """Allow dict-style access: registry[alias].""" + return self.get(alias) + + def __contains__(self, alias: str) -> bool: + """Check if alias is registered.""" + return alias in self._objs + + def __iter__(self): + """Iterate over registered aliases.""" + return iter(self._objs) + + def __len__(self): + """Return number of registered aliases.""" + return len(self._objs) + + def keys(self): + """Return all registered aliases.""" + return self._objs.keys() + + def values(self): + """Return all registered objects. + + Note: Objects may be placeholders that haven't been materialized yet. + """ + return self._objs.values() + + def items(self): + """Return (alias, object) pairs. + + Note: Objects may be placeholders that haven't been materialized yet. + """ + return self._objs.items() + + # Utilities ------------------------------------------------------------- + + def origin(self, alias: str) -> str | None: + """Get the source location of a registered object. + + Args: + alias: The registered name + + Returns: + "path/to/file.py:line_number" or None if not available + """ + obj = self._objs.get(alias) + if isinstance(obj, (str, md.EntryPoint)): + return None + try: + path = inspect.getfile(obj) # type: ignore[arg-type] + line = inspect.getsourcelines(obj)[1] # type: ignore[arg-type] + return f"{path}:{line}" + except Exception: # pragma: no cover – best‑effort only + return None + + def freeze(self): + """Make the registry read-only to prevent further modifications. + + After freezing, attempts to register new objects will fail. + This is useful for ensuring registry contents don't change after + initialization. + """ + with self._lock: + self._objs = MappingProxyType(dict(self._objs)) # type: ignore[assignment] + + # Test helper -------------------------------- + def _clear(self): # pragma: no cover + """Erase registry (for isolated tests). + + Clears both the registry contents and the materialization cache. + Only use this in test code to ensure clean state between tests. + """ + if isinstance(self._objs, MappingProxyType): + self._objs = dict(self._objs) # type: ignore[assignment] + self._objs.clear() + _materialise_placeholder.cache_clear() + + +# ============================================================================= +# Registry instances +# ============================================================================= + +model_registry: Registry[type[LM]] = Registry("model") +task_registry: Registry[Callable[..., Any]] = Registry("task") +filter_registry: Registry[type[Filter]] = Registry("filter") +aggregation_registry: Registry[Callable[..., float]] = Registry("aggregation") +metric_registry: Registry[Callable] = Registry("metric") +metric_agg_registry: Registry[Callable] = Registry("metric_aggregation") +higher_is_better_registry: Registry[bool] = Registry("higher_is_better") + + +def freeze_all(): + """Freeze all registries to prevent further modifications. + + This is useful for ensuring registry contents are immutable after + initialization, preventing accidental modifications during runtime. + """ + for r in ( + model_registry, + task_registry, + filter_registry, + aggregation_registry, + metric_registry, + metric_agg_registry, + higher_is_better_registry, + ): + r.freeze() + + +# ============================================================================= +# Legacy global state (for backward compatibility) +# ============================================================================= + +GROUP_REGISTRY: dict[str, list] = {} +ALL_TASKS: set = set() +func2task_index: dict[str, str] = {} +OUTPUT_TYPE_REGISTRY: dict[str, Any] = {} + +# Backward compat aliases - these now point to Registry instances +METRIC_REGISTRY = metric_registry +METRIC_AGGREGATION_REGISTRY = metric_agg_registry +AGGREGATION_REGISTRY = aggregation_registry +HIGHER_IS_BETTER_REGISTRY = higher_is_better_registry + +DEFAULT_METRIC_REGISTRY = { + "loglikelihood": [ + "perplexity", + "acc", + ], + "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], + "multiple_choice": ["acc", "acc_norm"], + "generate_until": ["exact_match"], +} + + +# ============================================================================= +# Model registration (using new Registry class) +# ============================================================================= def register_model(*names): - # either pass a list or a single alias. - # function receives them as a tuple of strings + """Decorator to register a model class. - def decorate(cls): - for name in names: - assert issubclass(cls, LM), ( - f"Model '{name}' ({cls.__name__}) must extend LM class" - ) + Args: + *names: One or more names to register the model under - assert name not in MODEL_REGISTRY, ( - f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead." - ) + Returns: + Decorator function - MODEL_REGISTRY[name] = cls + Example: + >>> @register_model("my-model", "my-model-alias") + >>> class MyModel(LM): + ... pass + """ + # Import here to avoid circular import at module load time + from lm_eval.api.model import LM + + def decorate(cls): + assert issubclass(cls, LM), f"Model '{cls.__name__}' must extend LM class" + # Use Registry's public API - it handles placeholder→concrete upgrades + model_registry.register(*names)(cls) return cls return decorate -def get_model(model_name): +def get_model(model_name: str): + """Get a model class by name. + + Args: + model_name: The registered name of the model + + Returns: + The model class + + Raises: + ValueError: If model name is not found + """ + # Auto-import models module if registry is empty (lazy initialization) + if len(model_registry) == 0: + import lm_eval.models # noqa: F401 + try: - return MODEL_REGISTRY[model_name] + return model_registry.get(model_name) except KeyError: raise ValueError( - f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}" + f"Attempted to load model '{model_name}', but no model for this name found! " + f"Supported model names: {', '.join(model_registry.keys())}" ) -TASK_REGISTRY = {} -GROUP_REGISTRY = {} -ALL_TASKS = set() -func2task_index = {} +# Backward compatibility alias +MODEL_REGISTRY = model_registry + + +# ============================================================================= +# Task registration (using new Registry class) +# ============================================================================= def register_task(name): - def decorate(fn): - assert name not in TASK_REGISTRY, ( - f"task named '{name}' conflicts with existing registered task!" - ) + """Decorator to register a task. + + Args: + name: Name to register the task under + + Returns: + Decorator function + """ - TASK_REGISTRY[name] = fn + def decorate(fn): + # Use Registry's public API for registration + task_registry.register(name)(fn) + # Also update legacy global state for backward compatibility ALL_TASKS.add(name) func2task_index[fn.__name__] = name return fn @@ -60,7 +520,32 @@ def decorate(fn): return decorate +def get_task(name): + """Get a task by name. + + Args: + name: The registered name of the task + + Returns: + The task function/class + """ + return task_registry.get(name) + + +# Backward compatibility alias +TASK_REGISTRY = task_registry + + def register_group(name): + """Decorator to register a task group. + + Args: + name: Name of the group + + Returns: + Decorator function + """ + def decorate(fn): func_name = func2task_index[fn.__name__] if name in GROUP_REGISTRY: @@ -73,124 +558,203 @@ def decorate(fn): return decorate -OUTPUT_TYPE_REGISTRY = {} -METRIC_REGISTRY = {} -METRIC_AGGREGATION_REGISTRY = {} -AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {} -HIGHER_IS_BETTER_REGISTRY = {} -FILTER_REGISTRY = {} +# ============================================================================= +# Filter registration (using new Registry class) +# ============================================================================= -DEFAULT_METRIC_REGISTRY = { - "loglikelihood": [ - "perplexity", - "acc", - ], - "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], - "multiple_choice": ["acc", "acc_norm"], - "generate_until": ["exact_match"], -} + +def register_filter(name): + """Decorator to register a filter class. + + Args: + name: Name to register the filter under + + Returns: + Decorator function + """ + + def decorate(cls): + if name in filter_registry: + eval_logger.info(f"Registering filter `{name}` that is already in Registry") + # Use Registry's public API for registration + filter_registry.register(name)(cls) + return cls + + return decorate + + +def get_filter(filter_name: str | Callable) -> Callable: + """Get a filter by name. + + Args: + filter_name: The registered name of the filter, or a callable + + Returns: + The filter class/function + + Raises: + KeyError: If filter name is not found and is not callable + """ + if callable(filter_name): + return filter_name + try: + return filter_registry.get(filter_name) + except KeyError as e: + eval_logger.warning(f"filter `{filter_name}` is not registered!") + raise e + + +# Backward compatibility alias +FILTER_REGISTRY = filter_registry + + +# ============================================================================= +# Metric registration (using new Registry class) +# ============================================================================= def register_metric(**args): - # TODO: do we want to enforce a certain interface to registered metrics? + """Decorator to register a metric function. + + Args: + **args: Keyword arguments including: + - metric: Name to register the metric under (required) + - higher_is_better: Whether higher scores are better + - aggregation: Name of aggregation function to use + + Returns: + Decorator function + """ + def decorate(fn): assert "metric" in args name = args["metric"] - for key, registry in [ - ("metric", METRIC_REGISTRY), - ("higher_is_better", HIGHER_IS_BETTER_REGISTRY), - ("aggregation", METRIC_AGGREGATION_REGISTRY), - ]: - if key in args: - value = args[key] - assert value not in registry, ( - f"{key} named '{value}' conflicts with existing registered {key}!" - ) + # Register the metric function + metric_registry.register(name)(fn) - if key == "metric": - registry[name] = fn - elif key == "aggregation": - registry[name] = AGGREGATION_REGISTRY[value] - else: - registry[name] = value + # Register higher_is_better if provided + if "higher_is_better" in args: + higher_is_better_registry.register(name, lazy=args["higher_is_better"]) + + # Register aggregation if provided + if "aggregation" in args: + agg_fn = aggregation_registry.get(args["aggregation"]) + metric_agg_registry.register(name, lazy=agg_fn) return fn return decorate -def get_metric(name: str, hf_evaluate_metric=False) -> Callable: +def get_metric(name: str, hf_evaluate_metric: bool = False) -> Callable | None: + """Get a metric function by name. + + Args: + name: The metric name + hf_evaluate_metric: If True, skip local registry and use HF evaluate + + Returns: + The metric compute function, or None if not found + """ + # Auto-import metrics module if registry is empty (lazy initialization) + if len(metric_registry) == 0: + import lm_eval.api.metrics # noqa: F401 + if not hf_evaluate_metric: - if name in METRIC_REGISTRY: - return METRIC_REGISTRY[name] + if name in metric_registry: + return metric_registry.get(name) else: eval_logger.warning( f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..." ) try: + import evaluate as hf_evaluate + metric_object = hf_evaluate.load(name) return metric_object.compute except Exception: eval_logger.error( f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric", ) + return None def register_aggregation(name: str): - def decorate(fn): - assert name not in AGGREGATION_REGISTRY, ( - f"aggregation named '{name}' conflicts with existing registered aggregation!" - ) + """Decorator to register an aggregation function. + + Args: + name: Name to register the aggregation under + + Returns: + Decorator function + """ - AGGREGATION_REGISTRY[name] = fn + def decorate(fn): + aggregation_registry.register(name)(fn) return fn return decorate -def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: +def get_aggregation(name: str) -> Callable[..., float] | None: + """Get an aggregation function by name. + + Args: + name: The aggregation name + + Returns: + The aggregation function, or None if not found + """ + # Auto-import metrics module if registry is empty (lazy initialization) + if len(aggregation_registry) == 0: + import lm_eval.api.metrics # noqa: F401 + try: - return AGGREGATION_REGISTRY[name] + return aggregation_registry.get(name) except KeyError: eval_logger.warning(f"{name} not a registered aggregation metric!") + return None -def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: - try: - return METRIC_AGGREGATION_REGISTRY[name] - except KeyError: - eval_logger.warning(f"{name} metric is not assigned a default aggregation!") +def get_metric_aggregation(name: str) -> Callable[..., float] | None: + """Get the aggregation function for a metric. + Args: + name: The metric name + + Returns: + The aggregation function for that metric, or None if not found + """ + # Auto-import metrics module if registry is empty (lazy initialization) + if len(metric_agg_registry) == 0: + import lm_eval.api.metrics # noqa: F401 -def is_higher_better(metric_name) -> bool: try: - return HIGHER_IS_BETTER_REGISTRY[metric_name] + return metric_agg_registry.get(name) except KeyError: - eval_logger.warning( - f"higher_is_better not specified for metric '{metric_name}'!" - ) + eval_logger.warning(f"{name} metric is not assigned a default aggregation!") + return None -def register_filter(name): - def decorate(cls): - if name in FILTER_REGISTRY: - eval_logger.info( - f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}" - ) - FILTER_REGISTRY[name] = cls - return cls +def is_higher_better(metric_name: str) -> bool | None: + """Check if higher values are better for a metric. - return decorate + Args: + metric_name: The metric name + Returns: + True if higher is better, False otherwise, None if not found + """ + # Auto-import metrics module if registry is empty (lazy initialization) + if len(higher_is_better_registry) == 0: + import lm_eval.api.metrics # noqa: F401 -def get_filter(filter_name: Union[str, Callable]) -> Callable: try: - return FILTER_REGISTRY[filter_name] - except KeyError as e: - if callable(filter_name): - return filter_name - else: - eval_logger.warning(f"filter `{filter_name}` is not registered!") - raise e + return higher_is_better_registry.get(metric_name) + except KeyError: + eval_logger.warning( + f"higher_is_better not specified for metric '{metric_name}'!" + ) + return None diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index a0f6179bf3..dd8cb43916 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -8,12 +8,11 @@ from typing import TYPE_CHECKING, List, Optional, Union import numpy as np -import torch import lm_eval.api.metrics +import lm_eval.api.model import lm_eval.api.registry import lm_eval.api.task -import lm_eval.models from lm_eval.caching.cache import delete_cache from lm_eval.evaluator_utils import ( consolidate_group_results, @@ -33,6 +32,7 @@ hash_dict_images, hash_string, positional_deprecated, + set_torch_seed, setup_logging, simple_parse_args_string, wrap_text, @@ -193,7 +193,7 @@ def simple_evaluate( if torch_random_seed is not None: seed_message.append(f"Setting torch manual seed to {torch_random_seed}") - torch.manual_seed(torch_random_seed) + set_torch_seed(torch_random_seed) if fewshot_random_seed is not None: seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}") @@ -387,7 +387,7 @@ def _adjust_config(task_dict): "model_args": model_args, } # add more detailed model info if available - if isinstance(lm, lm_eval.models.huggingface.HFLM): + if hasattr(lm, "get_model_info"): results["config"].update(lm.get_model_info()) # add info about execution results["config"].update( @@ -553,6 +553,8 @@ def evaluate( requests[reqtype].append(instance) if lm.world_size > 1: + import torch + instances_rnk = torch.tensor(len(task._instances), device=lm.device) gathered_item = ( lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist() @@ -661,6 +663,8 @@ def evaluate( task_output.sample_metrics[(metric, filter_key)].append(value) if WORLD_SIZE > 1: + import torch + # if multigpu, then gather data across all ranks to rank 0 # first gather logged samples across all ranks for task_output in eval_tasks: diff --git a/lm_eval/filters/__init__.py b/lm_eval/filters/__init__.py index be5c9d4362..0049be4e01 100644 --- a/lm_eval/filters/__init__.py +++ b/lm_eval/filters/__init__.py @@ -1,25 +1,33 @@ +from __future__ import annotations + from functools import partial -from typing import List from lm_eval.api.filter import FilterEnsemble -from lm_eval.api.registry import get_filter +from lm_eval.api.registry import filter_registry, get_filter from . import custom, extraction, selection, transformation def build_filter_ensemble( - filter_name: str, components: List[List[str]] + filter_name: str, + components: list[tuple[str, dict[str, str | int | float] | None]], ) -> FilterEnsemble: """ Create a filtering pipeline. """ - filters = [] - for function, kwargs in components: - if kwargs is None: - kwargs = {} - # create a filter given its name in the registry - f = partial(get_filter(function), **kwargs) - # add the filter as a pipeline step - filters.append(f) + # create filters given its name in the registry, and add each as a pipeline step + return FilterEnsemble( + name=filter_name, + filters=[ + partial(get_filter(func), **(kwargs or {})) for func, kwargs in components + ], + ) + - return FilterEnsemble(name=filter_name, filters=filters) +__all__ = [ + "custom", + "extraction", + "selection", + "transformation", + "build_filter_ensemble", +] diff --git a/lm_eval/models/__init__.py b/lm_eval/models/__init__.py index abedc5535e..e9739c6693 100644 --- a/lm_eval/models/__init__.py +++ b/lm_eval/models/__init__.py @@ -1,28 +1,53 @@ -from . import ( - anthropic_llms, - api_models, - dummy, - gguf, - hf_audiolm, - hf_steered, - hf_vlms, - huggingface, - ibm_watsonx_ai, - mamba_lm, - nemo_lm, - neuron_optimum, - openai_completions, - optimum_ipex, - optimum_lm, - sglang_causallms, - sglang_generate_API, - textsynth, - vllm_causallms, - vllm_vlms, -) - - -# TODO: implement __all__ +# Models are now lazily loaded via the registry system +# No need to import them all at once - they're loaded on demand + +# Define model mappings for lazy registration +MODEL_MAPPING = { + "anthropic-completions": "lm_eval.models.anthropic_llms:AnthropicLM", + "anthropic-chat": "lm_eval.models.anthropic_llms:AnthropicChatLM", + "anthropic-chat-completions": "lm_eval.models.anthropic_llms:AnthropicCompletionsLM", + "local-completions": "lm_eval.models.openai_completions:LocalCompletionsAPI", + "local-chat-completions": "lm_eval.models.openai_completions:LocalChatCompletion", + "openai-completions": "lm_eval.models.openai_completions:OpenAICompletionsAPI", + "openai-chat-completions": "lm_eval.models.openai_completions:OpenAIChatCompletion", + "dummy": "lm_eval.models.dummy:DummyLM", + "gguf": "lm_eval.models.gguf:GGUFLM", + "ggml": "lm_eval.models.gguf:GGUFLM", + "hf-audiolm-qwen": "lm_eval.models.hf_audiolm:HFAudioLM", + "steered": "lm_eval.models.hf_steered:SteeredHF", + "hf-multimodal": "lm_eval.models.hf_vlms:HFMultimodalLM", + "hf-auto": "lm_eval.models.huggingface:HFLM", + "hf": "lm_eval.models.huggingface:HFLM", + "huggingface": "lm_eval.models.huggingface:HFLM", + "watsonx_llm": "lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI", + "mamba_ssm": "lm_eval.models.mamba_lm:MambaLMWrapper", + "nemo_lm": "lm_eval.models.nemo_lm:NeMoLM", + "neuronx": "lm_eval.models.neuron_optimum:NeuronModelForCausalLM", + "ipex": "lm_eval.models.optimum_ipex:IPEXForCausalLM", + "openvino": "lm_eval.models.optimum_lm:OptimumLM", + "sglang": "lm_eval.models.sglang_causallms:SGLANG", + "sglang-generate": "lm_eval.models.sglang_generate_API:SGAPI", + "textsynth": "lm_eval.models.textsynth:TextSynthLM", + "vllm": "lm_eval.models.vllm_causallms:VLLM", + "vllm-vlm": "lm_eval.models.vllm_vlms:VLLM_VLM", +} + + +def _register_all_models(): + """Register all known models lazily in the registry.""" + from lm_eval.api.registry import model_registry + + for name, path in MODEL_MAPPING.items(): + # Only register if not already present (avoids conflicts when modules are imported) + if name not in model_registry: + # Register the lazy placeholder using lazy parameter + model_registry.register(name, lazy=path) + + +# Call registration on module import +_register_all_models() + +__all__ = ["MODEL_MAPPING"] try: diff --git a/lm_eval/models/dummy.py b/lm_eval/models/dummy.py index 014ad49ee3..949eccc6fb 100644 --- a/lm_eval/models/dummy.py +++ b/lm_eval/models/dummy.py @@ -8,7 +8,7 @@ @register_model("dummy") class DummyLM(LM): - def __init__(self) -> None: + def __init__(self, *args, **kwargs) -> None: super().__init__() @classmethod diff --git a/lm_eval/tasks/__init__.py b/lm_eval/tasks/__init__.py index eeb20dbf91..f925138607 100644 --- a/lm_eval/tasks/__init__.py +++ b/lm_eval/tasks/__init__.py @@ -2,9 +2,10 @@ import inspect import logging import os +from collections.abc import Mapping from functools import partial from pathlib import Path -from typing import Dict, List, Mapping, Optional, Union +from typing import Dict, List, Optional, Union from lm_eval import utils from lm_eval.api.group import ConfigurableGroup, GroupConfig @@ -25,10 +26,10 @@ class TaskManager: def __init__( self, - verbosity: Optional[str] = None, - include_path: Optional[Union[str, List]] = None, + verbosity: str | None = None, + include_path: str | list | None = None, include_defaults: bool = True, - metadata: Optional[dict] = None, + metadata: dict | None = None, ) -> None: if verbosity is not None: utils.setup_logging(verbosity) @@ -57,7 +58,7 @@ def __init__( def initialize_tasks( self, - include_path: Optional[Union[str, List]] = None, + include_path: str | list | None = None, include_defaults: bool = True, ) -> dict[str, dict]: """Creates a dictionary of tasks indexes. @@ -257,9 +258,9 @@ def _class_has_config_in_constructor(self, cls): def _load_individual_task_or_group( self, - name_or_config: Optional[Union[str, dict]] = None, - parent_name: Optional[str] = None, - update_config: Optional[dict] = None, + name_or_config: str | dict | None = None, + parent_name: str | None = None, + update_config: dict | None = None, ) -> Mapping: def _load_task(config, task): if "include" in config: @@ -411,7 +412,7 @@ def _process_group_config( group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) } - def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict: + def load_task_or_group(self, task_list: str | list | None = None) -> dict: """Loads a dictionary of task objects from a list :param task_list: Union[str, list] = None @@ -433,7 +434,7 @@ def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> di ) return all_loaded_tasks - def load_config(self, config: Dict): + def load_config(self, config: dict): return self._load_individual_task_or_group(config) def _get_task_and_group(self, task_dir: str): @@ -551,7 +552,7 @@ def _populate_tags_and_groups(config, task, tasks_and_groups, print_info): return tasks_and_groups -def get_task_name_from_config(task_config: Dict[str, str]) -> str: +def get_task_name_from_config(task_config: dict[str, str]) -> str: if "task" in task_config: return task_config["task"] if "dataset_name" in task_config: @@ -601,8 +602,8 @@ def _check_duplicates(task_dict: dict) -> None: def get_task_dict( - task_name_list: Union[str, List[Union[str, Dict, Task]]], - task_manager: Optional[TaskManager] = None, + task_name_list: str | list[str | dict | Task], + task_manager: TaskManager | None = None, ): """Creates a dictionary of task objects from either a name of task, config, or prepared Task object. @@ -680,10 +681,14 @@ def pretty_print_task(task_name, task_manager, indent: int): yaml_path = task_manager.task_index[task_name]["yaml_path"] yaml_path = Path(yaml_path) lm_eval_tasks_path = Path(__file__).parent - relative_yaml_path = yaml_path.relative_to(lm_eval_tasks_path) + try: + display_path = yaml_path.relative_to(lm_eval_tasks_path) + except ValueError: + # Path is outside lm_eval/tasks (e.g., from include_path) + display_path = yaml_path pad = " " * indent - eval_logger.info(f"{pad}Task: {task_name} ({relative_yaml_path})") + eval_logger.info(f"{pad}Task: {task_name} ({display_path})") # NOTE: Only nicely logs: # 1/ group diff --git a/lm_eval/utils.py b/lm_eval/utils.py index 07f262f4d6..9cc45013b0 100644 --- a/lm_eval/utils.py +++ b/lm_eval/utils.py @@ -840,3 +840,9 @@ def _request_with_retries(method, url, **kwargs): return False return True + + +def set_torch_seed(seed: int): + import torch + + torch.manual_seed(seed) diff --git a/pyproject.toml b/pyproject.toml index 67fe89dde8..9930a4ff0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=40.8.0", "wheel"] +requires = ["setuptools>=64.0"] build-backend = "setuptools.build_meta" [project] @@ -13,28 +13,20 @@ readme = "README.md" classifiers = [ "Development Status :: 3 - Alpha", "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] requires-python = ">=3.10" -license = { "text" = "MIT" } +license = { text = "MIT" } dependencies = [ - "accelerate>=0.26.0", - "evaluate", "datasets>=2.16.0", "evaluate>=0.4.0", + "jinja2", "jsonlines", - "numexpr", - "peft>=0.2.0", - "pybind11>=2.6.2", "pytablewriter", "rouge-score>=0.0.4", "sacrebleu>=1.5.0", "scikit-learn>=0.24.1", "sqlitedict", - "torch>=1.8", - "tqdm-multiprocess", - "transformers>=4.1", "zstandard", "dill", "word2number", @@ -57,48 +49,71 @@ Homepage = "https://github.com/EleutherAI/lm-evaluation-harness" Repository = "https://github.com/EleutherAI/lm-evaluation-harness" [project.optional-dependencies] -acpbench = ["lark>=1.1.9", "tarski[clingo]==0.8.2", "pddl==0.4.2", "kstar-planner==1.4.2"] +# Model backend dependencies api = ["requests", "aiohttp", "tenacity", "tqdm", "tiktoken"] -audiolm_qwen = ["librosa", "soundfile"] -dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "requests", "aiohttp", "tenacity", "tqdm", "tiktoken", "sentencepiece"] +hf = ["transformers>=4.1","torch>=1.8", "accelerate>=0.26.0", "peft>=0.2.0",] +vllm = ["vllm>=0.4.2"] gptq = ["auto-gptq[triton]>=0.6.0"] gptqmodel = ["gptqmodel>=1.0.9"] -hf_transfer = ["hf_transfer"] +ipex = ["optimum-intel"] ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"] +# mamba requires CUDA (nvcc) - cannot build on macOS/CPU-only systems +# mamba = ["mamba_ssm", "causal-conv1d==1.0.2", "torch"] +optimum = ["optimum[openvino]"] +sparsify = ["sparsify"] +sae_lens = ["sae_lens"] +# Task specific dependencies +acpbench = ["lark>=1.1.9", "tarski[clingo]==0.8.2", "pddl==0.4.2", "kstar-planner==1.4.2"] +audiolm_qwen = ["librosa", "soundfile"] +dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "requests", "aiohttp", "tenacity", "tqdm", "tiktoken", "sentencepiece", "ruff"] +hf_transfer = ["hf_transfer"] ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"] -ipex = ["optimum"] japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"] longbench = ["jieba", "fuzzywuzzy", "rouge"] libra = ["pymorphy2"] -mamba = ["mamba_ssm", "causal-conv1d==1.0.2", "torch"] math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"] multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"] -neuronx = ["optimum[neuronx]"] -optimum = ["optimum[openvino]"] -promptsource = ["promptsource>=0.2.3"] +#promptsource = [ +# "promptsource>=0.2.3 ; python_version <= '3.12'", +#] ruler = ["nltk", "wonderwords", "scipy"] -sae_lens = ["sae_lens"] sentencepiece = ["sentencepiece>=0.1.98"] -sparsify = ["sparsify"] discrim_eval = ["statsmodels==0.14.4"] -testing = ["pytest", "pytest-cov", "pytest-xdist"] unitxt = ["unitxt==1.22.0"] -vllm = ["vllm>=0.4.2"] wandb = ["wandb>=0.16.3", "pandas", "numpy"] zeno = ["pandas", "zeno-client"] tasks = [ - "lm_eval[acpbench]", "lm_eval[discrim_eval]", "lm_eval[ifeval]", "lm_eval[japanese_leaderboard]", "lm_eval[longbench]", "lm_eval[libra]", - "lm_eval[mamba]", "lm_eval[math]", "lm_eval[multilingual]", "lm_eval[ruler]", ] +[dependency-groups] +dev = [ + "lm_eval[api]", "lm_eval[dev]", "lm_eval[hf]","sentencepiece" +] + +[tool.uv] +conflicts = [ + [ + { extra = "acpbench" }, + { extra = "math" }, + ], + [ + { extra = "acpbench" }, + { extra = "tasks" }, + ], + [ + { extra = "gptq" }, + { extra = "vllm" }, + ], +] + [tool.pymarkdown] plugins.md013.enabled = false # line-length plugins.md024.allow_different_nesting = true # no-duplicate-headers @@ -107,18 +122,44 @@ plugins.md028.enabled = false # no-blanks-blockquote plugins.md029.allow_extended_start_values = true # ol-prefix plugins.md034.enabled = false # no-bare-urls -[tool.ruff.lint] -extend-select = ["I", "W605"] - -[tool.ruff.lint.isort] -lines-after-imports = 2 -known-first-party = ["lm_eval"] +[tool.ruff] +lint.extend-select = [ + "I", # isort + "UP", # pyupgrade + "E", # pycodestyle errors + "C419", # unnecessary-comprehension-in-call + "F", # pyflakes + "B", # flake8-bugbear + "SIM", # flake8-simplify + "RUF034", # useless-if-else + "W605", # invalid-escape-sequence + "FURB", # refurb +] +lint.fixable = [ + "I001", # unsorted-imports + "F401", # unused-import + "UP", # pyupgrade fixes +] +lint.ignore = [ + "E402", # module-import-not-at-top-of-file + "E731", # lambda-assignment + "E501", # line-too-long + "E111", # indentation-with-invalid-multiple + "E114", # indentation-with-invalid-multiple-comment + "E117", # over-indented + "E741", # ambiguous-variable-name + "E701", # multiple-statements-on-one-line-colon +] [tool.ruff.lint.extend-per-file-ignores] -"__init__.py" = ["F401", "F402", "F403"] -"utils.py" = ["F401"] - -[dependency-groups] -dev = [ - "api", "dev", "sentencepiece" +"__init__.py" = [ + "F401", # unused-import + "F402", # import-shadowed-by-loop-var + "F403", # undefined-local-with-import-star + "F405", # undefined-local-with-import-star-usage ] + +[tool.ruff.lint.isort] +combine-as-imports = true +known-first-party = ["lm_eval"] +lines-after-imports = 2 diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index d6e1198b1a..0000000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ --e . diff --git a/setup.py b/setup.py deleted file mode 100644 index b5d8fabb86..0000000000 --- a/setup.py +++ /dev/null @@ -1,5 +0,0 @@ -import setuptools - - -# This is to make sure that the package supports editable installs -setuptools.setup() diff --git a/tests/models/test_bos_handling.py b/tests/models/test_bos_handling.py index 0435f6ccc8..128287f890 100644 --- a/tests/models/test_bos_handling.py +++ b/tests/models/test_bos_handling.py @@ -46,16 +46,16 @@ def pythia_tokenizer(): @pytest.fixture(scope="module") -def gemma_tokenizer(): +def olmo_tokenizer(): """ - Load gemma-2-2b-it tokenizer for testing. + Load OLMo-3-7B-Instruct tokenizer for testing. Properties: - - BOS token: '' (ID: 2) + - BOS token: '<|endoftext|>' (ID: 100257) - DOES add BOS by default (add_bos_token=True in tokenizer) - Used to test tokenizers that add BOS by default """ - tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it") + tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-3-7B-Instruct") # Set pad token to avoid padding errors in batch encoding tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" @@ -150,14 +150,14 @@ def test_both_none_returns_empty(self): class TestDefaultsToNone: """Test that add_bos_token defaults to None, allowing tokenizer defaults.""" - @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "gemma_tokenizer"]) + @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "olmo_tokenizer"]) def test_huggingface_none_uses_tokenizer_default(self, tokenizer_name, request): """ HF: When add_bos_token=None, should respect tokenizer's default. Tests both tokenizer types: - Pythia: Doesn't add BOS by default - - Gemma: DOES add BOS by default + - OLMo: DOES add BOS by default """ tokenizer = request.getfixturevalue(tokenizer_name) mock_hflm = create_hf_mock(tokenizer, add_bos_token=None) @@ -166,14 +166,14 @@ def test_huggingface_none_uses_tokenizer_default(self, tokenizer_name, request): expected = tokenizer.encode("Hello") assert result == expected - @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "gemma_tokenizer"]) + @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "olmo_tokenizer"]) def test_vllm_none_uses_tokenizer_default(self, tokenizer_name, request): """ vLLM: When add_bos_token=None, should respect tokenizer's default. Tests both tokenizer types: - Pythia: Doesn't add BOS by default - - Gemma: DOES add BOS by default + - OLMo: DOES add BOS by default """ tokenizer = request.getfixturevalue(tokenizer_name) mock_vllm = create_vllm_mock(tokenizer, add_bos_token=None) @@ -191,7 +191,7 @@ def test_vllm_none_uses_tokenizer_default(self, tokenizer_name, request): class TestNoDuplicateBos: """Test that BOS tokens are never duplicated when already present.""" - @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "gemma_tokenizer"]) + @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "olmo_tokenizer"]) def test_huggingface_detects_bos_in_single_string(self, tokenizer_name, request): """HF: Should detect BOS prefix and avoid duplication.""" tokenizer = request.getfixturevalue(tokenizer_name) @@ -215,7 +215,7 @@ def test_huggingface_detects_bos_in_single_string(self, tokenizer_name, request) # Should avoid duplication (fewer or equal tokens) assert input_ids.shape[1] <= without_detection.shape[1] - @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "gemma_tokenizer"]) + @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "olmo_tokenizer"]) def test_huggingface_adds_bos_when_missing(self, tokenizer_name, request): """HF: Should add BOS when string doesn't have it (using add_special_tokens=True)""" tokenizer = request.getfixturevalue(tokenizer_name) @@ -228,13 +228,13 @@ def test_huggingface_adds_bos_when_missing(self, tokenizer_name, request): assert input_ids.tolist() == expected.tolist() - @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "gemma_tokenizer"]) + @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "olmo_tokenizer"]) def test_huggingface_follows_tokenizer_default(self, tokenizer_name, request): """ HF: When add_bos_token is not set (None), follows tokenizer default. - Pythia: Doesn't add BOS by default - - Gemma: DOES add BOS by default + - OLMo: DOES add BOS by default """ tokenizer = request.getfixturevalue(tokenizer_name) mock_hflm = create_hf_mock(tokenizer, add_bos_token=None) @@ -244,7 +244,7 @@ def test_huggingface_follows_tokenizer_default(self, tokenizer_name, request): assert input_ids.tolist() == expected.tolist() - @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "gemma_tokenizer"]) + @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "olmo_tokenizer"]) @pytest.mark.parametrize("add_bos_token", [None, True]) def test_vllm_handles_mixed_batch(self, tokenizer_name, add_bos_token, request): """ @@ -284,7 +284,7 @@ def test_vllm_handles_mixed_batch(self, tokenizer_name, add_bos_token, request): for i, exp in enumerate(expected): assert result[i] == exp - @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "gemma_tokenizer"]) + @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "olmo_tokenizer"]) @pytest.mark.parametrize("add_bos_token", [None, True]) def test_vllm_preserves_order_in_mixed_batch( self, tokenizer_name, add_bos_token, request @@ -328,7 +328,7 @@ def test_vllm_preserves_order_in_mixed_batch( class TestChatTemplateCompatibility: """Test that chat templates (which add BOS) work without duplication.""" - @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "gemma_tokenizer"]) + @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "olmo_tokenizer"]) def test_huggingface_chat_template_no_duplicate_bos(self, tokenizer_name, request): """ HF: Chat template adds BOS, tokenizer should not add another. @@ -352,7 +352,7 @@ def test_huggingface_chat_template_no_duplicate_bos(self, tokenizer_name, reques assert torch.equal(input_ids, expected) - @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "gemma_tokenizer"]) + @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "olmo_tokenizer"]) @pytest.mark.parametrize("add_bos_token", [None, True]) def test_vllm_mixed_chat_batch(self, tokenizer_name, add_bos_token, request): """ @@ -424,7 +424,7 @@ def test_huggingface_seq2seq_skips_causal_bos_logic(self, pythia_tokenizer): class TestLoglikelihoodBosHandling: """Test BOS handling in loglikelihood method.""" - @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "gemma_tokenizer"]) + @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "olmo_tokenizer"]) @pytest.mark.parametrize("add_bos_token", [None, True]) def test_empty_context_continuation_with_bos( self, tokenizer_name, add_bos_token, request @@ -479,7 +479,7 @@ def capture_and_return(reqs, disable_tqdm=False): ) assert continuation_enc == continuation_without_bos[1:] # Skip the BOS token - @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "gemma_tokenizer"]) + @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "olmo_tokenizer"]) @pytest.mark.parametrize("add_bos_token", [None, True]) def test_empty_context_continuation_without_bos( self, tokenizer_name, add_bos_token, request @@ -523,7 +523,7 @@ def capture_and_return(reqs, disable_tqdm=False): expected_continuation = tokenizer.encode(continuation, add_special_tokens=False) assert continuation_enc == expected_continuation - @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "gemma_tokenizer"]) + @pytest.mark.parametrize("tokenizer_name", ["pythia_tokenizer", "olmo_tokenizer"]) @pytest.mark.parametrize("add_bos_token", [None, True]) def test_context_with_bos_prefix(self, tokenizer_name, add_bos_token, request): """When context starts with BOS (e.g., from chat template), should not duplicate BOS.""" diff --git a/tests/scripts/test_zeno_visualize.py b/tests/scripts/test_zeno_visualize.py index cdbe7e5c54..d6a0ab2caf 100644 --- a/tests/scripts/test_zeno_visualize.py +++ b/tests/scripts/test_zeno_visualize.py @@ -3,10 +3,11 @@ import pytest + +pytest.importorskip("zeno_client") from scripts.zeno_visualize import sanitize_string -@pytest.skip("requires zeno_client dependency") def test_zeno_sanitize_string(): """ Test that the model_args handling logic in zeno_visualize.py properly handles diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 0000000000..9c9d544f8c --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,404 @@ +"""Tests for the registry system.""" + +import threading + +import pytest + +from lm_eval.api.registry import ( + Registry, + aggregation_registry, + filter_registry, + get_aggregation, + get_filter, + get_metric, + get_metric_aggregation, + get_model, + higher_is_better_registry, + is_higher_better, + metric_agg_registry, + metric_registry, + model_registry, + register_aggregation, + register_filter, + register_metric, + task_registry, +) + + +class TestRegistryBasics: + """Test basic Registry class functionality.""" + + def test_create_registry(self): + """Test creating a basic registry.""" + reg = Registry("test") + assert len(reg) == 0 + assert list(reg) == [] + + def test_decorator_registration(self): + """Test decorator-based registration.""" + reg = Registry("test") + + @reg.register("my_class") + class MyClass: + pass + + assert "my_class" in reg + assert reg.get("my_class") is MyClass + assert reg["my_class"] is MyClass + + def test_decorator_multiple_aliases(self): + """Test decorator with multiple aliases.""" + reg = Registry("test") + + @reg.register("alias1", "alias2", "alias3") + class MyClass: + pass + + assert reg.get("alias1") is MyClass + assert reg.get("alias2") is MyClass + assert reg.get("alias3") is MyClass + + def test_decorator_auto_name(self): + """Test decorator using class name when no alias provided.""" + reg = Registry("test") + + @reg.register() + class AutoNamedClass: + pass + + assert reg.get("AutoNamedClass") is AutoNamedClass + + def test_lazy_registration(self): + """Test lazy loading with module paths.""" + reg = Registry("test") + + # Register with lazy loading + reg.register("join", lazy="os.path:join") + + # Check it's stored as a string (placeholder) + assert isinstance(reg._objs["join"], str) + + # Access triggers materialization + import os.path + + result = reg.get("join") + assert result is os.path.join + assert callable(result) + + def test_unknown_key_error(self): + """Test error when accessing unknown key.""" + reg = Registry("test") + + with pytest.raises(KeyError) as exc_info: + reg.get("unknown") + + assert "Unknown test 'unknown'" in str(exc_info.value) + + def test_default_value(self): + """Test default value when key not found.""" + reg = Registry("test") + + assert reg.get("missing", "default") == "default" + assert reg.get("missing", None) is None + assert reg.get("missing", 0) == 0 + + def test_iteration(self): + """Test registry iteration.""" + reg = Registry("test") + + reg.register("a", lazy="os:getcwd") + reg.register("b", lazy="os:getenv") + reg.register("c", lazy="os:getpid") + + assert set(reg) == {"a", "b", "c"} + assert len(reg) == 3 + + def test_contains(self): + """Test 'in' operator.""" + reg = Registry("test") + reg.register("exists", lazy="os:getcwd") + + assert "exists" in reg + assert "missing" not in reg + + def test_keys_values_items(self): + """Test dict-like methods.""" + reg = Registry("test") + reg.register("a", lazy="os:getcwd") + reg.register("b", lazy="os:getenv") + + assert set(reg.keys()) == {"a", "b"} + assert len(list(reg.values())) == 2 + assert len(list(reg.items())) == 2 + + +class TestRegistryCollisions: + """Test collision handling in Registry.""" + + def test_duplicate_raises_error(self): + """Test that registering different objects under same alias raises error.""" + reg = Registry("test") + + @reg.register("name") + class First: + pass + + with pytest.raises(ValueError) as exc_info: + + @reg.register("name") + class Second: + pass + + assert "already registered" in str(exc_info.value) + + def test_placeholder_upgrade(self): + """Test that placeholder can be upgraded to concrete class.""" + reg = Registry("test") + + # Create a class to test with + class MyTestClass: + pass + + # Register as placeholder with correct module:class path + placeholder_path = f"{MyTestClass.__module__}:{MyTestClass.__name__}" + reg.register("my_alias", lazy=placeholder_path) + + # Registering the actual class should upgrade the placeholder + reg.register("my_alias")(MyTestClass) + + assert reg.get("my_alias") is MyTestClass + + def test_same_object_no_error(self): + """Test that registering same object twice doesn't raise error.""" + reg = Registry("test") + + class MyClass: + pass + + reg.register("name")(MyClass) + reg.register("name")(MyClass) # Should not raise + + assert reg.get("name") is MyClass + + +class TestRegistryFreeze: + """Test registry freezing.""" + + def test_freeze(self): + """Test that freeze makes registry read-only.""" + reg = Registry("test") + reg.register("a", lazy="os:getcwd") + reg.freeze() + + # Can still read + assert "a" in reg + + # But modifications fail + with pytest.raises(TypeError): + reg.register("b", lazy="os:getenv") + + def test_freeze_all(self): + """Test freeze_all function.""" + + # This would freeze all global registries - skip in test + # freeze_all() + pass + + +class TestRegistryThreadSafety: + """Test thread safety of Registry.""" + + def test_concurrent_registration(self): + """Test concurrent registration from multiple threads.""" + reg = Registry("test") + errors = [] + + def register_item(i): + try: + reg.register(f"item_{i}", lazy="os:getcwd") + except Exception as e: + errors.append(e) + + threads = [ + threading.Thread(target=register_item, args=(i,)) for i in range(100) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + assert len(reg) == 100 + + def test_concurrent_access(self): + """Test concurrent access from multiple threads.""" + reg = Registry("test") + reg.register("item", lazy="os.path:join") + + results = [] + + def access_item(): + result = reg.get("item") + results.append(result) + + threads = [threading.Thread(target=access_item) for _ in range(50)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All threads should get the same result + import os.path + + assert all(r is os.path.join for r in results) + + +class TestModelRegistry: + """Test model registry integration.""" + + def test_model_registry_exists(self): + """Test that model_registry is properly initialized.""" + assert model_registry is not None + + def test_lazy_model_loading(self): + """Test lazy loading of models. + + get_model() auto-imports lm_eval.models if registry is empty, + so this works without explicit import. + """ + # get_model auto-imports models if registry is empty + dummy_cls = get_model("dummy") + assert dummy_cls is not None + assert "DummyLM" in dummy_cls.__name__ + + # After get_model, registry should be populated + assert "dummy" in model_registry + assert "hf" in model_registry + + def test_get_model_error(self): + """Test get_model with unknown model.""" + with pytest.raises(ValueError) as exc_info: + get_model("nonexistent_model_xyz") + + assert "no model for this name found" in str(exc_info.value) + + +class TestFilterRegistry: + """Test filter registry integration.""" + + def test_filter_registry_exists(self): + """Test that filter_registry is properly initialized.""" + assert filter_registry is not None + + def test_register_filter(self): + """Test registering a filter.""" + from lm_eval.api.filter import Filter + + @register_filter("test_filter_unique") + class TestFilter(Filter): + def apply(self, resps, docs): + return resps + + assert "test_filter_unique" in filter_registry + assert get_filter("test_filter_unique") is TestFilter + + def test_get_filter_callable(self): + """Test get_filter with callable input.""" + + def my_filter(x): + return x + + assert get_filter(my_filter) is my_filter + + +class TestMetricRegistry: + """Test metric registry integration.""" + + def test_metric_registry_exists(self): + """Test that metric_registry is properly initialized.""" + assert metric_registry is not None + + def test_aggregation_registry_exists(self): + """Test that aggregation_registry is properly initialized.""" + assert aggregation_registry is not None + + def test_register_aggregation(self): + """Test registering an aggregation function.""" + + @register_aggregation("test_agg_unique") + def test_agg(items): + return sum(items) / len(items) + + assert "test_agg_unique" in aggregation_registry + assert get_aggregation("test_agg_unique") is test_agg + + def test_register_metric(self): + """Test registering a metric.""" + + # First register the aggregation + @register_aggregation("test_metric_agg") + def mean_agg(items): + return sum(items) / len(items) + + @register_metric( + metric="test_metric_unique", + higher_is_better=True, + aggregation="test_metric_agg", + ) + def test_metric(items): + return sum(1 for i in items if i) + + assert "test_metric_unique" in metric_registry + assert get_metric("test_metric_unique") is test_metric + assert is_higher_better("test_metric_unique") is True + assert get_metric_aggregation("test_metric_unique") is mean_agg + + def test_builtin_metrics_loaded(self): + """Test that built-in metrics are loaded.""" + # Import metrics module to trigger registration + from lm_eval.api import metrics # noqa: F401 + + # Check some common metrics are registered + assert "acc" in metric_registry + assert "mean" in aggregation_registry + + +class TestBackwardCompatibility: + """Test backward compatibility aliases.""" + + def test_registry_aliases(self): + """Test that UPPER_CASE aliases point to Registry instances.""" + from lm_eval.api.registry import ( + AGGREGATION_REGISTRY, + FILTER_REGISTRY, + HIGHER_IS_BETTER_REGISTRY, + METRIC_AGGREGATION_REGISTRY, + METRIC_REGISTRY, + MODEL_REGISTRY, + TASK_REGISTRY, + ) + + assert MODEL_REGISTRY is model_registry + assert TASK_REGISTRY is task_registry + assert FILTER_REGISTRY is filter_registry + assert METRIC_REGISTRY is metric_registry + assert AGGREGATION_REGISTRY is aggregation_registry + assert METRIC_AGGREGATION_REGISTRY is metric_agg_registry + assert HIGHER_IS_BETTER_REGISTRY is higher_is_better_registry + + +class TestRegistryClear: + """Test registry clear functionality (for test isolation).""" + + def test_clear(self): + """Test _clear method for test isolation.""" + reg = Registry("test") + reg.register("a", lazy="os:getcwd") + reg.register("b", lazy="os:getenv") + + assert len(reg) == 2 + + reg._clear() + + assert len(reg) == 0 + assert "a" not in reg