diff --git a/README.md b/README.md index 18ea573927..7e0d656b0f 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,8 @@ + +   # ⚡ Lit-GPT @@ -38,6 +40,7 @@ Supports the following popular model checkpoints: | NousResearch Nous-Hermes | 7B, 13B, 70B | [Org page](https://huggingface.co/NousResearch) | | OpenLM Research [OpenLLaMA](tutorials/download_openllama.md) | 3B, 7B, 13B | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | | Platypus | 7B, 13B, 70B | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | +| Qwen | 7B | [Bai, Jinze, et al. 2023](https://arxiv.org/abs/2309.16609) | | Stability AI StableCode | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | Stability AI [FreeWilly2](tutorials/download_freewilly_2.md) (Stable Beluga 2) | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | | Stability AI [StableLM](tutorials/download_stablelm.md) | 3B, 7B | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | diff --git a/lit_gpt/config.py b/lit_gpt/config.py index dd79d20530..92eef88bea 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -27,6 +27,8 @@ class Config: rotary_percentage: float = 0.25 parallel_residual: bool = True bias: bool = True + # just for Qwen + is_Qwen: Optional[bool] = None lm_head_bias: bool = False # to use multi-head attention (MHA), set this to `n_head` (default) # to use multi-query attention (MQA), set this to 1 @@ -1277,4 +1279,30 @@ def norm_class(self) -> Type: configs.extend(llama_2_function_calling) + +############# +# Qwen +############# +Qwen = [ + # https://huggingface.co/Qwen/Qwen-7B/blob/main/config.json + dict( + name="Qwen-7B", + hf_config=dict(org="Qwen", name="Qwen-7B"), + vocab_size=151936, + padded_vocab_size=151936, + block_size=4096, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + is_Qwen=True, + _norm_class="RMSNorm", + norm_eps=1e-06, + _mlp_class="LLaMAMLP", + intermediate_size=11008, # In config.json, its 22016, but in mlp_class, devided by 2 + ), +] +configs.extend(Qwen) + + name_to_config = {config["name"]: config for config in configs} diff --git a/lit_gpt/model.py b/lit_gpt/model.py index 2fda57a6c5..1e66875748 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -174,7 +174,7 @@ def __init__(self, config: Config) -> None: super().__init__() shape = (config.n_head + 2 * config.n_query_groups) * config.head_size # key, query, value projections for all heads, but in a batch - self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) + self.attn = nn.Linear(config.n_embd, shape, bias=config.bias or config.is_Qwen) # output projection self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) # disabled by default diff --git a/lit_gpt/tokenization_qwen.py b/lit_gpt/tokenization_qwen.py new file mode 100644 index 0000000000..2a526d66c3 --- /dev/null +++ b/lit_gpt/tokenization_qwen.py @@ -0,0 +1,276 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Tokenization classes for QWen.""" + +import base64 +import logging +import os +import unicodedata +from typing import Collection, Dict, List, Set, Tuple, Union + +import tiktoken +from transformers import PreTrainedTokenizer, AddedToken + +logger = logging.getLogger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"} + +PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" +ENDOFTEXT = "<|endoftext|>" +IMSTART = "<|im_start|>" +IMEND = "<|im_end|>" +# as the default behavior is changed to allow special tokens in +# regular texts, the surface forms of special tokens need to be +# as different as possible to minimize the impact +EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) +# changed to use actual index to avoid misconfiguration with vocabulary expansion +SPECIAL_START_ID = 151643 +SPECIAL_TOKENS = tuple( + enumerate( + ( + ( + ENDOFTEXT, + IMSTART, + IMEND, + ) + + EXTRAS + ), + start=SPECIAL_START_ID, + ) +) +SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS) + + +def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: + with open(tiktoken_bpe_file, "rb") as f: + contents = f.read() + return { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in contents.splitlines() if line) + } + + +class QWenTokenizer(PreTrainedTokenizer): + """QWen tokenizer.""" + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + errors="replace", + extra_vocab_file=None, + **kwargs, + ): + super().__init__(**kwargs) + + # how to handle errors in decoding UTF-8 byte sequences + # use ignore if you are in streaming inference + self.errors = errors + + self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int] + self.special_tokens = { + token: index + for index, token in SPECIAL_TOKENS + } + + # try load extra vocab from file + if extra_vocab_file is not None: + used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values()) + extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file) + for token, index in extra_mergeable_ranks.items(): + if token in self.mergeable_ranks: + logger.info(f"extra token {token} exists, skipping") + continue + if index in used_ids: + logger.info(f'the index {index} for extra token {token} exists, skipping') + continue + self.mergeable_ranks[token] = index + # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this + + enc = tiktoken.Encoding( + "Qwen", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + assert ( + len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab + ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" + + self.decoder = { + v: k for k, v in self.mergeable_ranks.items() + } # type: dict[int, bytes|str] + self.decoder.update({v: k for k, v in self.special_tokens.items()}) + + self.tokenizer = enc # type: tiktoken.Encoding + + self.eod_id = self.tokenizer.eot_token + self.im_start_id = self.special_tokens[IMSTART] + self.im_end_id = self.special_tokens[IMEND] + + def __getstate__(self): + # for pickle lovers + state = self.__dict__.copy() + del state["tokenizer"] + return state + + def __setstate__(self, state): + # tokenizer is not python native; don't pass it; rebuild it + self.__dict__.update(state) + enc = tiktoken.Encoding( + "Qwen", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + self.tokenizer = enc + + def __len__(self) -> int: + return self.tokenizer.n_vocab + + def get_vocab(self) -> Dict[bytes, int]: + return self.mergeable_ranks + + def convert_tokens_to_ids( + self, tokens: Union[bytes, str, List[Union[bytes, str]]] + ) -> List[int]: + ids = [] + if isinstance(tokens, (str, bytes)): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.mergeable_ranks.get(tokens) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.mergeable_ranks.get(token)) + return ids + + def _add_tokens( + self, + new_tokens: Union[List[str], List[AddedToken]], + special_tokens: bool = False, + ) -> int: + if not special_tokens and new_tokens: + raise ValueError("Adding regular tokens is not supported") + for token in new_tokens: + surface_form = token.content if isinstance(token, AddedToken) else token + if surface_form not in SPECIAL_TOKENS_SET: + raise ValueError("Adding unknown special tokens is not supported") + return 0 + + def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: + """ + Save only the vocabulary of the tokenizer (vocabulary). + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + file_path = os.path.join(save_directory, "qwen.tiktoken") + with open(file_path, "w", encoding="utf8") as w: + for k, v in self.mergeable_ranks.items(): + line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n" + w.write(line) + return (file_path,) + + def tokenize( + self, + text: str, + allowed_special: Union[Set, str] = "all", + disallowed_special: Union[Collection, str] = (), + **kwargs, + ) -> List[Union[bytes, str]]: + """ + Converts a string in a sequence of tokens. + + Args: + text (`str`): + The sequence to be encoded. + allowed_special (`Literal["all"]` or `set`): + The surface forms of the tokens to be encoded as special tokens in regular texts. + Default to "all". + disallowed_special (`Literal["all"]` or `Collection`): + The surface forms of the tokens that should not be in regular texts and trigger errors. + Default to an empty tuple. + + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific encode method. + + Returns: + `List[bytes|str]`: The list of tokens. + """ + tokens = [] + text = unicodedata.normalize("NFC", text) + + # this implementation takes a detour: text -> token id -> token surface forms + for t in self.tokenizer.encode( + text, allowed_special=allowed_special, disallowed_special=disallowed_special + ): + tokens.append(self.decoder[t]) + return tokens + + def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: + """ + Converts a sequence of tokens in a single string. + """ + text = "" + temp = b"" + for t in tokens: + if isinstance(t, str): + if temp: + text += temp.decode("utf-8", errors=self.errors) + temp = b"" + text += t + elif isinstance(t, bytes): + temp += t + else: + raise TypeError("token should only be of type types or str") + if temp: + text += temp.decode("utf-8", errors=self.errors) + return text + + @property + def vocab_size(self): + return self.tokenizer.n_vocab + + def _convert_id_to_token(self, index: int) -> Union[bytes, str]: + """Converts an id to a token, special tokens included""" + if index in self.decoder: + return self.decoder[index] + raise ValueError("unknown ids") + + def _convert_token_to_id(self, token: Union[bytes, str]) -> int: + """Converts a token to an id using the vocab, special tokens included""" + if token in self.special_tokens: + return self.special_tokens[token] + if token in self.mergeable_ranks: + return self.mergeable_ranks[token] + raise ValueError("unknown token") + + def _tokenize(self, text: str, **kwargs): + """ + Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based + vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). + + Do NOT take care of added tokens. + """ + raise NotImplementedError + + def _decode( + self, + token_ids: Union[int, List[int]], + skip_special_tokens: bool = False, + errors: str = None, + **kwargs, + ) -> str: + if isinstance(token_ids, int): + token_ids = [token_ids] + if skip_special_tokens: + token_ids = [i for i in token_ids if i < self.eod_id] + return self.tokenizer.decode(token_ids, errors=errors or self.errors) diff --git a/lit_gpt/tokenizer.py b/lit_gpt/tokenizer.py index 3a6758eb62..7e58d28ada 100644 --- a/lit_gpt/tokenizer.py +++ b/lit_gpt/tokenizer.py @@ -46,6 +46,10 @@ def __init__(self, checkpoint_dir: Union[Path, str]) -> None: self.bos_id = config.get("bos_token_id") if self.eos_id is None: self.eos_id = config.get("eos_token_id") + elif (vocabulary_path := checkpoint_dir / "qwen.tiktoken").is_file(): + from lit_gpt.tokenization_qwen import QWenTokenizer + self.processor = QWenTokenizer(vocabulary_path) + self.backend = "tiktoken" else: raise NotImplementedError @@ -91,6 +95,8 @@ def encode( tokens = self.processor.encode(string).ids elif self.backend == "sentencepiece": tokens = self.processor.encode(string) + elif self.backend == "tiktoken": + tokens = self.processor.encode(string) else: raise RuntimeError if bos or (bos is None and self.use_bos): diff --git a/lit_gpt/utils.py b/lit_gpt/utils.py index 590610fc4b..614b623f15 100644 --- a/lit_gpt/utils.py +++ b/lit_gpt/utils.py @@ -61,9 +61,9 @@ def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: files = { "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), - "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( + "tokenizer.json OR tokenizer.model OR qwen.tiktoken": (checkpoint_dir / "tokenizer.json").is_file() or ( checkpoint_dir / "tokenizer.model" - ).is_file(), + ).is_file() or (checkpoint_dir / "qwen.tiktoken").is_file(), "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), } if checkpoint_dir.is_dir(): diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 46faef7136..7d7c895403 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -237,6 +237,45 @@ def copy_weights_phi( param = saver.store_early(param) state_dict[to_name] = param +def copy_weights_Qwen( + config: Config, + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "transformer.wte.weight", + "transformer.h.{}.ln_1.weight": "transformer.h.{}.norm_1.weight", + "transformer.h.{}.ln_2.weight": "transformer.h.{}.norm_2.weight", + "transformer.h.{}.attn.c_attn.bias": "transformer.h.{}.attn.attn.bias", + "transformer.h.{}.attn.c_attn.weight": "transformer.h.{}.attn.attn.weight", + "transformer.h.{}.attn.c_proj.weight": "transformer.h.{}.attn.proj.weight", + "transformer.h.{}.mixer.rotary_emb.inv_freq": None, + "transformer.h.{}.mlp.w1.weight": "transformer.h.{}.mlp.fc_2.weight", + "transformer.h.{}.mlp.w2.weight": "transformer.h.{}.mlp.fc_1.weight", + "transformer.h.{}.mlp.c_proj.weight": "transformer.h.{}.mlp.proj.weight", + "transformer.ln_f.weight": "transformer.ln_f.weight", + "lm_head.weight": "lm_head.weight", + } + + for name, param in hf_weights.items(): + if name.startswith("transformer.h."): + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name].format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, dtype) + if "c_attn" in name: + q_per_kv = config.n_head // config.n_query_groups + total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value + param = param.view(total_qkv, config.n_query_groups, -1).transpose(0, 1) + param = param.reshape(config.n_embd * 3, -1) + if "bias" in name: + param = param.squeeze() + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: split = layer_name.split(".") @@ -277,6 +316,8 @@ def convert_hf_checkpoint( if "falcon" in model_name: copy_fn = partial(copy_weights_falcon, model_name) + elif "Qwen" in model_name: + copy_fn = partial(copy_weights_Qwen, config) elif config._mlp_class in ("LLaMAMLP", "LLaMAMoE"): # holder to reconstitute the split q, k, v qkv_weights = {} diff --git a/scripts/download.py b/scripts/download.py index 6011fcb7e3..05897db801 100644 --- a/scripts/download.py +++ b/scripts/download.py @@ -40,6 +40,8 @@ def download_from_hub( ) download_files = ["tokenizer*", "generation_config.json"] + if "Qwen" in repo_id: + download_files.append("qwen.tiktoken") if not tokenizer_only: if from_safetensors: if not _SAFETENSORS_AVAILABLE: diff --git a/tutorials/download_Qwen.md b/tutorials/download_Qwen.md new file mode 100644 index 0000000000..e9814dccb8 --- /dev/null +++ b/tutorials/download_Qwen.md @@ -0,0 +1,38 @@ +## Download [Qwen](https://github.com/QwenLM/Qwen) weights + +Qwen-7B is the 7B-parameter version of the large language model series, Qwen (abbr. Tongyi Qianwen), proposed by Alibaba Cloud. + +For more info on the models, please see the [Qwen repository](https://github.com/QwenLM/Qwen). + +To see all the available checkpoints for Qwen, run: + +```bash +python scripts/download.py | grep Qwen +``` + +which will print + +```text +Qwen/Qwen-7B +``` + +In order to use a specific Qwen checkpoint, for instance [Qwen-7B](https://huggingface.co/Qwen/Qwen-7B), download the weights and convert the checkpoint to the lit-gpt format: + +```bash +pip install huggingface_hub + +python scripts/download.py --repo_id Qwen/Qwen-7B --from_safetensors true + +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/Qwen/Qwen-7B +``` + +By default, the convert_hf_checkpoint step will use the data type of the HF checkpoint's parameters. In cases where RAM +or disk size is constrained, it might be useful to pass `--dtype bfloat16` to convert all parameters into this smaller precision before continuing. + +You're done! To execute the model just run: + +```bash +pip install tiktoken + +python generate/base.py --prompt "中国的首都是" --checkpoint_dir checkpoints/Qwen/Qwen-7B/ +```