Skip to content

Commit

Permalink
fix convert_inetrnevo2hf for internlm2 model
Browse files Browse the repository at this point in the history
  • Loading branch information
season0528 committed Jan 3, 2025
1 parent d03c6f9 commit a22c157
Showing 1 changed file with 8 additions and 19 deletions.
27 changes: 8 additions & 19 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional

import torch
from einops import rearrange
from torch import nn
from tqdm import tqdm

Expand Down Expand Up @@ -771,18 +770,6 @@ def unique_kv_index(i):

@staticmethod
def convert_internevo2hf_weights(src: str, tgt: str) -> None:
def permute(qkv, num_heads, num_kv_heads, head_dim, adapt_hf=True):
if adapt_hf:
return qkv
q_per_kv = num_heads // num_kv_heads
qkv = rearrange(qkv.T, "o (g n i) -> o g n i", n=q_per_kv + 2, i=head_dim)
q, k, v = qkv[..., :q_per_kv, :], qkv[..., -2:-1, :], qkv[..., -1:, :]
q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1)
k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)
qkv = torch.cat((q, k, v), dim=2)
qkv = rearrange(qkv, "o g n i -> o (g n i)").T
return qkv

model_config = gpc.config.model
tp_mode = gpc.config.parallel.tensor["mode"]
row_dim = 0 if tp_mode == "isp" else 1
Expand All @@ -808,12 +795,14 @@ def permute(qkv, num_heads, num_kv_heads, head_dim, adapt_hf=True):
}
)
# attn
state_dict[f"model.layers.{layer_i}.attention.wqkv.weight"] = permute(
torch.cat([states[i][f"layers.{layer_i}.attention.wqkv.weight"] for i in range(num_shards)], dim=0),
num_heads=model_config["num_attention_heads"],
num_kv_heads=model_config["num_kv_attention_heads"],
head_dim=model_config["hidden_size"] // model_config["num_attention_heads"],
adapt_hf=model_config.get("adapt_hf", True),
state_dict[f"model.layers.{layer_i}.attention.wq.weight"] = torch.cat(
[states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.attention.wk.weight"] = torch.cat(
[states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.attention.wv.weight"] = torch.cat(
[states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.attention.wo.weight"] = torch.cat(
[states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim
Expand Down

0 comments on commit a22c157

Please sign in to comment.