Skip to content

Commit 7e9df72

Browse files
author
SrGonao
committed
Remove max_examples
1 parent 6b9af04 commit 7e9df72

File tree

2 files changed

+9
-23
lines changed

2 files changed

+9
-23
lines changed

delphi/config.py

-3
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,6 @@ class ConstructorConfig(Serializable):
4848
If the number of examples is less than this, the
4949
latent will not be explained and scored."""
5050

51-
max_examples: int = 10_000
52-
"""Maximum number of activating examples to generate for a single latent."""
53-
5451
n_non_activating: int = 50
5552
"""Number of non-activating examples to be constructed."""
5653

delphi/latents/constructors.py

+9-20
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def _top_k_pools(
5959
max_buffer: Float[Tensor, "batch"],
6060
split_activations: Float[Tensor, "activations ctx_len"],
6161
buffer_tokens: Int[Tensor, "batch ctx_len"],
62-
max_examples: int,
6362
) -> tuple[Int[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]:
6463
"""
6564
Get the top k activation pools.
@@ -73,11 +72,10 @@ def _top_k_pools(
7372
Returns:
7473
The token windows and activation windows.
7574
"""
76-
k = min(max_examples, len(max_buffer))
77-
top_values, top_indices = torch.topk(max_buffer, k, sorted=True)
75+
sorted_values, sorted_indices = torch.sort(max_buffer, descending=True)
7876

79-
activation_windows = torch.stack([split_activations[i] for i in top_indices])
80-
token_windows = buffer_tokens[top_indices]
77+
activation_windows = torch.stack([split_activations[i] for i in sorted_indices])
78+
token_windows = buffer_tokens[sorted_indices]
8179

8280
return token_windows, activation_windows
8381

@@ -88,7 +86,6 @@ def pool_max_activation_windows(
8886
ctx_indices: Int[Tensor, "examples"],
8987
index_within_ctx: Int[Tensor, "examples"],
9088
ctx_len: int,
91-
max_examples: int,
9289
) -> tuple[Int[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]:
9390
"""
9491
Pool max activation windows from the buffer output and update the latent record.
@@ -119,9 +116,7 @@ def pool_max_activation_windows(
119116
new_tensor[inverses, index_within_ctx] = activations
120117
tokens = tokens[unique_ctx_indices]
121118

122-
token_windows, activation_windows = _top_k_pools(
123-
max_buffer, new_tensor, tokens, max_examples
124-
)
119+
token_windows, activation_windows = _top_k_pools(max_buffer, new_tensor, tokens)
125120

126121
return token_windows, activation_windows
127122

@@ -133,7 +128,6 @@ def pool_centered_activation_windows(
133128
ctx_indices: Float[Tensor, "examples"],
134129
index_within_ctx: Float[Tensor, "examples"],
135130
ctx_len: int,
136-
max_examples: int,
137131
) -> tuple[Float[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]:
138132
"""
139133
Similar to pool_max_activation_windows. Doesn't use the ctx_indices that were
@@ -161,15 +155,14 @@ def pool_centered_activation_windows(
161155
max_buffer = torch.segment_reduce(activations, "max", lengths=lengths)
162156

163157
# Get the top max_examples windows
164-
k = min(max_examples, len(max_buffer))
165-
top_values, top_indices = torch.topk(max_buffer, k, sorted=True)
158+
sorted_values, sorted_indices = torch.sort(max_buffer, descending=True)
166159

167160
# this tensor has the correct activations for each context window
168161
temp_tensor = torch.zeros(len(unique_ctx_indices), ctx_len, dtype=activations.dtype)
169162
temp_tensor[inverses, index_within_ctx] = activations
170163

171-
unique_ctx_indices = unique_ctx_indices[top_indices]
172-
temp_tensor = temp_tensor[top_indices]
164+
unique_ctx_indices = unique_ctx_indices[sorted_indices]
165+
temp_tensor = temp_tensor[sorted_indices]
173166

174167
# if a element in unique_ctx_indices is divisible by n_windows_per_batch it
175168
# the start of a new batch, so we discard it
@@ -247,7 +240,6 @@ def constructor(
247240
example_ctx_len = constructor_cfg.example_ctx_len
248241
source_non_activating = constructor_cfg.non_activating_source
249242
n_not_active = constructor_cfg.n_non_activating
250-
max_examples = constructor_cfg.max_examples
251243
min_examples = constructor_cfg.min_examples
252244
# Get all positions where the latent is active
253245
flat_indices = (
@@ -276,7 +268,6 @@ def constructor(
276268
ctx_indices=ctx_indices,
277269
index_within_ctx=index_within_ctx,
278270
ctx_len=example_ctx_len,
279-
max_examples=max_examples,
280271
)
281272
else:
282273
token_windows, act_windows = pool_centered_activation_windows(
@@ -286,10 +277,7 @@ def constructor(
286277
ctx_indices=ctx_indices,
287278
index_within_ctx=index_within_ctx,
288279
ctx_len=example_ctx_len,
289-
max_examples=max_examples,
290280
)
291-
# TODO: We might want to do this in the sampler
292-
# we are tokenizing examples that are not going to be used
293281
record.examples = [
294282
ActivatingExample(
295283
tokens=toks,
@@ -433,8 +421,9 @@ def faiss_non_activation_windows(
433421
cache_path = Path(cache_dir) / embedding_model_name
434422

435423
# Get activating example texts
424+
436425
activating_texts = [
437-
"".join(example.str_tokens)
426+
"".join(tokenizer.batch_decode(example.tokens))
438427
for example in record.examples[: min(10, len(record.examples))]
439428
]
440429

0 commit comments

Comments
 (0)