Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extended tests, runtime type checking #101

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
043cc03
Add beartype to tests, fix type hints
neverix Mar 8, 2025
7074dce
Pipeline tests
neverix Mar 8, 2025
275a3ba
Beartype for E2E test, discover small typing bugs
neverix Mar 8, 2025
22ba1af
Fix beartype errors, E2E test fails
neverix Mar 8, 2025
c272bb1
Fix E2E beartype issue
neverix Mar 8, 2025
c35b64d
Basic constructor tests, update cache documentation
neverix Mar 9, 2025
fdaf8da
Test samplers
neverix Mar 9, 2025
eccbf8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 9, 2025
cc7ef74
Tests for the clients
neverix Mar 9, 2025
7b2b2dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 9, 2025
d03ea50
Fix ruff
neverix Mar 9, 2025
ec867b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 9, 2025
6000985
Fix some type errors
neverix Mar 9, 2025
c56ad2e
Merge remote-tracking branch 'origin/more-tests' into more-tests
neverix Mar 9, 2025
409ce58
Fix typing and formatting
neverix Mar 10, 2025
6bcf708
Fix types again
neverix Mar 10, 2025
bc94e71
Merge branch 'main' into more-tests
neverix Mar 20, 2025
ff87cdc
Fix beartype errors
neverix Mar 20, 2025
0b03585
Merge branch 'main' into more-tests
neverix Mar 21, 2025
cf3042a
Fix beartype warnings & errors
neverix Mar 21, 2025
d7dec01
Fix SparseCoder protocol type errors
neverix Mar 21, 2025
0b8dc24
Merge branch 'main' into more-tests
neverix Apr 6, 2025
fe54b85
Fix tests
neverix Apr 6, 2025
f821471
Fix some failing tests by removing runtime checking (inspect.get_stat…
neverix Apr 6, 2025
9f6a262
Fix tests
neverix Apr 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions delphi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def create_neighbours(
elif constructor_cfg.neighbours_type == "decoder_similarity":

neighbour_calculator = NeighbourCalculator(
autoencoder=saes[hookpoint].cuda(), number_of_neighbours=250
autoencoder=saes[hookpoint].to("cuda"), number_of_neighbours=250
)

elif constructor_cfg.neighbours_type == "encoder_similarity":
neighbour_calculator = NeighbourCalculator(
autoencoder=saes[hookpoint].cuda(), number_of_neighbours=250
autoencoder=saes[hookpoint].to("cuda"), number_of_neighbours=250
)
else:
raise ValueError(
Expand Down Expand Up @@ -131,7 +131,7 @@ async def process_cache(
} # The latent range to explain

dataset = LatentDataset(
raw_dir=str(latents_path),
raw_dir=latents_path,
sampler_cfg=run_cfg.sampler_cfg,
constructor_cfg=run_cfg.constructor_cfg,
modules=hookpoints,
Expand Down
2 changes: 1 addition & 1 deletion delphi/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, model: str):
@abstractmethod
async def generate(
self, prompt: Union[str, list[dict[str, str]]], **kwargs
) -> Response:
) -> str | Response:
pass

# @abstractmethod
Expand Down
10 changes: 8 additions & 2 deletions delphi/clients/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def __init__(
self.statistics_path = Path("statistics")
self.statistics_path.mkdir(parents=True, exist_ok=True)

async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs):
async def process_func(
self,
batches: Union[str, list[Union[dict[str, str], list[dict[str, str]]]]],
kwargs,
):
"""
Process a single request.
"""
Expand Down Expand Up @@ -142,7 +146,9 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs):
)
return new_response

async def generate(self, prompt: Union[str, list[dict[str, str]]], **kwargs) -> str: # type: ignore
async def generate(
self, prompt: Union[str, list[dict[str, str]]], **kwargs
) -> Response: # type: ignore
"""
Enqueue a request and wait for the result.
"""
Expand Down
7 changes: 6 additions & 1 deletion delphi/clients/openrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ..logger import logger
from .client import Client
from .types import ChatFormatRequest

# Preferred provider routing arguments.
# Change depending on what model you'd like to use.
Expand Down Expand Up @@ -36,7 +37,11 @@ def postprocess(self, response):
return Response(msg)

async def generate( # type: ignore
self, prompt: str, raw: bool = False, max_retries: int = 1, **kwargs # type: ignore
self,
prompt: ChatFormatRequest,
raw: bool = False,
max_retries: int = 1,
**kwargs, # type: ignore
) -> Response: # type: ignore
kwargs.pop("schema", None)
max_tokens = kwargs.pop("max_tokens", 500)
Expand Down
9 changes: 9 additions & 0 deletions delphi/clients/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Literal, TypedDict, Union


class Message(TypedDict):
content: str
role: Literal["system", "user", "assistant"]


ChatFormatRequest = Union[str, list[str], list[Message], None]
8 changes: 6 additions & 2 deletions delphi/explainers/contrastive_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from delphi.explainers.default.prompts import SYSTEM_CONTRASTIVE
from delphi.explainers.explainer import Explainer, ExplainerResult
from delphi.explainers.explainer import Explainer, ExplainerResult, Response
from delphi.latents.latents import ActivatingExample, LatentRecord, NonActivatingExample


Expand Down Expand Up @@ -54,7 +54,11 @@ async def __call__(self, record: LatentRecord) -> ExplainerResult:
)

try:
explanation = self.parse_explanation(response.text)
if isinstance(response, Response):
response_text = response.text
else:
response_text = response
explanation = self.parse_explanation(response_text)
if self.verbose:
from ..logger import logger

Expand Down
3 changes: 2 additions & 1 deletion delphi/explainers/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import aiofiles

from ..clients.client import Client
from ..clients.client import Client, Response
from ..latents.latents import ActivatingExample, LatentRecord
from ..logger import logger

Expand Down Expand Up @@ -44,6 +44,7 @@ async def __call__(self, record: LatentRecord) -> ExplainerResult:
response = await self.client.generate(
messages, temperature=self.temperature, **self.generation_kwargs
)
assert isinstance(response, Response)

try:
explanation = self.parse_explanation(response.text)
Expand Down
125 changes: 64 additions & 61 deletions delphi/latents/cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Callable

import numpy as np
import torch
Expand All @@ -15,8 +15,49 @@
from delphi.config import CacheConfig
from delphi.latents.collect_activations import collect_activations

location_tensor_shape = Float[Tensor, "batch sequence num_latents"]
token_tensor_shape = Float[Tensor, "batch sequence"]
location_tensor_type = Int[Tensor, "batch_sequence 3"]
activation_tensor_type = Float[Tensor, "batch_sequence"]
token_tensor_type = Int[Tensor, "batch sequence"]
latent_tensor_type = Float[Tensor, "batch sequence num_latents"]


def get_nonzeros_batch(
latents: latent_tensor_type,
) -> tuple[
Float[Tensor, "batch sequence num_latents"], Float[Tensor, "batch sequence "]
]:
"""
Get non-zero activations for large batches that exceed int32 max value.

Args:
latents: Input latent activations.

Returns:
tuple[Tensor, Tensor]: Non-zero latent locations and activations.
"""
# Calculate the maximum batch size that fits within sys.maxsize
max_batch_size = torch.iinfo(torch.int32).max // (
latents.shape[1] * latents.shape[2]
)
nonzero_latent_locations = []
nonzero_latent_activations = []

for i in range(0, latents.shape[0], max_batch_size):
batch = latents[i : i + max_batch_size]

# Get nonzero locations and activations
batch_locations = torch.nonzero(batch.abs() > 1e-5)
batch_activations = batch[batch.abs() > 1e-5]

# Adjust indices to account for batching
batch_locations[:, 0] += i
nonzero_latent_locations.append(batch_locations)
nonzero_latent_activations.append(batch_activations)

# Concatenate results
nonzero_latent_locations = torch.cat(nonzero_latent_locations, dim=0)
nonzero_latent_activations = torch.cat(nonzero_latent_activations, dim=0)
return nonzero_latent_locations, nonzero_latent_activations


class InMemoryCache:
Expand All @@ -37,25 +78,25 @@ def __init__(
filters: Filters for selecting specific latents.
batch_size: Size of batches for processing. Defaults to 64.
"""
self.latent_locations_batches: dict[str, list[location_tensor_shape]] = (
self.latent_locations_batches: dict[str, list[location_tensor_type]] = (
defaultdict(list)
)
self.latent_activations_batches: dict[str, list[location_tensor_shape]] = (
self.latent_activations_batches: dict[str, list[latent_tensor_type]] = (
defaultdict(list)
)
self.tokens_batches: dict[str, list[token_tensor_shape]] = defaultdict(list)
self.tokens_batches: dict[str, list[token_tensor_type]] = defaultdict(list)

self.latent_locations: dict[str, location_tensor_shape] = {}
self.latent_activations: dict[str, location_tensor_shape] = {}
self.tokens: dict[str, token_tensor_shape] = {}
self.latent_locations: dict[str, location_tensor_type] = {}
self.latent_activations: dict[str, latent_tensor_type] = {}
self.tokens: dict[str, token_tensor_type] = {}

self.filters = filters
self.batch_size = batch_size

def add(
self,
latents: location_tensor_shape,
tokens: token_tensor_shape,
latents: latent_tensor_type,
tokens: token_tensor_type,
batch_number: int,
module_path: str,
):
Expand Down Expand Up @@ -96,47 +137,9 @@ def save(self):
self.tokens_batches[module_path], dim=0
)

def get_nonzeros_batch(
self, latents: location_tensor_shape
) -> tuple[
Float[Tensor, "batch sequence num_latents"], Float[Tensor, "batch sequence "]
]:
"""
Get non-zero activations for large batches that exceed int32 max value.

Args:
latents: Input latent activations.

Returns:
tuple[Tensor, Tensor]: Non-zero latent locations and activations.
"""
# Calculate the maximum batch size that fits within sys.maxsize
max_batch_size = torch.iinfo(torch.int32).max // (
latents.shape[1] * latents.shape[2]
)
nonzero_latent_locations = []
nonzero_latent_activations = []

for i in range(0, latents.shape[0], max_batch_size):
batch = latents[i : i + max_batch_size]

# Get nonzero locations and activations
batch_locations = torch.nonzero(batch.abs() > 1e-5)
batch_activations = batch[batch.abs() > 1e-5]

# Adjust indices to account for batching
batch_locations[:, 0] += i
nonzero_latent_locations.append(batch_locations)
nonzero_latent_activations.append(batch_activations)

# Concatenate results
nonzero_latent_locations = torch.cat(nonzero_latent_locations, dim=0)
nonzero_latent_activations = torch.cat(nonzero_latent_activations, dim=0)
return nonzero_latent_locations, nonzero_latent_activations

def get_nonzeros(self, latents: location_tensor_shape, module_path: str) -> tuple[
location_tensor_shape,
location_tensor_shape,
def get_nonzeros(self, latents: latent_tensor_type, module_path: str) -> tuple[
location_tensor_type,
activation_tensor_type,
]:
"""
Get the nonzero latent locations and activations.
Expand All @@ -153,7 +156,7 @@ def get_nonzeros(self, latents: location_tensor_shape, module_path: str) -> tupl
(
nonzero_latent_locations,
nonzero_latent_activations,
) = self.get_nonzeros_batch(latents)
) = get_nonzeros_batch(latents)
else:
nonzero_latent_locations = torch.nonzero(latents.abs() > 1e-5)
nonzero_latent_activations = latents[latents.abs() > 1e-5]
Expand Down Expand Up @@ -209,8 +212,8 @@ def __init__(
self.filter_submodules(filters)

def load_token_batches(
self, n_tokens: int, tokens: token_tensor_shape
) -> list[token_tensor_shape]:
self, n_tokens: int, tokens: token_tensor_type
) -> list[token_tensor_type]:
"""
Load and prepare token batches for processing.

Expand Down Expand Up @@ -248,7 +251,7 @@ def filter_submodules(self, filters: dict[str, Float[Tensor, "indices"]]):
]
self.hookpoint_to_sparse_encode = filtered_submodules

def run(self, n_tokens: int, tokens: token_tensor_shape):
def run(self, n_tokens: int, tokens: token_tensor_type):
"""
Run the latent caching process.

Expand Down Expand Up @@ -520,11 +523,11 @@ def generate_statistics_cache(
print(f"Fraction of strong single token latents: {strong_token_fraction:%}")

return CacheStatistics(
frac_alive=fraction_alive,
frac_fired_1pct=one_percent,
frac_fired_10pct=ten_percent,
frac_weak_single_token=single_token_fraction,
frac_strong_single_token=strong_token_fraction,
frac_alive=float(fraction_alive),
frac_fired_1pct=float(one_percent),
frac_fired_10pct=float(ten_percent),
frac_weak_single_token=float(single_token_fraction),
frac_strong_single_token=float(strong_token_fraction),
)


Expand Down
4 changes: 3 additions & 1 deletion delphi/latents/collect_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def collect_activations(
handles = []

def create_hook(hookpoint: str, transcode: bool = False):
def hook_fn(module: nn.Module, input: Any, output: Tensor) -> Tensor | None:
def hook_fn(
module: nn.Module, input: Any, output: Tensor | tuple[Tensor]
) -> Tensor | None:
# If output is a tuple (like in some transformer layers), take first element
if transcode:
if isinstance(input, tuple):
Expand Down
Loading