@@ -59,7 +59,6 @@ def _top_k_pools(
59
59
max_buffer : Float [Tensor , "batch" ],
60
60
split_activations : Float [Tensor , "activations ctx_len" ],
61
61
buffer_tokens : Int [Tensor , "batch ctx_len" ],
62
- max_examples : int ,
63
62
) -> tuple [Int [Tensor , "examples ctx_len" ], Float [Tensor , "examples ctx_len" ]]:
64
63
"""
65
64
Get the top k activation pools.
@@ -73,11 +72,10 @@ def _top_k_pools(
73
72
Returns:
74
73
The token windows and activation windows.
75
74
"""
76
- k = min (max_examples , len (max_buffer ))
77
- top_values , top_indices = torch .topk (max_buffer , k , sorted = True )
75
+ sorted_values , sorted_indices = torch .sort (max_buffer , descending = True )
78
76
79
- activation_windows = torch .stack ([split_activations [i ] for i in top_indices ])
80
- token_windows = buffer_tokens [top_indices ]
77
+ activation_windows = torch .stack ([split_activations [i ] for i in sorted_indices ])
78
+ token_windows = buffer_tokens [sorted_indices ]
81
79
82
80
return token_windows , activation_windows
83
81
@@ -88,7 +86,6 @@ def pool_max_activation_windows(
88
86
ctx_indices : Int [Tensor , "examples" ],
89
87
index_within_ctx : Int [Tensor , "examples" ],
90
88
ctx_len : int ,
91
- max_examples : int ,
92
89
) -> tuple [Int [Tensor , "examples ctx_len" ], Float [Tensor , "examples ctx_len" ]]:
93
90
"""
94
91
Pool max activation windows from the buffer output and update the latent record.
@@ -119,9 +116,7 @@ def pool_max_activation_windows(
119
116
new_tensor [inverses , index_within_ctx ] = activations
120
117
tokens = tokens [unique_ctx_indices ]
121
118
122
- token_windows , activation_windows = _top_k_pools (
123
- max_buffer , new_tensor , tokens , max_examples
124
- )
119
+ token_windows , activation_windows = _top_k_pools (max_buffer , new_tensor , tokens )
125
120
126
121
return token_windows , activation_windows
127
122
@@ -133,7 +128,6 @@ def pool_centered_activation_windows(
133
128
ctx_indices : Float [Tensor , "examples" ],
134
129
index_within_ctx : Float [Tensor , "examples" ],
135
130
ctx_len : int ,
136
- max_examples : int ,
137
131
) -> tuple [Float [Tensor , "examples ctx_len" ], Float [Tensor , "examples ctx_len" ]]:
138
132
"""
139
133
Similar to pool_max_activation_windows. Doesn't use the ctx_indices that were
@@ -161,15 +155,14 @@ def pool_centered_activation_windows(
161
155
max_buffer = torch .segment_reduce (activations , "max" , lengths = lengths )
162
156
163
157
# Get the top max_examples windows
164
- k = min (max_examples , len (max_buffer ))
165
- top_values , top_indices = torch .topk (max_buffer , k , sorted = True )
158
+ sorted_values , sorted_indices = torch .sort (max_buffer , descending = True )
166
159
167
160
# this tensor has the correct activations for each context window
168
161
temp_tensor = torch .zeros (len (unique_ctx_indices ), ctx_len , dtype = activations .dtype )
169
162
temp_tensor [inverses , index_within_ctx ] = activations
170
163
171
- unique_ctx_indices = unique_ctx_indices [top_indices ]
172
- temp_tensor = temp_tensor [top_indices ]
164
+ unique_ctx_indices = unique_ctx_indices [sorted_indices ]
165
+ temp_tensor = temp_tensor [sorted_indices ]
173
166
174
167
# if a element in unique_ctx_indices is divisible by n_windows_per_batch it
175
168
# the start of a new batch, so we discard it
@@ -247,7 +240,6 @@ def constructor(
247
240
example_ctx_len = constructor_cfg .example_ctx_len
248
241
source_non_activating = constructor_cfg .non_activating_source
249
242
n_not_active = constructor_cfg .n_non_activating
250
- max_examples = constructor_cfg .max_examples
251
243
min_examples = constructor_cfg .min_examples
252
244
# Get all positions where the latent is active
253
245
flat_indices = (
@@ -276,7 +268,6 @@ def constructor(
276
268
ctx_indices = ctx_indices ,
277
269
index_within_ctx = index_within_ctx ,
278
270
ctx_len = example_ctx_len ,
279
- max_examples = max_examples ,
280
271
)
281
272
else :
282
273
token_windows , act_windows = pool_centered_activation_windows (
@@ -286,10 +277,7 @@ def constructor(
286
277
ctx_indices = ctx_indices ,
287
278
index_within_ctx = index_within_ctx ,
288
279
ctx_len = example_ctx_len ,
289
- max_examples = max_examples ,
290
280
)
291
- # TODO: We might want to do this in the sampler
292
- # we are tokenizing examples that are not going to be used
293
281
record .examples = [
294
282
ActivatingExample (
295
283
tokens = toks ,
@@ -433,8 +421,9 @@ def faiss_non_activation_windows(
433
421
cache_path = Path (cache_dir ) / embedding_model_name
434
422
435
423
# Get activating example texts
424
+
436
425
activating_texts = [
437
- "" .join (example .str_tokens )
426
+ "" .join (tokenizer . batch_decode ( example .tokens ) )
438
427
for example in record .examples [: min (10 , len (record .examples ))]
439
428
]
440
429
0 commit comments