Skip to content

Commit 31d3986

Browse files
authored
Merge pull request #190 from stanfordnlp/gemma2
[P2] Add Gemma 2 model
2 parents 2242266 + b7addfc commit 31d3986

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed

pyvene/models/gemma2/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
Each modeling file in this library is a mapping between
3+
abstract naming of intervention anchor points and actual
4+
model module defined in the huggingface library.
5+
6+
We also want to let the intervention library know how to
7+
config the dimensions of intervention based on model config
8+
defined in the huggingface library.
9+
"""
10+
11+
12+
import torch
13+
from ..constants import *
14+
15+
16+
gemma2_type_to_module_mapping = {
17+
"block_input": ("layers[%s]", CONST_INPUT_HOOK),
18+
"block_output": ("layers[%s]", CONST_OUTPUT_HOOK),
19+
"mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK),
20+
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
21+
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
22+
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
23+
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
24+
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
25+
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
26+
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
27+
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
28+
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
29+
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
30+
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
31+
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
32+
}
33+
34+
35+
gemma2_type_to_dimension_mapping = {
36+
"n_head": ("num_attention_heads",),
37+
"n_kv_head": ("num_key_value_heads",),
38+
"block_input": ("hidden_size",),
39+
"block_output": ("hidden_size",),
40+
"mlp_activation": ("intermediate_size",),
41+
"mlp_output": ("hidden_size",),
42+
"mlp_input": ("hidden_size",),
43+
"attention_value_output": ("hidden_size",),
44+
"head_attention_value_output": ("head_dim",),
45+
"attention_output": ("hidden_size",),
46+
"attention_input": ("hidden_size",),
47+
"query_output": ("hidden_size",),
48+
"key_output": ("hidden_size",),
49+
"value_output": ("hidden_size",),
50+
"head_query_output": ("head_dim",),
51+
"head_key_output": ("head_dim",),
52+
"head_value_output": ("hhead_dim",),
53+
}
54+
55+
56+
"""gemma2 model with LM head"""
57+
gemma2_lm_type_to_module_mapping = {}
58+
for k, v in gemma2_type_to_module_mapping.items():
59+
gemma2_lm_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]
60+
61+
62+
gemma2_lm_type_to_dimension_mapping = gemma2_type_to_dimension_mapping
63+
64+
65+
"""gemma2 model with classifier head"""
66+
gemma2_classifier_type_to_module_mapping = {}
67+
for k, v in gemma2_type_to_module_mapping.items():
68+
gemma2_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]
69+
70+
71+
gemma2_classifier_type_to_dimension_mapping = gemma2_type_to_dimension_mapping
72+
73+
74+
def create_gemma2(
75+
name="google/gemma2-2b", cache_dir=None, dtype=torch.bfloat16
76+
):
77+
"""Creates a Causal LM model, config, and tokenizer from the given name and revision"""
78+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
79+
80+
config = AutoConfig.from_pretrained(name, cache_dir=cache_dir)
81+
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
82+
gemma = AutoModelForCausalLM.from_pretrained(
83+
name,
84+
config=config,
85+
cache_dir=cache_dir,
86+
torch_dtype=dtype,
87+
)
88+
print("loaded model")
89+
return config, tokenizer, gemma

pyvene/models/intervenable_modelcard.py

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .llama.modelings_intervenable_llama import *
33
from .mistral.modellings_intervenable_mistral import *
44
from .gemma.modelings_intervenable_gemma import *
5+
from .gemma2.modelings_intervenable_gemma2 import *
56
from .gpt2.modelings_intervenable_gpt2 import *
67
from .gpt_neo.modelings_intervenable_gpt_neo import *
78
from .gpt_neox.modelings_intervenable_gpt_neox import *
@@ -58,6 +59,8 @@
5859
hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_module_mapping,
5960
hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_module_mapping,
6061
hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_module_mapping,
62+
hf_models.gemma2.modeling_gemma2.Gemma2Model: gemma2_type_to_module_mapping,
63+
hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_module_mapping,
6164
hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_module_mapping,
6265
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_module_mapping,
6366
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping,
@@ -91,6 +94,8 @@
9194
hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_dimension_mapping,
9295
hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_dimension_mapping,
9396
hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_dimension_mapping,
97+
hf_models.gemma2.modeling_gemma2.Gemma2Model: gemma2_type_to_dimension_mapping,
98+
hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_dimension_mapping,
9499
hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_dimension_mapping,
95100
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_dimension_mapping,
96101
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping,

0 commit comments

Comments
 (0)