Skip to content

Commit d7e077e

Browse files
authored
Merge pull request #118 from EleutherAI/update-config
Update config defaults
2 parents daeb584 + 9f9170a commit d7e077e

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

delphi/__main__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,13 @@ def scorer_postprocess(result, score_dir):
235235
scorer_pipe,
236236
)
237237

238+
if run_cfg.pipeline_num_proc > 1 and run_cfg.explainer_provider == "openrouter":
239+
print(
240+
"OpenRouter does not support multiprocessing,"
241+
" setting pipeline_num_proc to 1"
242+
)
243+
run_cfg.pipeline_num_proc = 1
244+
238245
await pipeline.run(run_cfg.pipeline_num_proc)
239246

240247

delphi/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class ConstructorConfig(Serializable):
6666

6767
@dataclass
6868
class CacheConfig(Serializable):
69-
dataset_repo: str = "EleutherAI/fineweb-edu-dedup-10b"
69+
dataset_repo: str = "EleutherAI/SmolLM2-135M-10B"
7070
"""Dataset repository to use for generating latent activations."""
7171

7272
dataset_split: str = "train[:1%]"
@@ -145,7 +145,9 @@ class RunConfig(Serializable):
145145
load_in_8bit: bool = False
146146
"""Load the model in 8-bit mode."""
147147

148-
hf_token: str | None = None
148+
# Use a dummy encoding function to prevent the token from being saved
149+
# to disk in plain text
150+
hf_token: str | None = field(default=None, encoding_fn=lambda _: None)
149151
"""Huggingface API token for downloading models."""
150152

151153
pipeline_num_proc: int = field(

delphi/log/result_analysis.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import orjson
55
import pandas as pd
66
import torch
7+
from sklearn.metrics import roc_auc_score
78
from torch import Tensor
89

910

@@ -33,6 +34,7 @@ def latent_balanced_score_metrics(
3334
"f1_score": np.average(df["f1_score"], weights=weights),
3435
"precision": np.average(df["precision"], weights=weights),
3536
"recall": np.average(df["recall"], weights=weights),
37+
"auc": np.average(df["auc"], weights=weights),
3638
"false_positives": np.average(df["false_positives"], weights=weights),
3739
"false_negatives": np.average(df["false_negatives"], weights=weights),
3840
"true_positives": np.average(df["true_positives"], weights=weights),
@@ -53,6 +55,7 @@ def latent_balanced_score_metrics(
5355
print(f"F1 Score: {metrics['f1_score']:.3f}")
5456
print(f"Precision: {metrics['precision']:.3f}")
5557
print(f"Recall: {metrics['recall']:.3f}")
58+
print(f"AUC: {metrics['auc']:.3f}")
5659

5760
fractions_failed = [
5861
failed_count / (total_examples + failed_count)
@@ -111,11 +114,11 @@ def parse_score_file(file_path):
111114
total_positives = (df["activating"]).sum()
112115
total_negatives = (~df["activating"]).sum()
113116

114-
# Calculate confusion matrix elements
115-
true_positives = ((df["prediction"] == 1) & (df["activating"])).sum()
116-
true_negatives = ((df["prediction"] == 0) & (~df["activating"])).sum()
117-
false_positives = ((df["prediction"] == 1) & (~df["activating"])).sum()
118-
false_negatives = ((df["prediction"] == 0) & (df["activating"])).sum()
117+
# Calculate confusion matrix elements using a threshold of 0.5
118+
true_positives = ((df["prediction"] >= 0.5) & (df["activating"])).sum()
119+
true_negatives = ((df["prediction"] < 0.5) & (~df["activating"])).sum()
120+
false_positives = ((df["prediction"] >= 0.5) & (~df["activating"])).sum()
121+
false_negatives = ((df["prediction"] < 0.5) & (df["activating"])).sum()
119122

120123
# Calculate rates
121124
true_positive_rate = true_positives / total_positives if total_positives > 0 else 0
@@ -127,7 +130,7 @@ def parse_score_file(file_path):
127130
false_negatives / total_positives if total_positives > 0 else 0
128131
)
129132

130-
# Calculate precision, recall, f1 (using sklearn for verification)
133+
# Calculate precision, recall, F1, and accuracy
131134
precision = (
132135
true_positives / (true_positives + false_positives)
133136
if (true_positives + false_positives) > 0
@@ -139,12 +142,16 @@ def parse_score_file(file_path):
139142
if (precision + recall) > 0
140143
else 0
141144
)
142-
143-
# Calculate accuracy
144145
accuracy = (
145146
(true_positives + true_negatives) / total_examples if total_examples > 0 else 0
146147
)
147148

149+
# Calculate ROC AUC score
150+
try:
151+
auc = roc_auc_score(df["activating"], df["prediction"])
152+
except Exception:
153+
auc = 0.5
154+
148155
# Add metrics to first row
149156
metrics = {
150157
"true_positive_rate": true_positive_rate,
@@ -159,6 +166,7 @@ def parse_score_file(file_path):
159166
"recall": recall,
160167
"f1_score": f1_score,
161168
"accuracy": accuracy,
169+
"auc": auc,
162170
"total_examples": total_examples,
163171
"total_positives": total_positives,
164172
"total_negatives": total_negatives,
@@ -189,6 +197,7 @@ def build_scores_df(
189197
"precision",
190198
"recall",
191199
"f1_score",
200+
"auc",
192201
"true_positives",
193202
"true_negatives",
194203
"false_positives",
@@ -238,6 +247,8 @@ def build_scores_df(
238247
df_data["latent_idx"].append(latent_idx)
239248
df_data["firing_counts"].append(
240249
hookpoint_firing_counts[module][latent_idx].item()
250+
if module in hookpoint_firing_counts
251+
else -1
241252
)
242253
df_data["module"].append(module)
243254
for col in metrics_cols:
@@ -268,14 +279,17 @@ def plot_line(df: pd.DataFrame, visualize_path: Path):
268279

269280
def log_results(scores_path: Path, visualize_path: Path, target_modules: list[str]):
270281
log_path = scores_path.parent / "log" / "hookpoint_firing_counts.pt"
271-
hookpoint_firing_counts: dict[str, Tensor] = torch.load(log_path, weights_only=True)
282+
hookpoint_firing_counts: dict[str, Tensor] = (
283+
torch.load(log_path, weights_only=True) if log_path.exists() else {}
284+
)
272285
df = build_scores_df(scores_path, target_modules, hookpoint_firing_counts)
273286

274287
# Calculate the number of dead features for each module which will not be in the df
275288
num_dead_features = sum(
276289
[
277290
(hookpoint_firing_counts[module] == 0).sum().item()
278291
for module in target_modules
292+
if module in hookpoint_firing_counts
279293
]
280294
)
281295
print(f"Number of dead features: {num_dead_features}")

0 commit comments

Comments
 (0)