Skip to content

Commit b139280

Browse files
committed
[add] main_llm_tokenizer experiment
1 parent 770fa41 commit b139280

File tree

4 files changed

+191
-0
lines changed

4 files changed

+191
-0
lines changed

projects/llm_framework/main_llm_tokenizer/Kconfig

Whitespace-only changes.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
3+
Import('env')
4+
with open(env['PROJECT_TOOL_S']) as f:
5+
exec(f.read())
6+
7+
SRCS = Glob('src/*.c*')
8+
INCLUDE = []
9+
PRIVATE_INCLUDE = []
10+
REQUIREMENTS = ['pthread']
11+
STATIC_LIB = []
12+
DYNAMIC_LIB = []
13+
DEFINITIONS = []
14+
DEFINITIONS_PRIVATE = []
15+
LDFLAGS = []
16+
LINK_SEARCH_PATH = []
17+
STATIC_FILES = []
18+
19+
STATIC_FILES += [AFile('_tokenizer.py')]
20+
21+
env['COMPONENTS'].append({'target':'llm_tokenizer',
22+
'SRCS':SRCS,
23+
'INCLUDE':INCLUDE,
24+
'PRIVATE_INCLUDE':PRIVATE_INCLUDE,
25+
'REQUIREMENTS':REQUIREMENTS,
26+
'STATIC_LIB':STATIC_LIB,
27+
'DYNAMIC_LIB':DYNAMIC_LIB,
28+
'DEFINITIONS':DEFINITIONS,
29+
'DEFINITIONS_PRIVATE':DEFINITIONS_PRIVATE,
30+
'LDFLAGS':LDFLAGS,
31+
'LINK_SEARCH_PATH':LINK_SEARCH_PATH,
32+
'STATIC_FILES':STATIC_FILES,
33+
'REGISTER':'project'
34+
})
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#!/bin/env python3
2+
import zmq
3+
from tokenizers import Tokenizer
4+
import http.server
5+
import socketserver
6+
import threading
7+
import socket
8+
import json
9+
from http.server import HTTPServer, BaseHTTPRequestHandler
10+
tokenizers_obj = {}
11+
tokenizers_content = {}
12+
server_obj = {}
13+
url_path_map = {}
14+
15+
context = zmq.Context()
16+
zmq_socket = context.socket(zmq.REP)
17+
zmq_socket.bind("ipc:///tmp/rpc.tokenizer")
18+
19+
20+
class Request(BaseHTTPRequestHandler):
21+
timeout = 5
22+
server_version = "Apache"
23+
def do_GET(self):
24+
server_ip, server_port = self.server.server_address
25+
val = 'http://localhost:{}'.format(server_port)
26+
token_obj = tokenizers_obj[url_path_map[val]]
27+
28+
print(self.path)
29+
self.send_response(200)
30+
self.send_header("type", "get")
31+
self.end_headers()
32+
if self.path == "/bos_id":
33+
bos_id = token_obj.bos_id
34+
# print(bos_id)
35+
# to json
36+
if bos_id is None:
37+
msg = json.dumps({"bos_id": -1})
38+
else:
39+
msg = json.dumps({"bos_id": bos_id})
40+
elif self.path == "/eos_id":
41+
eos_id = token_obj.eos_id
42+
if eos_id is None:
43+
msg = json.dumps({"eos_id": -1})
44+
else:
45+
msg = json.dumps({"eos_id": eos_id})
46+
else:
47+
msg = "error"
48+
print(msg)
49+
msg = str(msg).encode()
50+
self.wfile.write(msg)
51+
52+
def do_POST(self):
53+
server_ip, server_port = self.server.server_address
54+
val = 'http://localhost:{}'.format(server_port)
55+
token_obj = tokenizers_obj[url_path_map[val]]
56+
data = self.rfile.read(
57+
int(self.headers["content-length"])
58+
)
59+
data = data.decode()
60+
self.send_response(200)
61+
self.send_header("type", "post")
62+
self.end_headers()
63+
if self.path == "/encode":
64+
req = json.loads(data)
65+
prompt = req['text']
66+
token_ids = token_obj.encode(prompt, tokenizers_content[url_path_map[val]])
67+
if token_ids is None:
68+
msg = json.dumps({"token_ids": -1})
69+
else:
70+
msg = json.dumps({"token_ids": token_ids})
71+
elif self.path == "/decode":
72+
req = json.loads(data)
73+
token_ids = req["token_ids"]
74+
text = token_obj.decode(token_ids)
75+
if text is None:
76+
msg = json.dumps({"text": ""})
77+
else:
78+
msg = json.dumps({"text": text})
79+
else:
80+
msg = "error"
81+
print(msg)
82+
msg = str(msg).encode()
83+
self.wfile.write(msg)
84+
85+
86+
def start_server(httpd):
87+
httpd.serve_forever()
88+
89+
def find_free_port():
90+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
91+
s.bind(('', 0))
92+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
93+
return s.getsockname()[1]
94+
95+
def rpc_forever():
96+
while True:
97+
try:
98+
message_parts = socket.recv_multipart()
99+
action = message_parts[0].decode('utf-8')
100+
rawmsg = message_parts[1].decode('utf-8')
101+
102+
val = 'None'
103+
if action == 'creat_tokenizer':
104+
json_args = json.loads(rawmsg)
105+
tokenizer_path = json_args['path']
106+
tokenizer_content = json_args['content']
107+
tokenizers_content[tokenizer_path] = tokenizer_content
108+
tokenizers_obj[tokenizer_path] = Tokenizer.from_file(tokenizer_path)
109+
server_obj[tokenizer_path] = socketserver.TCPServer(("", 0), Request)
110+
server_ip, server_port = server_obj[tokenizer_path].server_address
111+
val = 'http://localhost:{}'.format(server_port)
112+
url_path_map[val] = tokenizer_path
113+
thread = threading.Thread(target=start_server, args=(server_obj[tokenizer_path]))
114+
thread.daemon = True
115+
thread.start()
116+
117+
if action == 'close_tokenizer':
118+
tokenizer_path = rawmsg.decode('utf-8')
119+
server_obj[tokenizer_path].shutdown()
120+
server_obj[tokenizer_path].server_close()
121+
del server_obj[tokenizer_path]
122+
del tokenizers_obj[tokenizer_path]
123+
del tokenizers_content[tokenizer_path]
124+
for key, value in list(url_path_map.items()):
125+
if value == tokenizer_path:
126+
del url_path_map[key]
127+
zmq_socket.send(val.encode('utf-8'))
128+
except:
129+
pass
130+
131+
if __name__ == "__main__":
132+
rpc_forever()
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD
3+
*
4+
* SPDX-License-Identifier: MIT
5+
*/
6+
#include <stdio.h>
7+
#include <unistd.h>
8+
9+
int main() {
10+
if (access("./_tokenizer.py", F_OK) == 0) {
11+
char *args[] = {"python3", "./_tokenizer.py", NULL};
12+
if (execvp("python3", args) == -1) {
13+
perror("execvp");
14+
return 1;
15+
}
16+
} else if (access("/opt/m5stack/share/_tokenizer.py", F_OK) == 0) {
17+
char *args[] = {"python3", "/opt/m5stack/share/_tokenizer.py", NULL};
18+
if (execvp("python3", args) == -1) {
19+
perror("execvp");
20+
return 1;
21+
}
22+
}
23+
perror("_tokenizer.py miss");
24+
return 0;
25+
}

0 commit comments

Comments
 (0)