Skip to content

Commit

Permalink
[Tool]: Support converting InternLM2 to Llama format (InternLM#627)
Browse files Browse the repository at this point in the history
Co-authored-by: x54-729 <[email protected]>
  • Loading branch information
gaoyang07 and x54-729 authored Jan 19, 2024
1 parent 5d9ef21 commit 4281caf
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tools/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# InternLM2 tools

## 1. Convert to LLaMA

We offer the `convert2llama.py`, designed to seamlessly transform InternLM2 (HF format) into LLaMA (HF format). Here, HF refers to the format used by HuggingFace Transformers.

### Usage
```
python convert2llama.py --src /path/to/internlm2/ckpt --tgt /path/to/target/ckpt
```

### Note

While the `convert2llama.py` tool is available, we still advise opting for InternLM2 when practical, chiefly due to its superior efficiency. InternLM2, which is adapted from LLaMA, streamlines the process by integrating the `Wq`, `Wk`, `Wv` weight matrices into a single matrix `Wqkv`. This integration leads to approximately a **5%** speed increase during training. Given the substantial costs associated with pre-training, this efficiency boost can result in significant savings.
136 changes: 136 additions & 0 deletions tools/convert2llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) InternLM. All rights reserved.
import argparse
import json
import os

import torch
from einops import rearrange
from tqdm import tqdm
from transformers import AutoConfig, LlamaConfig, LlamaTokenizer


def save_conifg(config, tgt):
config_dict = config.to_dict()
unnecessary_keys = [
"_name_or_path",
"auto_map",
"transformers_version",
"model_type",
"architectures",
"tokenizer_class",
"attn_implementation",
]
for k in unnecessary_keys:
config_dict.pop(k, None)
config_dict["attention_bias"] = config_dict.pop("bias")
config_dict["architectures"] = ["LlamaForCausalLM"]
llama_config = LlamaConfig(**config_dict)
llama_config.save_pretrained(tgt)


def convert(src, tgt):
"""Convert InternLM2 huggingface checkpoints to Llama-style."""

print("Convert InternLM2 huggingface checkpoints to Llama...")

config = AutoConfig.from_pretrained(src, trust_remote_code=True)
assert not config.bias, "Cannot convert InternLM Model with bias to LLaMA."

head_dim = config.hidden_size // config.num_attention_heads
num_key_value_groups = config.num_attention_heads // config.num_key_value_heads

# load index json file
index_file = os.path.join(src, "pytorch_model.bin.index.json")
if os.path.exists(index_file):
with open(index_file) as fp:
index_dict = json.load(fp)
index_dict["weight_map"] = {}
else:
index_dict = None

os.makedirs(tgt, exist_ok=True)
for filename in tqdm(os.listdir(src)):
if not filename.endswith(".bin"):
continue
states = torch.load(os.path.join(src, filename))
llama_states = {}
for k, v in states.copy().items():
if "wqkv" in k:
v = rearrange(
v,
"(h gs d) dim -> h gs d dim",
gs=2 + num_key_value_groups,
d=head_dim,
)
wq, wk, wv = torch.split(v, [num_key_value_groups, 1, 1], dim=1)
wq = rearrange(wq, "h gs d dim -> (h gs d) dim")
wk = rearrange(wk, "h gs d dim -> (h gs d) dim")
wv = rearrange(wv, "h gs d dim -> (h gs d) dim")
_prefix = k.split("attention")[0]
wq_key = _prefix + "self_attn.q_proj.weight"
wk_key = _prefix + "self_attn.k_proj.weight"
wv_key = _prefix + "self_attn.v_proj.weight"
llama_states[wq_key] = wq.clone()
llama_states[wk_key] = wk.clone()
llama_states[wv_key] = wv.clone()

elif "attention.wo" in k:
new_k = k.replace("attention.wo", "self_attn.o_proj")
llama_states[new_k] = v
elif "feed_forward.w1" in k:
new_k = k.replace("feed_forward.w1", "mlp.gate_proj")
llama_states[new_k] = v
elif "feed_forward.w2" in k:
new_k = k.replace("feed_forward.w2", "mlp.down_proj")
llama_states[new_k] = v
elif "feed_forward.w3" in k:
new_k = k.replace("feed_forward.w3", "mlp.up_proj")
llama_states[new_k] = v
elif "attention_norm" in k:
new_k = k.replace("attention_norm", "input_layernorm")
llama_states[new_k] = v
elif "ffn_norm" in k:
new_k = k.replace("ffn_norm", "post_attention_layernorm")
llama_states[new_k] = v
elif "tok_embeddings" in k:
llama_states["model.embed_tokens.weight"] = v
elif "output" in k:
llama_states["lm_head.weight"] = v
else:
llama_states[k] = v

if index_dict is not None:
for k in llama_states:
index_dict["weight_map"][k] = filename
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
torch.save(llama_states, os.path.join(tgt, filename))
del states

print("Saving config and tokenizer...")
# index.json
if index_dict is not None:
with open(os.path.join(tgt, "pytorch_model.bin.index.json"), "w") as fp:
json.dump(index_dict, fp, indent=2)
# tokenizer
tokenizer = LlamaTokenizer.from_pretrained(src)
tokenizer.init_kwargs.pop("auto_map", None)
tokenizer.save_pretrained(tgt)
# config
save_conifg(config, tgt)
print("Done!")


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str, help="Input folder")
parser.add_argument("--tgt", type=str, help="Output folder")

args = parser.parse_args()

return args


if __name__ == "__main__":
args = parse_args()

convert(args.src, args.tgt)

0 comments on commit 4281caf

Please sign in to comment.