Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Qwen support #850

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

</div>



&nbsp;

# ⚡ Lit-GPT
Expand All @@ -38,6 +40,7 @@ Supports the following popular model checkpoints:
| NousResearch Nous-Hermes | 7B, 13B, 70B | [Org page](https://huggingface.co/NousResearch) |
| OpenLM Research [OpenLLaMA](tutorials/download_openllama.md) | 3B, 7B, 13B | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Platypus | 7B, 13B, 70B | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| Qwen | 7B | [Bai, Jinze, et al. 2023](https://arxiv.org/abs/2309.16609) |
| Stability AI StableCode | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| Stability AI [FreeWilly2](tutorials/download_freewilly_2.md) (Stable Beluga 2) | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Stability AI [StableLM](tutorials/download_stablelm.md) | 3B, 7B | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
Expand Down
28 changes: 28 additions & 0 deletions lit_gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Config:
rotary_percentage: float = 0.25
parallel_residual: bool = True
bias: bool = True
# just for Qwen
is_Qwen: Optional[bool] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid this in favor of something that characterizes Qwen, like having bias only in c_attn.
For the time being we could rename this as attn_bias, and then in the future turn bias into a Option[bool, List[str]] if there's a need for it.

lm_head_bias: bool = False
# to use multi-head attention (MHA), set this to `n_head` (default)
# to use multi-query attention (MQA), set this to 1
Expand Down Expand Up @@ -1277,4 +1279,30 @@ def norm_class(self) -> Type:

configs.extend(llama_2_function_calling)


#############
# Qwen
#############
Qwen = [
# https://huggingface.co/Qwen/Qwen-7B/blob/main/config.json
dict(
name="Qwen-7B",
hf_config=dict(org="Qwen", name="Qwen-7B"),
vocab_size=151936,
padded_vocab_size=151936,
block_size=4096,
n_layer=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
is_Qwen=True,
_norm_class="RMSNorm",
norm_eps=1e-06,
_mlp_class="LLaMAMLP",
intermediate_size=11008, # In config.json, its 22016, but in mlp_class, devided by 2
),
]
configs.extend(Qwen)


name_to_config = {config["name"]: config for config in configs}
2 changes: 1 addition & 1 deletion lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(self, config: Config) -> None:
super().__init__()
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
self.attn = nn.Linear(config.n_embd, shape, bias=config.bias or config.is_Qwen)
# output projection
self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# disabled by default
Expand Down
276 changes: 276 additions & 0 deletions lit_gpt/tokenization_qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
Comment on lines +1 to +4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you didn't write the code from this file (which I assume you were not since you did add this license), you should link it to the original source. Is the original from PaddlePaddle? https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/qwen/tokenizer.py

I would advice that you create a version that only implements that few methods required by tokenizer.py

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copy it from their huggingface repo. tokenization_qwen.py


"""Tokenization classes for QWen."""

import base64
import logging
import os
import unicodedata
from typing import Collection, Dict, List, Set, Tuple, Union

import tiktoken
from transformers import PreTrainedTokenizer, AddedToken
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This project doesn't have transformers as a dependency, so this import is not possible


logger = logging.getLogger(__name__)


VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}

PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
ENDOFTEXT = "<|endoftext|>"
IMSTART = "<|im_start|>"
IMEND = "<|im_end|>"
# as the default behavior is changed to allow special tokens in
# regular texts, the surface forms of special tokens need to be
# as different as possible to minimize the impact
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
# changed to use actual index to avoid misconfiguration with vocabulary expansion
SPECIAL_START_ID = 151643
SPECIAL_TOKENS = tuple(
enumerate(
(
(
ENDOFTEXT,
IMSTART,
IMEND,
)
+ EXTRAS
),
start=SPECIAL_START_ID,
)
)
SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)


def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
with open(tiktoken_bpe_file, "rb") as f:
contents = f.read()
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}


class QWenTokenizer(PreTrainedTokenizer):
"""QWen tokenizer."""

vocab_files_names = VOCAB_FILES_NAMES

def __init__(
self,
vocab_file,
errors="replace",
extra_vocab_file=None,
**kwargs,
):
super().__init__(**kwargs)

# how to handle errors in decoding UTF-8 byte sequences
# use ignore if you are in streaming inference
self.errors = errors

self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
self.special_tokens = {
token: index
for index, token in SPECIAL_TOKENS
}

# try load extra vocab from file
if extra_vocab_file is not None:
used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
for token, index in extra_mergeable_ranks.items():
if token in self.mergeable_ranks:
logger.info(f"extra token {token} exists, skipping")
continue
if index in used_ids:
logger.info(f'the index {index} for extra token {token} exists, skipping')
continue
self.mergeable_ranks[token] = index
# the index may be sparse after this, but don't worry tiktoken.Encoding will handle this

enc = tiktoken.Encoding(
"Qwen",
pat_str=PAT_STR,
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
assert (
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"

self.decoder = {
v: k for k, v in self.mergeable_ranks.items()
} # type: dict[int, bytes|str]
self.decoder.update({v: k for k, v in self.special_tokens.items()})

self.tokenizer = enc # type: tiktoken.Encoding

self.eod_id = self.tokenizer.eot_token
self.im_start_id = self.special_tokens[IMSTART]
self.im_end_id = self.special_tokens[IMEND]

def __getstate__(self):
# for pickle lovers
state = self.__dict__.copy()
del state["tokenizer"]
return state

def __setstate__(self, state):
# tokenizer is not python native; don't pass it; rebuild it
self.__dict__.update(state)
enc = tiktoken.Encoding(
"Qwen",
pat_str=PAT_STR,
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
self.tokenizer = enc

def __len__(self) -> int:
return self.tokenizer.n_vocab

def get_vocab(self) -> Dict[bytes, int]:
return self.mergeable_ranks

def convert_tokens_to_ids(
self, tokens: Union[bytes, str, List[Union[bytes, str]]]
) -> List[int]:
ids = []
if isinstance(tokens, (str, bytes)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.mergeable_ranks.get(tokens)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.mergeable_ranks.get(token))
return ids

def _add_tokens(
self,
new_tokens: Union[List[str], List[AddedToken]],
special_tokens: bool = False,
) -> int:
if not special_tokens and new_tokens:
raise ValueError("Adding regular tokens is not supported")
for token in new_tokens:
surface_form = token.content if isinstance(token, AddedToken) else token
if surface_form not in SPECIAL_TOKENS_SET:
raise ValueError("Adding unknown special tokens is not supported")
return 0

def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
"""
Save only the vocabulary of the tokenizer (vocabulary).

Returns:
`Tuple(str)`: Paths to the files saved.
"""
file_path = os.path.join(save_directory, "qwen.tiktoken")
with open(file_path, "w", encoding="utf8") as w:
for k, v in self.mergeable_ranks.items():
line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
w.write(line)
return (file_path,)

def tokenize(
self,
text: str,
allowed_special: Union[Set, str] = "all",
disallowed_special: Union[Collection, str] = (),
**kwargs,
) -> List[Union[bytes, str]]:
"""
Converts a string in a sequence of tokens.

Args:
text (`str`):
The sequence to be encoded.
allowed_special (`Literal["all"]` or `set`):
The surface forms of the tokens to be encoded as special tokens in regular texts.
Default to "all".
disallowed_special (`Literal["all"]` or `Collection`):
The surface forms of the tokens that should not be in regular texts and trigger errors.
Default to an empty tuple.

kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific encode method.

Returns:
`List[bytes|str]`: The list of tokens.
"""
tokens = []
text = unicodedata.normalize("NFC", text)

# this implementation takes a detour: text -> token id -> token surface forms
for t in self.tokenizer.encode(
text, allowed_special=allowed_special, disallowed_special=disallowed_special
):
tokens.append(self.decoder[t])
return tokens

def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
"""
Converts a sequence of tokens in a single string.
"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors=self.errors)
temp = b""
text += t
elif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should only be of type types or str")
if temp:
text += temp.decode("utf-8", errors=self.errors)
return text

@property
def vocab_size(self):
return self.tokenizer.n_vocab

def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
"""Converts an id to a token, special tokens included"""
if index in self.decoder:
return self.decoder[index]
raise ValueError("unknown ids")

def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
"""Converts a token to an id using the vocab, special tokens included"""
if token in self.special_tokens:
return self.special_tokens[token]
if token in self.mergeable_ranks:
return self.mergeable_ranks[token]
raise ValueError("unknown token")

def _tokenize(self, text: str, **kwargs):
"""
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).

Do NOT take care of added tokens.
"""
raise NotImplementedError

def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
errors: str = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
if skip_special_tokens:
token_ids = [i for i in token_ids if i < self.eod_id]
return self.tokenizer.decode(token_ids, errors=errors or self.errors)
6 changes: 6 additions & 0 deletions lit_gpt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
self.bos_id = config.get("bos_token_id")
if self.eos_id is None:
self.eos_id = config.get("eos_token_id")
elif (vocabulary_path := checkpoint_dir / "qwen.tiktoken").is_file():
from lit_gpt.tokenization_qwen import QWenTokenizer
self.processor = QWenTokenizer(vocabulary_path)
self.backend = "tiktoken"
else:
raise NotImplementedError

Expand Down Expand Up @@ -91,6 +95,8 @@ def encode(
tokens = self.processor.encode(string).ids
elif self.backend == "sentencepiece":
tokens = self.processor.encode(string)
elif self.backend == "tiktoken":
tokens = self.processor.encode(string)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem like the new processor implements this method. Also, what about decoding?

else:
raise RuntimeError
if bos or (bos is None and self.use_bos):
Expand Down
4 changes: 2 additions & 2 deletions lit_gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
files = {
"lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(),
"lit_config.json": (checkpoint_dir / "lit_config.json").is_file(),
"tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or (
"tokenizer.json OR tokenizer.model OR qwen.tiktoken": (checkpoint_dir / "tokenizer.json").is_file() or (
checkpoint_dir / "tokenizer.model"
).is_file(),
).is_file() or (checkpoint_dir / "qwen.tiktoken").is_file(),
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
}
if checkpoint_dir.is_dir():
Expand Down
Loading