Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 8 additions & 223 deletions ai_edge_torch/generative/tools/tokenizer_to_sentencepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,12 @@
"""

import logging
import random
from typing import List

from absl import app
from absl import flags
from ai_edge_torch.generative.tools import tokenizer_to_sentencepiece_lib as lib
import transformers

from sentencepiece import sentencepiece_model_pb2 as spm_model
import sentencepiece as spm

_CHECKPOINT = flags.DEFINE_string(
Expand Down Expand Up @@ -78,231 +76,18 @@
)


def _bytes_to_unicode():
"""Returns list of utf-8 byte and a corresponding list of unicode strings.

It's a copy of https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))


# An inverse map of _bytes_to_unicode() to decode unicode tokens in a HF
# transformers tokenizer into utf-8 tokens in the SentencePiece model.
_BYTE_DECODE_MAP = {v: k for k, v in _bytes_to_unicode().items()}


def _normalize_gpt2(token: str) -> str:
"""Normalizes a unicode character to a utf-8 character.

It's a semantic copy of
https://github.com/openai/gpt-2/blob/master/src/encoder.py#L105.
"""
return bytearray(
[_BYTE_DECODE_MAP[c] if c in _BYTE_DECODE_MAP else ord(c) for c in token]
).decode("utf-8", "replace")


_NORMALIZE_FUNCS = {
"none": lambda x, id, _: x,
"gpt2": lambda x, id, _: _normalize_gpt2(x),
"decode": lambda _, id, tokenizer: tokenizer.decode([id]),
}


def _add_token(
token: str,
id_: int,
tokenizer: transformers.PreTrainedTokenizer,
sp_model: spm_model.ModelProto,
tokens_seen: set[str],
counts: dict[spm_model.ModelProto.SentencePiece.Type, int],
):
"""Adds a token to the SentencePieceModel protobuf with a derived type."""
unk_token = tokenizer.unk_token or tokenizer.pad_token or tokenizer.eos_token
if token == unk_token:
type_ = spm_model.ModelProto.SentencePiece.UNKNOWN
elif token in tokenizer.special_tokens_map:
type_ = spm_model.ModelProto.SentencePiece.CONTROL
sp_model.trainer_spec.control_symbols.append(token)
elif token in tokenizer.get_added_vocab():
type_ = spm_model.ModelProto.SentencePiece.USER_DEFINED
sp_model.trainer_spec.user_defined_symbols.append(token)
else:
type_ = spm_model.ModelProto.SentencePiece.NORMAL

count_type = type_
normalized = _NORMALIZE_FUNCS[_NORMALIZE_TOKENS.value](token, id_, tokenizer)
if normalized == token:
pass
elif normalized in tokens_seen:
logging.debug(
'DUPLICATE: token "%s"(id=%d) normalized to "%s"',
token,
id_,
normalized,
)
normalized = token
# Change only the type of counts for logging. When UNUSED is set for SPM
# model, it seems to have some negative impact, i.e. the ratio of mismatched
# ID pairs is slightly higher.
count_type = spm_model.ModelProto.SentencePiece.Type.UNUSED
else:
tokens_seen.add(normalized)
sp_model.pieces.add(piece=normalized, score=-id_, type=type_)
counts[count_type] = counts.get(count_type, 0) + 1

# Fill special meta token info. One token can be used for multiple purposes.
if token == tokenizer.unk_token:
sp_model.trainer_spec.unk_id = id_
sp_model.trainer_spec.unk_piece = normalized
logging.info("Found unk_id: %d, unk_piece: %s", id_, normalized)
if token == tokenizer.bos_token:
sp_model.trainer_spec.bos_id = id_
sp_model.trainer_spec.bos_piece = normalized
logging.info("Found bos_id: %d, bos_piece: %s", id_, normalized)
if token == tokenizer.eos_token:
sp_model.trainer_spec.eos_id = id_
sp_model.trainer_spec.eos_piece = normalized
logging.info("Found eos_id: %d, eos_piece: %s", id_, normalized)
if token == tokenizer.pad_token:
sp_model.trainer_spec.pad_id = id_
sp_model.trainer_spec.pad_piece = normalized
logging.info("Found pad_id: %d, pad_piece: %s", id_, normalized)


def _build_spm_model_from_tokenizer(
tokenizer: transformers.PreTrainedTokenizer,
) -> spm_model.ModelProto:
"""Builds a SentencePieceModel protobuf from a tokenizer."""
sp_model = spm_model.ModelProto()
sp_model.trainer_spec.model_type = spm_model.TrainerSpec.BPE
sp_model.trainer_spec.vocab_size = len(tokenizer.vocab)
sp_model.normalizer_spec.add_dummy_prefix = False
sp_model.normalizer_spec.remove_extra_whitespaces = False
sp_model.normalizer_spec.escape_whitespaces = False
sp_model.denormalizer_spec.CopyFrom(sp_model.normalizer_spec)

id_to_token = {id: tk for tk, id in tokenizer.vocab.items()}
tokens_seen = set(tokenizer.vocab.keys())
counts = {}
for id_ in range(len(tokenizer.vocab)):
_add_token(id_to_token[id_], id_, tokenizer, sp_model, tokens_seen, counts)

logging.info("number of tokens: %d", len(sp_model.pieces))
for type_ in counts:
logging.info(
"number of %s: %d",
spm_model.ModelProto.SentencePiece.Type.Name(type_),
counts[type_],
)

return sp_model


def _is_same_ids(ids_by_tokenizer: List[int], ids: List[int]) -> bool:
"""Checks if the IDs are the same to ones by transformer tokenizer."""
# Transformer tokenizer may insert BOS token at the beginning.
return ids_by_tokenizer == ids or ids_by_tokenizer[1:] == ids


def _log_not_matched(
num_not_matched_strict: int, num_not_matched_loose: int, total: int
):
"""Logs the number of not matched pairs."""
logging.info(
"Not matched strictly %d/%d pairs: %.2f%%, loosely %d/%d pairs: %.2f%%",
num_not_matched_strict,
total,
100 * num_not_matched_strict / total,
num_not_matched_loose,
total,
100 * num_not_matched_loose / total,
)


def _encode_by_spm(
spm_tokenizer: spm.SentencePieceProcessor, string: str
) -> List[int]:
"""Encodes a string by the SentencePiece tokenizer."""
ids = spm_tokenizer.Encode(string)
if isinstance(ids, list):
return ids
# SentencePieceText
return [p.id for p in ids.pieces]


def _verify_spm_tokenizer(
tokenizer: transformers.PreTrainedTokenizer,
spm_tokenizer: spm.SentencePieceProcessor,
):
"""Verifies the SentencePiece tokenizer."""
# First, check if the token IDs encoded by the original tokenizer are the same
# as the token IDs encoded by the SentencePiece tokenizer.
for string in _STRINGS_TO_VERIFY.value:
ids_by_tokenizer = tokenizer.encode(string)
ids_by_spm = _encode_by_spm(spm_tokenizer, string)
logging.info("String to verify: %s", string)
logging.info("Token IDs by the oringal tokenizer: %s", ids_by_tokenizer)
logging.info("Token IDs by the SentencePiece tokenizer: %s", ids_by_spm)
if _is_same_ids(ids_by_tokenizer, ids_by_spm):
logging.info("PASS")
else:
logging.warning("FAIL")

# Second, check if how many strings decoded from the pairs of tokens by the
# original tokenizer are encoded to the same token IDs by the SentencePiece
# tokenizer.
total = _NUM_PAIRS_TO_VERIFY.value
num_not_matched_strict = 0
num_not_matched_loose = 0
for i in range(total):
id_pair = random.sample(list(range(len(tokenizer.vocab))), 2)
string = tokenizer.decode(id_pair)
ids_by_tokenizer = tokenizer.encode(string)
ids_by_spm = _encode_by_spm(spm_tokenizer, string)
if not _is_same_ids(ids_by_tokenizer, ids_by_spm):
num_not_matched_strict += 1
if _is_same_ids(ids_by_tokenizer, id_pair):
num_not_matched_loose += 1
logging.debug(
'NOT MATCHED: "%s", ids=%s, tok=%s, spm=%s',
string,
id_pair,
ids_by_tokenizer,
ids_by_spm,
)
if (i + 1) % 100 == 0:
_log_not_matched(num_not_matched_strict, num_not_matched_loose, i + 1)
_log_not_matched(num_not_matched_strict, num_not_matched_loose, total)


def main(_):
tokenizer = transformers.AutoTokenizer.from_pretrained(_CHECKPOINT.value)
if hasattr(tokenizer, "vocab_file") and tokenizer.vocab_file:
logging.info("vocab_file exists: %s", tokenizer.vocab_file)
with open(tokenizer.vocab_file, "rb") as f:
sp_model = spm_model.ModelProto.FromString(f.read())
else:
logging.info("vocab_file does not exist. Try to build a new one.")
sp_model = _build_spm_model_from_tokenizer(tokenizer)
spm_serialized = lib.convert(tokenizer)

spm_serialized = sp_model.SerializeToString()
spm_tokenizer = spm.SentencePieceProcessor()
spm_tokenizer.LoadFromSerializedProto(spm_serialized)
_verify_spm_tokenizer(tokenizer, spm_tokenizer)
lib.verify_spm_tokenizer(
tokenizer,
spm_tokenizer,
_STRINGS_TO_VERIFY.value,
_NUM_PAIRS_TO_VERIFY.value,
)

logging.info(
"Writing the SentencePieceModel protobuf file to: %s", _OUTPUT_PATH.value
Expand Down
Loading
Loading