From a22c15735f42d7c4f701ccf420fba2412ce54772 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 3 Jan 2025 11:49:45 +0800 Subject: [PATCH] fix convert_inetrnevo2hf for internlm2 model --- internlm/model/modeling_internlm2.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 69da0837..6084d88b 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -5,7 +5,6 @@ from typing import Optional import torch -from einops import rearrange from torch import nn from tqdm import tqdm @@ -771,18 +770,6 @@ def unique_kv_index(i): @staticmethod def convert_internevo2hf_weights(src: str, tgt: str) -> None: - def permute(qkv, num_heads, num_kv_heads, head_dim, adapt_hf=True): - if adapt_hf: - return qkv - q_per_kv = num_heads // num_kv_heads - qkv = rearrange(qkv.T, "o (g n i) -> o g n i", n=q_per_kv + 2, i=head_dim) - q, k, v = qkv[..., :q_per_kv, :], qkv[..., -2:-1, :], qkv[..., -1:, :] - q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) - k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) - qkv = torch.cat((q, k, v), dim=2) - qkv = rearrange(qkv, "o g n i -> o (g n i)").T - return qkv - model_config = gpc.config.model tp_mode = gpc.config.parallel.tensor["mode"] row_dim = 0 if tp_mode == "isp" else 1 @@ -808,12 +795,14 @@ def permute(qkv, num_heads, num_kv_heads, head_dim, adapt_hf=True): } ) # attn - state_dict[f"model.layers.{layer_i}.attention.wqkv.weight"] = permute( - torch.cat([states[i][f"layers.{layer_i}.attention.wqkv.weight"] for i in range(num_shards)], dim=0), - num_heads=model_config["num_attention_heads"], - num_kv_heads=model_config["num_kv_attention_heads"], - head_dim=model_config["hidden_size"] // model_config["num_attention_heads"], - adapt_hf=model_config.get("adapt_hf", True), + state_dict[f"model.layers.{layer_i}.attention.wq.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.attention.wk.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.attention.wv.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)], dim=0 ) state_dict[f"model.layers.{layer_i}.attention.wo.weight"] = torch.cat( [states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim