Skip to content

Commit d135da4

Browse files
haozha111copybara-github
authored andcommitted
set qkv_fused_interleaved=False for gemma2.
PiperOrigin-RevId: 736317628
1 parent a4c0214 commit d135da4

File tree

1 file changed

+3
-0
lines changed
  • ai_edge_torch/generative/examples/gemma

1 file changed

+3
-0
lines changed

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
247247
rotary_base=10000,
248248
rotary_percentage=1.0,
249249
qkv_transpose_before_split=True,
250+
# The safetensors from HF is not using the interleaved qkv format, so
251+
# we need to disable interleaving here in the model config.
252+
qkv_fused_interleaved=False,
250253
logit_softcap=50.0,
251254
sliding_window_size=4096,
252255
attn_type=(

0 commit comments

Comments
 (0)