44import orjson
55import pandas as pd
66import torch
7+ from sklearn .metrics import roc_auc_score
78from torch import Tensor
89
910
@@ -33,6 +34,7 @@ def latent_balanced_score_metrics(
3334 "f1_score" : np .average (df ["f1_score" ], weights = weights ),
3435 "precision" : np .average (df ["precision" ], weights = weights ),
3536 "recall" : np .average (df ["recall" ], weights = weights ),
37+ "auc" : np .average (df ["auc" ], weights = weights ),
3638 "false_positives" : np .average (df ["false_positives" ], weights = weights ),
3739 "false_negatives" : np .average (df ["false_negatives" ], weights = weights ),
3840 "true_positives" : np .average (df ["true_positives" ], weights = weights ),
@@ -53,6 +55,7 @@ def latent_balanced_score_metrics(
5355 print (f"F1 Score: { metrics ['f1_score' ]:.3f} " )
5456 print (f"Precision: { metrics ['precision' ]:.3f} " )
5557 print (f"Recall: { metrics ['recall' ]:.3f} " )
58+ print (f"AUC: { metrics ['auc' ]:.3f} " )
5659
5760 fractions_failed = [
5861 failed_count / (total_examples + failed_count )
@@ -111,11 +114,11 @@ def parse_score_file(file_path):
111114 total_positives = (df ["activating" ]).sum ()
112115 total_negatives = (~ df ["activating" ]).sum ()
113116
114- # Calculate confusion matrix elements
115- true_positives = ((df ["prediction" ] == 1 ) & (df ["activating" ])).sum ()
116- true_negatives = ((df ["prediction" ] == 0 ) & (~ df ["activating" ])).sum ()
117- false_positives = ((df ["prediction" ] == 1 ) & (~ df ["activating" ])).sum ()
118- false_negatives = ((df ["prediction" ] == 0 ) & (df ["activating" ])).sum ()
117+ # Calculate confusion matrix elements using a threshold of 0.5
118+ true_positives = ((df ["prediction" ] >= 0.5 ) & (df ["activating" ])).sum ()
119+ true_negatives = ((df ["prediction" ] < 0.5 ) & (~ df ["activating" ])).sum ()
120+ false_positives = ((df ["prediction" ] >= 0.5 ) & (~ df ["activating" ])).sum ()
121+ false_negatives = ((df ["prediction" ] < 0.5 ) & (df ["activating" ])).sum ()
119122
120123 # Calculate rates
121124 true_positive_rate = true_positives / total_positives if total_positives > 0 else 0
@@ -127,7 +130,7 @@ def parse_score_file(file_path):
127130 false_negatives / total_positives if total_positives > 0 else 0
128131 )
129132
130- # Calculate precision, recall, f1 (using sklearn for verification)
133+ # Calculate precision, recall, F1, and accuracy
131134 precision = (
132135 true_positives / (true_positives + false_positives )
133136 if (true_positives + false_positives ) > 0
@@ -139,12 +142,16 @@ def parse_score_file(file_path):
139142 if (precision + recall ) > 0
140143 else 0
141144 )
142-
143- # Calculate accuracy
144145 accuracy = (
145146 (true_positives + true_negatives ) / total_examples if total_examples > 0 else 0
146147 )
147148
149+ # Calculate ROC AUC score
150+ try :
151+ auc = roc_auc_score (df ["activating" ], df ["prediction" ])
152+ except Exception :
153+ auc = 0.5
154+
148155 # Add metrics to first row
149156 metrics = {
150157 "true_positive_rate" : true_positive_rate ,
@@ -159,6 +166,7 @@ def parse_score_file(file_path):
159166 "recall" : recall ,
160167 "f1_score" : f1_score ,
161168 "accuracy" : accuracy ,
169+ "auc" : auc ,
162170 "total_examples" : total_examples ,
163171 "total_positives" : total_positives ,
164172 "total_negatives" : total_negatives ,
@@ -189,6 +197,7 @@ def build_scores_df(
189197 "precision" ,
190198 "recall" ,
191199 "f1_score" ,
200+ "auc" ,
192201 "true_positives" ,
193202 "true_negatives" ,
194203 "false_positives" ,
@@ -238,6 +247,8 @@ def build_scores_df(
238247 df_data ["latent_idx" ].append (latent_idx )
239248 df_data ["firing_counts" ].append (
240249 hookpoint_firing_counts [module ][latent_idx ].item ()
250+ if module in hookpoint_firing_counts
251+ else - 1
241252 )
242253 df_data ["module" ].append (module )
243254 for col in metrics_cols :
@@ -268,14 +279,17 @@ def plot_line(df: pd.DataFrame, visualize_path: Path):
268279
269280def log_results (scores_path : Path , visualize_path : Path , target_modules : list [str ]):
270281 log_path = scores_path .parent / "log" / "hookpoint_firing_counts.pt"
271- hookpoint_firing_counts : dict [str , Tensor ] = torch .load (log_path , weights_only = True )
282+ hookpoint_firing_counts : dict [str , Tensor ] = (
283+ torch .load (log_path , weights_only = True ) if log_path .exists () else {}
284+ )
272285 df = build_scores_df (scores_path , target_modules , hookpoint_firing_counts )
273286
274287 # Calculate the number of dead features for each module which will not be in the df
275288 num_dead_features = sum (
276289 [
277290 (hookpoint_firing_counts [module ] == 0 ).sum ().item ()
278291 for module in target_modules
292+ if module in hookpoint_firing_counts
279293 ]
280294 )
281295 print (f"Number of dead features: { num_dead_features } " )
0 commit comments