Skip to content

Commit c0cce90

Browse files
committed
Hint Generator type
1 parent 8ccb10a commit c0cce90

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tokenizer/rwkv_tokenizer.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,15 @@ def printTokens(self, tokens):
217217
# Tokenizer #4 (fast) https://github.com/LoganDark
218218
########################################################################################################
219219

220+
from typing import Generator
220221
from ast import literal_eval
221222

222223
class FastTokenizer:
223224
__slots__ = ('tok2val', 'tok2len', 'root')
224225

225226
def __init__(self, file_name):
226-
self.tok2val = [b''] * 65536
227-
self.tok2len = [0] * 65536
227+
self.tok2val = {}
228+
self.tok2len = {}
228229
self.root = {}
229230

230231
with open(file_name, 'rt', encoding = 'utf-8') as file:
@@ -255,7 +256,7 @@ def next_token(self, src: bytes) -> int:
255256
break
256257
return last_token
257258

258-
def encode_bytes(self, src: bytes) -> list[int]:
259+
def encode_bytes(self, src: bytes) -> Generator[int, None, None]:
259260
start, stop = 0, len(src)
260261
while start < stop:
261262
last_token, last = None, self.root
@@ -274,7 +275,7 @@ def encode_bytes(self, src: bytes) -> list[int]:
274275
def decode_bytes(self, tokens: list[int]) -> bytes:
275276
return b''.join(map(self.tok2val.__getitem__, tokens))
276277

277-
def encode(self, src: str) -> list[int]:
278+
def encode(self, src: str) -> Generator[int, None, None]:
278279
return self.encode_bytes(src.encode('utf-8'))
279280

280281
def decode(self, tokens: list[int]) -> str:

0 commit comments

Comments
 (0)