Skip to content

Commit 3c2f764

Browse files
author
AMindToThink
committed
Cache the repeated SAE instead of making new every time.
1 parent 505015b commit 3c2f764

File tree

2 files changed

+138
-3
lines changed

2 files changed

+138
-3
lines changed

lm_eval/models/sae_steered_beta.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,22 @@ def from_csv(
6363
# Read steering configurations
6464
df = pd.read_csv(csv_path)
6565
# Create hooks for each row in the CSV
66+
sae_cache = {}
6667
hooks = []
68+
69+
def get_sae(sae_release, sae_id):
70+
cache_key = (sae_release, sae_id)
71+
if cache_key not in sae_cache:
72+
sae_cache[cache_key] = SAE.from_pretrained(
73+
sae_release, sae_id, device=str(device)
74+
)[0]
75+
return sae_cache[cache_key]
76+
6777
for _, row in df.iterrows():
68-
sae = SAE.from_pretrained(
69-
row["sae_release"], row["sae_id"], device=str(device)
70-
)[0]
78+
sae_release = row["sae_release"]
79+
sae_id = row["sae_id"]
80+
81+
sae = get_sae(sae_release=sae_release, sae_id=sae_id)
7182
sae.eval()
7283
hook = partial(
7384
steering_hook,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
{
2+
"results": {
3+
"mmlu_abstract_algebra": {
4+
"alias": "abstract_algebra",
5+
"acc,none": 0.34,
6+
"acc_stderr,none": 0.047609522856952344
7+
}
8+
},
9+
"group_subtasks": {
10+
"mmlu_abstract_algebra": []
11+
},
12+
"configs": {
13+
"mmlu_abstract_algebra": {
14+
"task": "mmlu_abstract_algebra",
15+
"task_alias": "abstract_algebra",
16+
"tag": "mmlu_stem_tasks",
17+
"dataset_path": "hails/mmlu_no_train",
18+
"dataset_name": "abstract_algebra",
19+
"dataset_kwargs": {
20+
"trust_remote_code": true
21+
},
22+
"test_split": "test",
23+
"fewshot_split": "dev",
24+
"doc_to_text": "{{question.strip()}}\nA. {{choices[0]}}\nB. {{choices[1]}}\nC. {{choices[2]}}\nD. {{choices[3]}}\nAnswer:",
25+
"doc_to_target": "answer",
26+
"unsafe_code": false,
27+
"doc_to_choice": [
28+
"A",
29+
"B",
30+
"C",
31+
"D"
32+
],
33+
"description": "The following are multiple choice questions (with answers) about abstract algebra.\n\n",
34+
"target_delimiter": " ",
35+
"fewshot_delimiter": "\n\n",
36+
"fewshot_config": {
37+
"sampler": "first_n"
38+
},
39+
"num_fewshot": 0,
40+
"metric_list": [
41+
{
42+
"metric": "acc",
43+
"aggregation": "mean",
44+
"higher_is_better": true
45+
}
46+
],
47+
"output_type": "multiple_choice",
48+
"repeats": 1,
49+
"should_decontaminate": false,
50+
"metadata": {
51+
"version": 1.0
52+
}
53+
}
54+
},
55+
"versions": {
56+
"mmlu_abstract_algebra": 1.0
57+
},
58+
"n-shot": {
59+
"mmlu_abstract_algebra": 0
60+
},
61+
"higher_is_better": {
62+
"mmlu_abstract_algebra": {
63+
"acc": true
64+
}
65+
},
66+
"n-samples": {
67+
"mmlu_abstract_algebra": {
68+
"original": 100,
69+
"effective": 100
70+
}
71+
},
72+
"config": {
73+
"model": "sae_steered_beta",
74+
"model_args": "base_name=google/gemma-2-2b,csv_path=/home/cs29824/matthew/lm-evaluation-harness/examples/dog_steer.csv",
75+
"model_num_parameters": 0,
76+
"model_dtype": null,
77+
"model_revision": "main",
78+
"model_sha": "c5ebcd40d208330abc697524c919956e692655cf",
79+
"batch_size": "auto",
80+
"batch_sizes": [
81+
16
82+
],
83+
"device": "cuda:0",
84+
"use_cache": null,
85+
"limit": null,
86+
"bootstrap_iters": 100000,
87+
"gen_kwargs": null,
88+
"random_seed": 0,
89+
"numpy_seed": 1234,
90+
"torch_seed": 1234,
91+
"fewshot_seed": 1234
92+
},
93+
"git_hash": "e16afa2f",
94+
"date": 1737419939.4888458,
95+
"pretty_env_info": "PyTorch version: 2.5.1+cu124\nIs debug build: False\nCUDA used to build PyTorch: 12.4\nROCM used to build PyTorch: N/A\n\nOS: Ubuntu 20.04.6 LTS (x86_64)\nGCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0\nClang version: Could not collect\nCMake version: version 3.16.3\nLibc version: glibc-2.31\n\nPython version: 3.11.11 | packaged by conda-forge | (main, Dec 5 2024, 14:17:24) [GCC 13.3.0] (64-bit runtime)\nPython platform: Linux-5.4.0-1125-kvm-x86_64-with-glibc2.31\nIs CUDA available: True\nCUDA runtime version: 10.1.243\nCUDA_MODULE_LOADING set to: LAZY\nGPU models and configuration: \nGPU 0: Quadro RTX 8000\nGPU 1: Quadro RTX 8000\n\nNvidia driver version: 545.23.08\ncuDNN version: Probably one of the following:\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn.so.8.4.1\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1\n/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1\nHIP runtime version: N/A\nMIOpen runtime version: N/A\nIs XNNPACK available: True\n\nCPU:\nArchitecture: x86_64\nCPU op-mode(s): 32-bit, 64-bit\nByte Order: Little Endian\nAddress sizes: 46 bits physical, 48 bits virtual\nCPU(s): 16\nOn-line CPU(s) list: 0-15\nThread(s) per core: 1\nCore(s) per socket: 1\nSocket(s): 16\nNUMA node(s): 1\nVendor ID: GenuineIntel\nCPU family: 6\nModel: 85\nModel name: Intel Xeon Processor (Cascadelake)\nStepping: 6\nCPU MHz: 2294.608\nBogoMIPS: 4589.21\nVirtualization: VT-x\nHypervisor vendor: KVM\nVirtualization type: full\nL1d cache: 512 KiB\nL1i cache: 512 KiB\nL2 cache: 64 MiB\nL3 cache: 256 MiB\nNUMA node0 CPU(s): 0-15\nVulnerability Gather data sampling: Unknown: Dependent on hypervisor status\nVulnerability Itlb multihit: Not affected\nVulnerability L1tf: Not affected\nVulnerability Mds: Not affected\nVulnerability Meltdown: Not affected\nVulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown\nVulnerability Retbleed: Mitigation; Enhanced IBRS\nVulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp\nVulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization\nVulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI Vulnerable, KVM SW loop\nVulnerability Srbds: Not affected\nVulnerability Tsx async abort: Mitigation; TSX disabled\nFlags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat umip pku avx512_vnni md_clear arch_capabilities\n\nVersions of relevant libraries:\n[pip3] mypy==1.14.1\n[pip3] mypy-extensions==1.0.0\n[pip3] numpy==1.26.4\n[pip3] torch==2.5.1\n[pip3] triton==3.1.0\n[conda] numpy 1.26.4 pypi_0 pypi\n[conda] torch 2.5.1 pypi_0 pypi\n[conda] triton 3.1.0 pypi_0 pypi",
96+
"transformers_version": "4.48.1",
97+
"upper_git_hash": null,
98+
"tokenizer_pad_token": [
99+
"<pad>",
100+
"0"
101+
],
102+
"tokenizer_eos_token": [
103+
"<eos>",
104+
"1"
105+
],
106+
"tokenizer_bos_token": [
107+
"<bos>",
108+
"2"
109+
],
110+
"eot_token_id": 1,
111+
"max_length": 8192,
112+
"task_hashes": {},
113+
"model_source": "sae_steered_beta",
114+
"model_name": "/home/cs29824/matthew/lm-evaluation-harness/examples/dog_steer.csv",
115+
"model_name_sanitized": "__home__cs29824__matthew__lm-evaluation-harness__examples__dog_steer.csv",
116+
"system_instruction": null,
117+
"system_instruction_sha": null,
118+
"fewshot_as_multiturn": false,
119+
"chat_template": null,
120+
"chat_template_sha": null,
121+
"start_time": 2970008.635285475,
122+
"end_time": 2970078.697630497,
123+
"total_evaluation_time_seconds": "70.06234502233565"
124+
}

0 commit comments

Comments
 (0)