diff --git a/delphi/__main__.py b/delphi/__main__.py index d0341684..c0a749ab 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -86,12 +86,12 @@ def create_neighbours( elif constructor_cfg.neighbours_type == "decoder_similarity": neighbour_calculator = NeighbourCalculator( - autoencoder=saes[hookpoint].cuda(), number_of_neighbours=250 + autoencoder=saes[hookpoint].to("cuda"), number_of_neighbours=250 ) elif constructor_cfg.neighbours_type == "encoder_similarity": neighbour_calculator = NeighbourCalculator( - autoencoder=saes[hookpoint].cuda(), number_of_neighbours=250 + autoencoder=saes[hookpoint].to("cuda"), number_of_neighbours=250 ) else: raise ValueError( @@ -131,7 +131,7 @@ async def process_cache( } # The latent range to explain dataset = LatentDataset( - raw_dir=str(latents_path), + raw_dir=latents_path, sampler_cfg=run_cfg.sampler_cfg, constructor_cfg=run_cfg.constructor_cfg, modules=hookpoints, diff --git a/delphi/clients/client.py b/delphi/clients/client.py index a9551fcc..4fd41752 100644 --- a/delphi/clients/client.py +++ b/delphi/clients/client.py @@ -17,7 +17,7 @@ def __init__(self, model: str): @abstractmethod async def generate( self, prompt: Union[str, list[dict[str, str]]], **kwargs - ) -> Response: + ) -> str | Response: pass # @abstractmethod diff --git a/delphi/clients/offline.py b/delphi/clients/offline.py index 2c4d078a..0c0a6ee3 100644 --- a/delphi/clients/offline.py +++ b/delphi/clients/offline.py @@ -74,7 +74,11 @@ def __init__( self.statistics_path = Path("statistics") self.statistics_path.mkdir(parents=True, exist_ok=True) - async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs): + async def process_func( + self, + batches: Union[str, list[Union[dict[str, str], list[dict[str, str]]]]], + kwargs, + ): """ Process a single request. """ @@ -142,7 +146,9 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs): ) return new_response - async def generate(self, prompt: Union[str, list[dict[str, str]]], **kwargs) -> str: # type: ignore + async def generate( + self, prompt: Union[str, list[dict[str, str]]], **kwargs + ) -> Response: # type: ignore """ Enqueue a request and wait for the result. """ diff --git a/delphi/clients/openrouter.py b/delphi/clients/openrouter.py index 2a6e8b85..f5e5cbff 100644 --- a/delphi/clients/openrouter.py +++ b/delphi/clients/openrouter.py @@ -5,6 +5,7 @@ from ..logger import logger from .client import Client +from .types import ChatFormatRequest # Preferred provider routing arguments. # Change depending on what model you'd like to use. @@ -36,7 +37,11 @@ def postprocess(self, response): return Response(msg) async def generate( # type: ignore - self, prompt: str, raw: bool = False, max_retries: int = 1, **kwargs # type: ignore + self, + prompt: ChatFormatRequest, + raw: bool = False, + max_retries: int = 1, + **kwargs, # type: ignore ) -> Response: # type: ignore kwargs.pop("schema", None) max_tokens = kwargs.pop("max_tokens", 500) diff --git a/delphi/clients/types.py b/delphi/clients/types.py new file mode 100644 index 00000000..d5eca235 --- /dev/null +++ b/delphi/clients/types.py @@ -0,0 +1,9 @@ +from typing import Literal, TypedDict, Union + + +class Message(TypedDict): + content: str + role: Literal["system", "user", "assistant"] + + +ChatFormatRequest = Union[str, list[str], list[Message], None] diff --git a/delphi/explainers/contrastive_explainer.py b/delphi/explainers/contrastive_explainer.py index 3d675563..359a2234 100644 --- a/delphi/explainers/contrastive_explainer.py +++ b/delphi/explainers/contrastive_explainer.py @@ -4,7 +4,7 @@ import torch from delphi.explainers.default.prompts import SYSTEM_CONTRASTIVE -from delphi.explainers.explainer import Explainer, ExplainerResult +from delphi.explainers.explainer import Explainer, ExplainerResult, Response from delphi.latents.latents import ActivatingExample, LatentRecord, NonActivatingExample @@ -54,7 +54,11 @@ async def __call__(self, record: LatentRecord) -> ExplainerResult: ) try: - explanation = self.parse_explanation(response.text) + if isinstance(response, Response): + response_text = response.text + else: + response_text = response + explanation = self.parse_explanation(response_text) if self.verbose: from ..logger import logger diff --git a/delphi/explainers/explainer.py b/delphi/explainers/explainer.py index 5c7c719e..9967dda0 100644 --- a/delphi/explainers/explainer.py +++ b/delphi/explainers/explainer.py @@ -8,7 +8,7 @@ import aiofiles -from ..clients.client import Client +from ..clients.client import Client, Response from ..latents.latents import ActivatingExample, LatentRecord from ..logger import logger @@ -44,6 +44,7 @@ async def __call__(self, record: LatentRecord) -> ExplainerResult: response = await self.client.generate( messages, temperature=self.temperature, **self.generation_kwargs ) + assert isinstance(response, Response) try: explanation = self.parse_explanation(response.text) diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index bbfac560..32f74da9 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -1,8 +1,8 @@ import json from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Callable import numpy as np import torch @@ -15,8 +15,49 @@ from delphi.config import CacheConfig from delphi.latents.collect_activations import collect_activations -location_tensor_shape = Float[Tensor, "batch sequence num_latents"] -token_tensor_shape = Float[Tensor, "batch sequence"] +location_tensor_type = Int[Tensor, "batch_sequence 3"] +activation_tensor_type = Float[Tensor, "batch_sequence"] +token_tensor_type = Int[Tensor, "batch sequence"] +latent_tensor_type = Float[Tensor, "batch sequence num_latents"] + + +def get_nonzeros_batch( + latents: latent_tensor_type, +) -> tuple[ + Float[Tensor, "batch sequence num_latents"], Float[Tensor, "batch sequence "] +]: + """ + Get non-zero activations for large batches that exceed int32 max value. + + Args: + latents: Input latent activations. + + Returns: + tuple[Tensor, Tensor]: Non-zero latent locations and activations. + """ + # Calculate the maximum batch size that fits within sys.maxsize + max_batch_size = torch.iinfo(torch.int32).max // ( + latents.shape[1] * latents.shape[2] + ) + nonzero_latent_locations = [] + nonzero_latent_activations = [] + + for i in range(0, latents.shape[0], max_batch_size): + batch = latents[i : i + max_batch_size] + + # Get nonzero locations and activations + batch_locations = torch.nonzero(batch.abs() > 1e-5) + batch_activations = batch[batch.abs() > 1e-5] + + # Adjust indices to account for batching + batch_locations[:, 0] += i + nonzero_latent_locations.append(batch_locations) + nonzero_latent_activations.append(batch_activations) + + # Concatenate results + nonzero_latent_locations = torch.cat(nonzero_latent_locations, dim=0) + nonzero_latent_activations = torch.cat(nonzero_latent_activations, dim=0) + return nonzero_latent_locations, nonzero_latent_activations class InMemoryCache: @@ -37,25 +78,25 @@ def __init__( filters: Filters for selecting specific latents. batch_size: Size of batches for processing. Defaults to 64. """ - self.latent_locations_batches: dict[str, list[location_tensor_shape]] = ( + self.latent_locations_batches: dict[str, list[location_tensor_type]] = ( defaultdict(list) ) - self.latent_activations_batches: dict[str, list[location_tensor_shape]] = ( + self.latent_activations_batches: dict[str, list[latent_tensor_type]] = ( defaultdict(list) ) - self.tokens_batches: dict[str, list[token_tensor_shape]] = defaultdict(list) + self.tokens_batches: dict[str, list[token_tensor_type]] = defaultdict(list) - self.latent_locations: dict[str, location_tensor_shape] = {} - self.latent_activations: dict[str, location_tensor_shape] = {} - self.tokens: dict[str, token_tensor_shape] = {} + self.latent_locations: dict[str, location_tensor_type] = {} + self.latent_activations: dict[str, latent_tensor_type] = {} + self.tokens: dict[str, token_tensor_type] = {} self.filters = filters self.batch_size = batch_size def add( self, - latents: location_tensor_shape, - tokens: token_tensor_shape, + latents: latent_tensor_type, + tokens: token_tensor_type, batch_number: int, module_path: str, ): @@ -96,47 +137,9 @@ def save(self): self.tokens_batches[module_path], dim=0 ) - def get_nonzeros_batch( - self, latents: location_tensor_shape - ) -> tuple[ - Float[Tensor, "batch sequence num_latents"], Float[Tensor, "batch sequence "] - ]: - """ - Get non-zero activations for large batches that exceed int32 max value. - - Args: - latents: Input latent activations. - - Returns: - tuple[Tensor, Tensor]: Non-zero latent locations and activations. - """ - # Calculate the maximum batch size that fits within sys.maxsize - max_batch_size = torch.iinfo(torch.int32).max // ( - latents.shape[1] * latents.shape[2] - ) - nonzero_latent_locations = [] - nonzero_latent_activations = [] - - for i in range(0, latents.shape[0], max_batch_size): - batch = latents[i : i + max_batch_size] - - # Get nonzero locations and activations - batch_locations = torch.nonzero(batch.abs() > 1e-5) - batch_activations = batch[batch.abs() > 1e-5] - - # Adjust indices to account for batching - batch_locations[:, 0] += i - nonzero_latent_locations.append(batch_locations) - nonzero_latent_activations.append(batch_activations) - - # Concatenate results - nonzero_latent_locations = torch.cat(nonzero_latent_locations, dim=0) - nonzero_latent_activations = torch.cat(nonzero_latent_activations, dim=0) - return nonzero_latent_locations, nonzero_latent_activations - - def get_nonzeros(self, latents: location_tensor_shape, module_path: str) -> tuple[ - location_tensor_shape, - location_tensor_shape, + def get_nonzeros(self, latents: latent_tensor_type, module_path: str) -> tuple[ + location_tensor_type, + activation_tensor_type, ]: """ Get the nonzero latent locations and activations. @@ -153,7 +156,7 @@ def get_nonzeros(self, latents: location_tensor_shape, module_path: str) -> tupl ( nonzero_latent_locations, nonzero_latent_activations, - ) = self.get_nonzeros_batch(latents) + ) = get_nonzeros_batch(latents) else: nonzero_latent_locations = torch.nonzero(latents.abs() > 1e-5) nonzero_latent_activations = latents[latents.abs() > 1e-5] @@ -209,8 +212,8 @@ def __init__( self.filter_submodules(filters) def load_token_batches( - self, n_tokens: int, tokens: token_tensor_shape - ) -> list[token_tensor_shape]: + self, n_tokens: int, tokens: token_tensor_type + ) -> list[token_tensor_type]: """ Load and prepare token batches for processing. @@ -248,7 +251,7 @@ def filter_submodules(self, filters: dict[str, Float[Tensor, "indices"]]): ] self.hookpoint_to_sparse_encode = filtered_submodules - def run(self, n_tokens: int, tokens: token_tensor_shape): + def run(self, n_tokens: int, tokens: token_tensor_type): """ Run the latent caching process. @@ -520,11 +523,11 @@ def generate_statistics_cache( print(f"Fraction of strong single token latents: {strong_token_fraction:%}") return CacheStatistics( - frac_alive=fraction_alive, - frac_fired_1pct=one_percent, - frac_fired_10pct=ten_percent, - frac_weak_single_token=single_token_fraction, - frac_strong_single_token=strong_token_fraction, + frac_alive=float(fraction_alive), + frac_fired_1pct=float(one_percent), + frac_fired_10pct=float(ten_percent), + frac_weak_single_token=float(single_token_fraction), + frac_strong_single_token=float(strong_token_fraction), ) diff --git a/delphi/latents/collect_activations.py b/delphi/latents/collect_activations.py index ec48eb03..dddc6941 100644 --- a/delphi/latents/collect_activations.py +++ b/delphi/latents/collect_activations.py @@ -25,7 +25,9 @@ def collect_activations( handles = [] def create_hook(hookpoint: str, transcode: bool = False): - def hook_fn(module: nn.Module, input: Any, output: Tensor) -> Tensor | None: + def hook_fn( + module: nn.Module, input: Any, output: Tensor | tuple[Tensor] + ) -> Tensor | None: # If output is a tuple (like in some transformer layers), take first element if transcode: if isinstance(input, tuple): diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 25c60652..c64b7f97 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -6,7 +6,7 @@ import faiss import numpy as np import torch -from jaxtyping import Float +from jaxtyping import Bool, Float, Int from sentence_transformers import SentenceTransformer from torch import Tensor from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -31,7 +31,7 @@ def get_model(name: str, device: str = "cuda") -> SentenceTransformer: def prepare_non_activating_examples( - tokens: Float[Tensor, "examples ctx_len"], + tokens: Int[Tensor, "examples ctx_len"], distance: float, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, ) -> list[NonActivatingExample]: @@ -45,7 +45,7 @@ def prepare_non_activating_examples( return [ NonActivatingExample( tokens=toks, - activations=torch.zeros_like(toks), + activations=torch.zeros_like(toks, dtype=torch.float), normalized_activations=None, distance=distance, str_tokens=tokenizer.batch_decode(toks), @@ -57,14 +57,14 @@ def prepare_non_activating_examples( def _top_k_pools( max_buffer: Float[Tensor, "batch"], split_activations: Float[Tensor, "activations ctx_len"], - buffer_tokens: Float[Tensor, "batch ctx_len"], + buffer_tokens: Int[Tensor, "batch ctx_len"], max_examples: int, -) -> tuple[Float[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]: +) -> tuple[Int[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]: """ Get the top k activation pools. Args: - max_buffer: The maximum buffer values. + max_buffer: The maxima of each context window's activations. split_activations: The split activations. buffer_tokens: The buffer tokens. max_examples: The maximum number of examples. @@ -83,12 +83,12 @@ def _top_k_pools( def pool_max_activation_windows( activations: Float[Tensor, "examples"], - tokens: Float[Tensor, "windows seq"], - ctx_indices: Float[Tensor, "examples"], - index_within_ctx: Float[Tensor, "examples"], + tokens: Int[Tensor, "windows seq"], + ctx_indices: Int[Tensor, "examples"], + index_within_ctx: Int[Tensor, "examples"], ctx_len: int, max_examples: int, -) -> tuple[Float[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]: +) -> tuple[Int[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]: """ Pool max activation windows from the buffer output and update the latent record. @@ -99,6 +99,8 @@ def pool_max_activation_windows( index_within_ctx : The index within the context. ctx_len : The context length. max_examples : The maximum number of examples. + Returns: + The token windows and activation windows. """ # unique_ctx_indices: array of distinct context window indices in order of first # appearance. sequential integers from 0 to batch_size * cache_token_length//ctx_len @@ -129,7 +131,7 @@ def constructor( record: LatentRecord, activation_data: ActivationData, constructor_cfg: ConstructorConfig, - tokens: Float[Tensor, "batch seq"], + tokens: Int[Tensor, "batch seq"], tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, all_data: Optional[dict[int, ActivationData]] = None, seed: int = 42, @@ -422,8 +424,8 @@ def faiss_non_activation_windows( def neighbour_non_activation_windows( record: LatentRecord, - not_active_mask: Float[Tensor, "windows"], - tokens: Float[Tensor, "batch seq"], + not_active_mask: Bool[Tensor, "windows"], + tokens: Int[Tensor, "batch seq"], all_data: dict[int, ActivationData], ctx_len: int, n_not_active: int, @@ -513,8 +515,8 @@ def neighbour_non_activation_windows( def random_non_activating_windows( - available_indices: Float[Tensor, "windows"], - reshaped_tokens: Float[Tensor, "windows ctx_len"], + available_indices: Int[Tensor, "windows"], + reshaped_tokens: Int[Tensor, "windows ctx_len"], n_not_active: int, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, seed: int = 42, diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index cf142f35..843bdba7 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -3,7 +3,7 @@ import blobfile as bf import orjson -from jaxtyping import Float +from jaxtyping import Float, Int from torch import Tensor from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -35,7 +35,7 @@ class ActivationData(NamedTuple): Represents the activation data for a latent. """ - locations: Float[Tensor, "n_examples 2"] + locations: Int[Tensor, "n_examples 3"] """Tensor of latent locations.""" activations: Float[Tensor, "n_examples"] @@ -69,7 +69,7 @@ class Example: A single example of latent data. """ - tokens: Float[Tensor, "ctx_len"] + tokens: Int[Tensor, "ctx_len"] """Tokenized input sequence.""" activations: Float[Tensor, "ctx_len"] diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index bab251bc..548ba565 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -7,7 +7,7 @@ import numpy as np import torch -from jaxtyping import Float +from jaxtyping import Float, Int from safetensors.numpy import load_file from torch import Tensor from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -34,10 +34,10 @@ class TensorBuffer: module_path: str """Path of the module.""" - latents: Optional[Float[Tensor, "num_latents"]] = None + latents: Optional[Int[Tensor, "num_latents"]] = None """Tensor of latent indices.""" - _tokens: Optional[Float[Tensor, "batch seq"]] = None + _tokens: Optional[Int[Tensor, "batch seq"]] = None """Tensor of tokens.""" def __iter__(self): @@ -60,7 +60,7 @@ def __iter__(self): ) @property - def tokens(self) -> Float[Tensor, "batch seq"] | None: + def tokens(self) -> Int[Tensor, "batch seq"] | None: if self._tokens is None: self._tokens = self.load_tokens() return self._tokens @@ -82,15 +82,14 @@ def load_data_per_latent(self): def load( self, ) -> tuple[ - Float[Tensor, "locations 2"], + Int[Tensor, "locations 3"], Float[Tensor, "activations"], - Float[Tensor, "batch seq"] | None, + Int[Tensor, "batch seq"] | None, ]: - """Load the tensor buffer's data. + """Load stored tensor buffer data. Returns: - Tuple[Tensor, Tensor, Optional[Tensor]]: Locations, activations, - and tokens (if present in the cache). + Tuple[Tensor, Tensor, Optional[Tensor]]: Locations, activations, and tokens. """ split_data = load_file(self.path) first_latent = int(self.path.split("/")[-1].split("_")[0]) @@ -110,7 +109,7 @@ def load( return locations, activations, tokens - def load_tokens(self) -> Float[Tensor, "batch seq"] | None: + def load_tokens(self) -> Int[Tensor, "batch seq"] | None: _, _, tokens = self.load() return tokens @@ -122,7 +121,7 @@ class LatentDataset: def __init__( self, - raw_dir: str, + raw_dir: os.PathLike, sampler_cfg: SamplerConfig, constructor_cfg: ConstructorConfig, tokenizer: Optional[PreTrainedTokenizer | PreTrainedTokenizerFast] = None, @@ -181,8 +180,7 @@ def __init__( if self.constructor_cfg.non_activating_source == "neighbours": # path is always going to end with /latents - split_path = raw_dir.split("/")[:-1] - neighbours_path = "/".join(split_path) + "/neighbours" + neighbours_path = Path(raw_dir).parent / "neighbours" self.neighbours = self.load_neighbours( neighbours_path, self.constructor_cfg.neighbours_type ) @@ -215,16 +213,16 @@ def load_tokens(self): ) return self.tokens - def load_neighbours(self, neighbours_path: str, neighbours_type: str): + def load_neighbours(self, neighbours_path: Path, neighbours_type: str): neighbours = {} for hookpoint in self.modules: with open( - neighbours_path + f"/{hookpoint}-{neighbours_type}.json", "r" + neighbours_path / f"{hookpoint}-{neighbours_type}.json", "r" ) as f: neighbours[hookpoint] = json.load(f) return neighbours - def _edges(self, raw_dir: str, module: str) -> list[tuple[int, int]]: + def _edges(self, raw_dir: os.PathLike, module: str) -> list[tuple[int, int]]: module_dir = Path(raw_dir) / module safetensor_files = [f for f in module_dir.glob("*.safetensors")] edges = [] @@ -234,12 +232,12 @@ def _edges(self, raw_dir: str, module: str) -> list[tuple[int, int]]: edges.sort(key=lambda x: x[0]) return edges - def _build(self, raw_dir: str): + def _build(self, raw_dir: os.PathLike): """ Build dataset buffers which load all cached latents. Args: - raw_dir (str): Directory containing raw latent data. + raw_dir (os.PathLike): Directory containing raw latent data. modules (Optional[list[str]]): list of module names to include. """ @@ -256,14 +254,14 @@ def _build(self, raw_dir: str): def _build_selected( self, - raw_dir: str, + raw_dir: os.PathLike, latents: dict[str, torch.Tensor], ): """ Build a dataset buffer which loads only selected latents. Args: - raw_dir (str): Directory containing raw latent data. + raw_dir (os.PathLike): Directory containing raw latent data. latents (dict[str, Union[int, torch.Tensor]]): Dictionary of latents per module. """ @@ -306,7 +304,7 @@ def __len__(self): """Return the number of buffers in the dataset.""" return len(self.buffers) - def _load_all_data(self, raw_dir: str, modules: list[str]): + def _load_all_data(self, raw_dir: os.PathLike, modules: list[str]): """For each module, load all locations and activations""" all_data = {} for buffer in self.buffers: diff --git a/delphi/latents/samplers.py b/delphi/latents/samplers.py index 1db76df4..4e1c9321 100644 --- a/delphi/latents/samplers.py +++ b/delphi/latents/samplers.py @@ -68,8 +68,9 @@ def train( logger.warning( "n_train is greater than the number of examples, using all examples" ) - - selected_examples = random.sample(examples, n_train) + selected_examples = examples + else: + selected_examples = random.sample(examples, n_train) selected_examples = normalize_activations(selected_examples, max_activation) return selected_examples case "quantiles": diff --git a/delphi/pipeline.py b/delphi/pipeline.py index b3ead562..92bb2f5a 100644 --- a/delphi/pipeline.py +++ b/delphi/pipeline.py @@ -1,15 +1,16 @@ import asyncio +from collections.abc import AsyncIterable, Awaitable, Callable from functools import wraps -from typing import Any, AsyncIterable, Callable +from typing import Any from tqdm.asyncio import tqdm def process_wrapper( - function: Callable, + function: Callable[..., Awaitable], preprocess: Callable | None = None, postprocess: Callable | None = None, -) -> Callable: +) -> Callable[..., Awaitable]: """ Wraps a function with optional preprocessing and postprocessing steps. @@ -79,7 +80,7 @@ def __init__(self, loader: AsyncIterable | Callable, *pipes: Pipe | Callable): Args: loader (Callable): The loader to be executed first. - *pipes (list[Pipe]): Pipes to be executed in the pipeline. + *pipes (list[Pipe | Callable]): Pipes to be executed in the pipeline. """ self.loader = loader @@ -102,16 +103,14 @@ async def run(self, max_concurrent: int = 10) -> list[Any]: progress_bar = tqdm(desc="Processing items") number_of_items = 0 - async def process_and_update(item, semaphore, count): - result = await self.process_item(item, semaphore, count) + async def process_and_update(item, semaphore): + result = await self.process_item(item, semaphore) progress_bar.update(1) return result async for item in self.generate_items(): number_of_items += 1 - task = asyncio.create_task( - process_and_update(item, semaphore, number_of_items) - ) + task = asyncio.create_task(process_and_update(item, semaphore)) tasks.add(task) if len(tasks) >= max_concurrent: @@ -148,16 +147,13 @@ async def generate_items(self) -> AsyncIterable[Any]: else: raise TypeError("The first pipe must be an async iterable or a callable") - async def process_item( - self, item: Any, semaphore: asyncio.Semaphore, count: int - ) -> Any: + async def process_item(self, item: Any, semaphore: asyncio.Semaphore) -> Any: """ Processes a single item through all pipes except the first one. Args: item (Any): The item to be processed. semaphore (asyncio.Semaphore): Semaphore for controlling concurrency. - count (int): The count of the current item being processed. Returns: Any: The processed item. diff --git a/delphi/scorers/classifier/classifier.py b/delphi/scorers/classifier/classifier.py index 45cb6b94..55ef718e 100644 --- a/delphi/scorers/classifier/classifier.py +++ b/delphi/scorers/classifier/classifier.py @@ -2,11 +2,13 @@ import json import random import re +import traceback from abc import abstractmethod +from typing import Literal import numpy as np -from ...clients.client import Client +from ...clients.client import Client, Response from ...latents import LatentRecord from ...logger import logger from ..scorer import Scorer, ScorerResult @@ -100,19 +102,20 @@ async def _generate( predictions = [None] * self.n_examples_shown probabilities = [None] * self.n_examples_shown else: + assert isinstance(response, Response) selections = response.text logprobs = response.logprobs if self.log_prob else None try: predictions, probabilities = self._parse(selections, logprobs) - except Exception as e: - logger.error(f"Parsing selections failed: {e}") + except Exception: + logger.error("Parsing selections failed:\n" + traceback.format_exc()) predictions = [None] * self.n_examples_shown probabilities = [None] * self.n_examples_shown results = [] for sample, prediction, probability in zip(batch, predictions, probabilities): result = sample.data - result.prediction = prediction + result.prediction = bool(prediction) if prediction is not None else None if prediction is not None: result.correct = prediction == result.activating else: @@ -136,7 +139,7 @@ def _parse(self, string, logprobs=None): match = re.search(pattern, string) if match is None: raise ValueError("No match found in string") - predictions: list[bool] = json.loads(match.group(0)) + predictions: list[bool | Literal[0, 1]] = json.loads(match.group(0)) assert len(predictions) == self.n_examples_shown probabilities = ( self._parse_logprobs(logprobs) diff --git a/delphi/scorers/simulator/oai_autointerp/activations/activation_records.py b/delphi/scorers/simulator/oai_autointerp/activations/activation_records.py index f8c4b0a0..de7c4618 100644 --- a/delphi/scorers/simulator/oai_autointerp/activations/activation_records.py +++ b/delphi/scorers/simulator/oai_autointerp/activations/activation_records.py @@ -1,7 +1,8 @@ """Utilities for formatting activation records into prompts.""" import math -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional from .activations import ActivationRecord diff --git a/delphi/scorers/simulator/oai_autointerp/activations/activations.py b/delphi/scorers/simulator/oai_autointerp/activations/activations.py index 090535cf..c71d1a95 100644 --- a/delphi/scorers/simulator/oai_autointerp/activations/activations.py +++ b/delphi/scorers/simulator/oai_autointerp/activations/activations.py @@ -13,7 +13,7 @@ class ActivationRecord(Serializable): tokens: list[str] """Tokens in the text sequence, represented as strings.""" - activations: list[float] + activations: list[int | float] """Raw activation values for the neuron on each token in the text sequence.""" diff --git a/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py b/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py index dbdc2a16..3141d23a 100644 --- a/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py +++ b/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py @@ -2,7 +2,8 @@ import asyncio import logging -from typing import Any, Callable, Sequence +from collections.abc import Callable, Sequence +from typing import Any import numpy as np diff --git a/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py b/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py index d8c056c4..0306576d 100644 --- a/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py +++ b/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py @@ -6,8 +6,9 @@ import logging from abc import ABC, abstractmethod from collections import OrderedDict +from collections.abc import Sequence from enum import Enum -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Union import numpy as np from pydantic import BaseModel diff --git a/delphi/sparse_coders/load_sparsify.py b/delphi/sparse_coders/load_sparsify.py index 4bf3ce04..6cb77e24 100644 --- a/delphi/sparse_coders/load_sparsify.py +++ b/delphi/sparse_coders/load_sparsify.py @@ -1,14 +1,30 @@ +from collections.abc import Callable from functools import partial from pathlib import Path -from typing import Callable +from typing import Optional, Protocol, Union import torch -from sparsify import SparseCoder +import torch._dynamo.eval_frame +from sparsify import SparseCoder, SparseCoderConfig +from sparsify.sparse_coder import EncoderOutput from torch import Tensor from transformers import PreTrainedModel -def sae_dense_latents(x: Tensor, sae: SparseCoder) -> Tensor: +class PotentiallyWrappedSparseCoder(Protocol): + def encode(self, x: Tensor) -> EncoderOutput: ... + + def to( + self, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.nn.Module: ... + + cfg: SparseCoderConfig + num_latents: int + + +def sae_dense_latents(x: Tensor, sae: PotentiallyWrappedSparseCoder) -> Tensor: """Run `sae` on `x`, yielding the dense activations.""" x_in = x.reshape(-1, x.shape[-1]) encoded = sae.encode(x_in) @@ -51,7 +67,7 @@ def load_sparsify_sparse_coders( hookpoints: list[str], device: str | torch.device, compile: bool = False, -) -> dict[str, SparseCoder]: +) -> dict[str, PotentiallyWrappedSparseCoder]: """ Load sparsify sparse coders for specified hookpoints. diff --git a/delphi/sparse_coders/sparse_model.py b/delphi/sparse_coders/sparse_model.py index fdd5769a..0b7654dc 100644 --- a/delphi/sparse_coders/sparse_model.py +++ b/delphi/sparse_coders/sparse_model.py @@ -1,14 +1,17 @@ -from typing import Callable +from collections.abc import Callable import torch import torch.nn as nn -from sparsify import SparseCoder from transformers import PreTrainedModel from delphi.config import RunConfig from .custom.gemmascope import load_gemma_autoencoders -from .load_sparsify import load_sparsify_hooks, load_sparsify_sparse_coders +from .load_sparsify import ( + PotentiallyWrappedSparseCoder, + load_sparsify_hooks, + load_sparsify_sparse_coders, +) def load_hooks_sparse_coders( @@ -75,7 +78,7 @@ def load_sparse_coders( run_cfg: RunConfig, device: str | torch.device, compile: bool = False, -) -> dict[str, nn.Module] | dict[str, SparseCoder]: +) -> dict[str, nn.Module] | dict[str, PotentiallyWrappedSparseCoder]: """ Load sparse coders for specified hookpoints. diff --git a/delphi/tests/client_test.py b/delphi/tests/client_test.py new file mode 100644 index 00000000..0c781529 --- /dev/null +++ b/delphi/tests/client_test.py @@ -0,0 +1,211 @@ +import logging +import os +from typing import Literal + +import dotenv +import fire +import torch +from beartype.claw import beartype_package +from jaxtyping import Int +from torch import Tensor +from transformers import AutoTokenizer + +beartype_package("delphi") + +from delphi.clients import Client, Offline, OpenRouter # noqa: E402 +from delphi.clients.types import Message # noqa: E402 +from delphi.explainers import DefaultExplainer # noqa: E402 +from delphi.latents.latents import ( # noqa: E402 + ActivatingExample, + Latent, + LatentRecord, + NonActivatingExample, +) +from delphi.latents.samplers import SamplerConfig, sampler # noqa: E402 +from delphi.logger import logger # noqa: E402 +from delphi.scorers import DetectionScorer, FuzzingScorer # noqa: E402 + +logger.addHandler(logging.StreamHandler()) + + +async def main( + explainer_provider: Literal["offline", "openrouter"] = "offline", + # meta-llama/llama-3.3-70b-instruct + explainer_model: str = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4", + model_max_len: int = 5120, + num_gpus: int = 1, + scorer_type: Literal["fuzz", "detect"] = "fuzz", +): + """Test different client and that the explainer + and scorer are calling the client correctly. + + Args: + explainer_provider (Literal["offline", "openrouter"], optional): + Which client type to use. Defaults to "offline". + explainer_model (str, optional): VLLM model name or OpenRouter ID. + Defaults to "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4". + model_max_len (int, optional): Maximum length for VLLM. Defaults to 5120. + num_gpus (int, optional): Number of GPUs to use for VLLM (TP size). + Defaults to 1. + scorer_type (Literal["fuzz", "detect"], optional): Scoring type to use. + Defaults to "fuzz". + """ + + def make_scorer(client: Client): + if scorer_type == "fuzz": + return FuzzingScorer(client, verbose=True) + elif scorer_type == "detect": + return DetectionScorer(client, verbose=True) + # other cases impossible due to beartype + + print("Creating data") + + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m") + + texts_activating = [ + "I like dogs. Dogs are great.", + "Dog dog dog dog dog dog", + ] + texts_non_activating = [ + "I like cats. Cats are great.", + "Cat cat cat cat cat cat", + ] + explanation = "Sentences mentioning dogs." + activating_examples = [] + for text in texts_activating: + token_ids: Int[Tensor, "ctx_len"] = tokenizer(text, return_tensors="pt")[ + "input_ids" + ][ + 0 + ] # type: ignore + dog_tokens = tokenizer("Dog dog Dog dogs Dogs", return_tensors="pt")[ + "input_ids" + ][ + 0 + ] # type: ignore + activating_examples.append( + ActivatingExample( + tokens=token_ids, + activations=(token_ids[:, None] == dog_tokens[None, :]) + .any(dim=1) + .float(), + str_tokens=tokenizer.batch_decode(token_ids, skip_special_tokens=True), + ) + ) + non_activating_examples = [] + + for text in texts_non_activating: + token_ids: Int[Tensor, "ctx_len"] = tokenizer(text, return_tensors="pt")[ + "input_ids" + ][0] + non_activating_examples.append( + NonActivatingExample( + tokens=token_ids, + activations=torch.rand_like(token_ids, dtype=torch.float32), + str_tokens=tokenizer.batch_decode(token_ids, skip_special_tokens=True), + ) + ) + + record = LatentRecord( + latent=Latent("test", 0), + examples=activating_examples, + not_active=non_activating_examples, + explanation=explanation, + ) + record = sampler( + record, + SamplerConfig( + n_examples_train=len(activating_examples), + n_examples_test=len(activating_examples), + n_quantiles=1, + train_type="quantiles", + test_type="quantiles", + ), + ) + + most_recent_generation = None + + class MockClient(Client): + async def generate(self, prompt, **kwargs) -> str: + nonlocal most_recent_generation + most_recent_generation = prompt, kwargs + raise NotImplementedError("Prompt received") + + client = MockClient("") + gen_kwargs_dict = { + "max_length": 100, + "num_return_sequences": 1, + } + explainer = DefaultExplainer( + client, verbose=True, generation_kwargs=gen_kwargs_dict + ) + try: + await explainer(record) + except NotImplementedError: + pass + assert most_recent_generation is not None, "Prompt not received" + full_prompt: list[Message] = most_recent_generation[0] + last_element: Message = full_prompt[-1] + assert "dog" in last_element["content"] + print(last_element["content"]) + scorer = make_scorer(client) + try: + await scorer(record) + except NotImplementedError: + pass + assert most_recent_generation is not None, "Prompt not received" + full_prompt: list[Message] = most_recent_generation[0] + last_element: Message = full_prompt[-1] + assert "dog" in last_element["content"] + + print("Mock tests passed") + + print("Loading model") + if explainer_provider == "offline": + client = Offline( + explainer_model, + max_memory=0.9, + # Explainer models context length - must be able to accommodate the longest + # set of examples + max_model_len=model_max_len, + num_gpus=num_gpus, + statistics=False, + ) + elif explainer_provider == "openrouter": + client = OpenRouter( + explainer_model, + api_key=os.environ["OPENROUTER_API_KEY"], + ) + + explainer = DefaultExplainer(client, verbose=True) + scorer = make_scorer(client) + + print("Testing explainer") + explainer_result = await explainer(record) + assert explainer_result.explanation, "No explanation generated" + # assert "dog" in explainer_result.explanation.lower(), \ + # f'Explanation does not contain "dog": {explainer_result.explanation}' + print("Explanation:", explainer_result.explanation) + + print("Testing scorer") + scorer_result = await scorer(record) + accuracy = 0 + n_failing = 0 + for output in scorer_result.score: + if output.correct is None: + n_failing += 1 + else: + accuracy += int(output.correct) + assert n_failing <= 1, f"Scorer failed {n_failing} times" + accuracy /= len(scorer_result.score) + assert accuracy > 0.5, f"Accuracy is {accuracy}" + + print("All tests passed!") + if explainer_provider == "offline": + assert isinstance(client, Offline) + await client.close() + + +if __name__ == "__main__": + dotenv.load_dotenv() + fire.Fire(main) diff --git a/delphi/tests/e2e.py b/delphi/tests/e2e.py index 6bc62e03..17d5ccec 100644 --- a/delphi/tests/e2e.py +++ b/delphi/tests/e2e.py @@ -3,10 +3,21 @@ from pathlib import Path import torch +from beartype.claw import beartype_package -from delphi.__main__ import run -from delphi.config import CacheConfig, ConstructorConfig, RunConfig, SamplerConfig -from delphi.log.result_analysis import build_scores_df, latent_balanced_score_metrics +beartype_package("delphi") + +from delphi.__main__ import run # noqa: E402 +from delphi.config import ( # noqa: E402 + CacheConfig, + ConstructorConfig, + RunConfig, + SamplerConfig, +) +from delphi.log.result_analysis import ( # noqa: E402 + build_scores_df, + latent_balanced_score_metrics, +) async def test(): diff --git a/delphi/tests/pipeline.py b/delphi/tests/pipeline.py new file mode 100644 index 00000000..38bb23c6 --- /dev/null +++ b/delphi/tests/pipeline.py @@ -0,0 +1,158 @@ +import random +from collections.abc import AsyncIterable, Awaitable + +import pytest + +from delphi.pipeline import Pipe, Pipeline, process_wrapper + + +@pytest.fixture +def random_seed(): + random.seed(0) + + +TEST_ELEMENTS_HASHABLE = [1, "a", (1, 2, 3)] +TEST_ELEMENTS = TEST_ELEMENTS_HASHABLE + [{1: 2}, [1, 2, 3]] + + +async def async_generate(): + for elem in TEST_ELEMENTS_HASHABLE: + yield elem + + +async def async_generate_100_times(): + for _ in range(100): + for elem in TEST_ELEMENTS_HASHABLE: + yield elem + + +def sync_generate(): + for elem in TEST_ELEMENTS_HASHABLE: + yield elem + + +def sync_generate_100_times(): + for _ in range(100): + for elem in TEST_ELEMENTS_HASHABLE: + yield elem + + +TEST_LOADERS = [ + async_generate, + async_generate_100_times, + lambda: sync_generate, + lambda: sync_generate_100_times, +] + + +def identity(x): + return x + + +def test_process_wrapper_not_async(): + not_async_fn = identity + wrapped_fn = process_wrapper(not_async_fn) + with pytest.warns(RuntimeWarning): + assert isinstance(wrapped_fn(1), Awaitable) + + +@pytest.mark.xfail +@pytest.mark.asyncio +async def test_async_fails(): + not_async_fn = identity + wrapped_fn = process_wrapper(not_async_fn) + assert await wrapped_fn(1) == 1 + + +@pytest.mark.parametrize("x", TEST_ELEMENTS) +@pytest.mark.asyncio +async def test_identity(x): + async_fn = async_identity + wrapped_fn = process_wrapper(async_fn) + assert await wrapped_fn(x) == x + wrapped_fn = process_wrapper( + async_fn, preprocess=lambda x: x, postprocess=lambda x: x + ) + assert await wrapped_fn(x) == x + + +async def async_identity(x): + return x + + +@pytest.mark.xfail +@pytest.mark.asyncio +async def test_process_wrapper_async(): + async_fn = async_identity + wrapped_fn = process_wrapper(async_fn, preprocess=async_fn, postprocess=async_fn) + assert await wrapped_fn(1) == 1 + + +async def add_1(x): + return x + 1 + + +async def add_2(x): + return x + 2 + + +async def add_3(x): + return x + 3 + + +@pytest.mark.asyncio +async def test_pipe(): + pipe = Pipe(add_1, add_2, add_3) + assert await pipe(1) == [2, 3, 4] + + +def unflatten(x): + while isinstance(x, list): + x = x[0] + return x + + +@pytest.mark.parametrize("get_loader", TEST_LOADERS) +@pytest.mark.parametrize("repeats", [1, 2, 3]) +@pytest.mark.parametrize("convert_into_pipe", [True, False]) +@pytest.mark.parametrize("max_concurrent", [1, 2, 10, 100, 1000]) +@pytest.mark.asyncio +async def test_identity_pipeline( + get_loader, repeats, convert_into_pipe, max_concurrent +): + pipe = Pipe(async_identity) if convert_into_pipe else async_identity + elems = [] + if isinstance(get_loader(), AsyncIterable): + async for elem in get_loader(): + elems.append(elem) + elif callable(get_loader()): + for elem in get_loader()(): + elems.append(elem) + pipeline = Pipeline(get_loader(), *([pipe] * repeats)) + assert set( + (unflatten(elem) if convert_into_pipe else elem) + for elem in await pipeline.run() + ) == set(elems) + + elems_async = [] + + async def add_to_list(x): + elems_async.append(x) + return x + + pipeline = Pipeline(get_loader(), *([add_to_list] * repeats)) + results_async = await pipeline.run(max_concurrent=max_concurrent) + assert set(results_async) == set(elems) + assert set(elems_async) == set(elems) + assert len(elems_async) == len(elems) * repeats + assert len(elems_async) == len(elems) * repeats + + +@pytest.mark.asyncio +async def test_pipeline_failure(): + async def raise_error(_): + raise ValueError + + pipeline = Pipeline(async_generate(), raise_error) + with pytest.raises(ValueError): + await pipeline.run() diff --git a/delphi/tests/test_autoencoders/test_sparse_coders.py b/delphi/tests/test_autoencoders/test_sparse_coders.py index 853a1248..455cd7a4 100644 --- a/delphi/tests/test_autoencoders/test_sparse_coders.py +++ b/delphi/tests/test_autoencoders/test_sparse_coders.py @@ -1,6 +1,9 @@ import pytest import torch import torch.nn as nn +from transformers import PreTrainedModel + +from delphi.config import RunConfig # Import the function to be tested from delphi.sparse_coders import load_hooks_sparse_coders @@ -15,6 +18,10 @@ def __init__(self, sparse_model, hookpoints): self.model = "dummy_model" self.hf_token = "" + @property + def __class__(self) -> type: # type: ignore + return RunConfig + class DummyLayer(nn.Module): def __init__(self): @@ -38,6 +45,10 @@ def __init__(self): def forward(self, x): return x + @property + def __class__(self) -> type: # type: ignore + return PreTrainedModel + @pytest.fixture def dummy_model(): diff --git a/delphi/tests/test_latents/test_constructor.py b/delphi/tests/test_latents/test_constructor.py new file mode 100644 index 00000000..dd679551 --- /dev/null +++ b/delphi/tests/test_latents/test_constructor.py @@ -0,0 +1,135 @@ +import random +from itertools import chain +from typing import Any, Literal + +import pytest +import torch +from jaxtyping import Int +from torch import Tensor +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from delphi.config import ConstructorConfig, SamplerConfig +from delphi.latents import ( + ActivatingExample, + Latent, + LatentDataset, + LatentRecord, + constructor, + sampler, +) +from delphi.latents.latents import ActivationData + + +def test_save_load_cache( + cache_setup: dict[str, Any], + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, +): + sampler_cfg = SamplerConfig( + n_examples_train=3, + n_examples_test=3, + n_quantiles=3, + train_type="quantiles", + test_type="quantiles", + ) + dataset = LatentDataset( + cache_setup["temp_dir"], sampler_cfg, ConstructorConfig(), tokenizer + ) + tokens: Int[Tensor, "examples ctx_len"] = dataset.load_tokens() # type: ignore + assert (tokens == cache_setup["tokens"][: len(tokens)]).all() + for record in dataset: + assert len(record.train) <= sampler_cfg.n_examples_train + assert len(record.test) <= sampler_cfg.n_examples_test + + +@pytest.fixture(scope="module") +def seed(): + random.seed(0) + torch.manual_seed(0) + + +@pytest.mark.parametrize("n_samples", [5, 10, 100, 1000]) +@pytest.mark.parametrize("n_quantiles", [2, 5, 10, 23]) +@pytest.mark.parametrize("n_examples", [0, 2, 5, 10, 20]) +@pytest.mark.parametrize("train_type", ["top", "random", "quantiles"]) +def test_simple_cache( + n_samples: int, + n_quantiles: int, + n_examples: int, + train_type: Literal["top", "random", "quantiles"], + ctx_len: int = 32, + seed: None = None, + *, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, +): + torch.manual_seed(0) + tokens = torch.randint( + 0, + 100, + ( + n_samples, + ctx_len, + ), + ) + all_activation_data = [] + all_activations = [] + for feature_idx in range(2): + activations = torch.rand(n_samples, ctx_len, 1) * ( + torch.rand(n_samples)[..., None, None] ** 2 + ) + all_activations.append(activations) + mask = activations > 0.1 + locations = torch.nonzero(mask) + locations[..., 2] = feature_idx + all_activation_data.append(ActivationData(locations, activations[mask])) + activation_data, other_activation_data = all_activation_data + activations, other_activations = all_activations + record = LatentRecord(latent=Latent("test", 0), examples=[]) + constructor( + record, + activation_data, + constructor_cfg=ConstructorConfig( + example_ctx_len=ctx_len, + min_examples=1, + max_examples=100, + n_non_activating=50, + non_activating_source="neighbours", + ), + tokens=tokens, + tokenizer=tokenizer, + all_data={0: activation_data, 1: other_activation_data}, + ) + for i, j in zip(record.examples[:-1], record.examples[1:]): + assert i.max_activation >= j.max_activation + for i in record.examples: + index = (tokens == i.tokens).all(dim=-1).float().argmax() + assert (tokens[index] == i.tokens).all() + assert activations[index].max() == i.max_activation + sampler( + record, + SamplerConfig( + n_examples_train=n_examples, + n_examples_test=n_examples, + n_quantiles=n_quantiles, + train_type=train_type, + test_type="quantiles", + ), + ) + assert len(record.train) <= n_examples + assert len(record.test) <= n_examples + for neighbor in record.neighbours: + assert neighbor.latent_index == 1 + for example in chain(record.train, record.test): + assert isinstance(example, ActivatingExample) + assert example.normalized_activations is not None + assert example.normalized_activations.shape == example.activations.shape + assert (example.normalized_activations <= 10).all() + assert (example.normalized_activations >= 0).all() + for quantile_list in (record.test,) + ( # type: ignore + (record.train,) if train_type == "quantiles" else () + ): + quantile_list: list[ActivatingExample] = quantile_list + for k, i in enumerate(quantile_list): + for j in quantile_list[k + 1 :]: + if i.quantile != j.quantile: + assert i.max_activation >= j.max_activation + assert i.quantile < j.quantile diff --git a/delphi/utils.py b/delphi/utils.py index a148d238..09549787 100644 --- a/delphi/utils.py +++ b/delphi/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Type, TypeVar, cast +from typing import Any, TypeVar, cast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -36,7 +36,7 @@ def load_tokenized_data( T = TypeVar("T") -def assert_type(typ: Type[T], obj: Any) -> T: +def assert_type(typ: type[T], obj: Any) -> T: """Assert that an object is of a given type at runtime and return it.""" if not isinstance(obj, typ): raise TypeError(f"Expected {typ.__name__}, got {type(obj).__name__}") diff --git a/pyproject.toml b/pyproject.toml index 5985e90b..c798f867 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,8 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest", - "pyright==1.1.378" + "pyright==1.1.378", + "pytest-beartype", "pytest-asyncio", "setuptools", "pre-commit" ] visualize = [ "kaleido==0.2.1", @@ -60,3 +61,6 @@ line-length = 88 # Enable pycodestyle (`E`), Pyflakes (`F`), and isort (`I`) codes # See https://beta.ruff.rs/docs/rules/ for more possible rules select = ["E", "F", "I"] + +[tool.pytest.ini_options] +beartype_packages = "delphi"