From f58ddce59b422a26f851cd3630f9a51cf2d1c205 Mon Sep 17 00:00:00 2001 From: hanfeng Date: Tue, 17 Oct 2023 17:06:09 +0800 Subject: [PATCH] fix bug in online chat models --- .../chat_models/interface/minimax.py | 2 +- .../chat_models/interface/openai.py | 2 +- .../chat_models/interface/types.py | 2 +- .../chat_models/interface/wenxin.py | 2 +- .../chat_models/interface/xunfei.py | 105 +++++++++--------- .../chat_models/interface/zhipuai.py | 2 +- .../bisheng_langchain/chat_models/minimax.py | 2 +- .../bisheng_langchain/chat_models/wenxin.py | 2 +- .../bisheng_langchain/chat_models/xunfeiai.py | 4 +- .../bisheng_langchain/chat_models/zhipuai.py | 2 +- .../tests/test_chat_sparkai.py | 35 ++++++ 11 files changed, 100 insertions(+), 60 deletions(-) create mode 100644 src/bisheng-langchain/tests/test_chat_sparkai.py diff --git a/src/bisheng-langchain/bisheng_langchain/chat_models/interface/minimax.py b/src/bisheng-langchain/bisheng_langchain/chat_models/interface/minimax.py index 19dba240e..28c319e24 100644 --- a/src/bisheng-langchain/bisheng_langchain/chat_models/interface/minimax.py +++ b/src/bisheng-langchain/bisheng_langchain/chat_models/interface/minimax.py @@ -82,7 +82,7 @@ def __call__(self, inp: ChatInput, verbose=False): status_code = response.status_code created = get_ts() choices = [] - usage = None + usage = Usage() if status_code == 200: try: info = json.loads(response.text) diff --git a/src/bisheng-langchain/bisheng_langchain/chat_models/interface/openai.py b/src/bisheng-langchain/bisheng_langchain/chat_models/interface/openai.py index 3e7ddb815..c5d2a6e2e 100644 --- a/src/bisheng-langchain/bisheng_langchain/chat_models/interface/openai.py +++ b/src/bisheng-langchain/bisheng_langchain/chat_models/interface/openai.py @@ -42,7 +42,7 @@ def __call__(self, inp: ChatInput, verbose=False): req_type = 'chat.completion' status_message = 'success' choices = [] - usage = None + usage = Usage() try: resp = openai.ChatCompletion.create(**payload) status_code = 200 diff --git a/src/bisheng-langchain/bisheng_langchain/chat_models/interface/types.py b/src/bisheng-langchain/bisheng_langchain/chat_models/interface/types.py index 418810317..40f7b6984 100644 --- a/src/bisheng-langchain/bisheng_langchain/chat_models/interface/types.py +++ b/src/bisheng-langchain/bisheng_langchain/chat_models/interface/types.py @@ -36,7 +36,7 @@ class Choice(BaseModel): class Usage(BaseModel): prompt_tokens: int = 0 completion_tokens: int = 0 - total_tokens: int + total_tokens: int = 0 class ChatOutput(BaseModel): diff --git a/src/bisheng-langchain/bisheng_langchain/chat_models/interface/wenxin.py b/src/bisheng-langchain/bisheng_langchain/chat_models/interface/wenxin.py index 3972321d5..a6b7f06b5 100644 --- a/src/bisheng-langchain/bisheng_langchain/chat_models/interface/wenxin.py +++ b/src/bisheng-langchain/bisheng_langchain/chat_models/interface/wenxin.py @@ -78,7 +78,7 @@ def __call__(self, inp: ChatInput, verbose=False): status_code = response.status_code created = get_ts() choices = [] - usage = None + usage = Usage() if status_code == 200: try: info = json.loads(response.text) diff --git a/src/bisheng-langchain/bisheng_langchain/chat_models/interface/xunfei.py b/src/bisheng-langchain/bisheng_langchain/chat_models/interface/xunfei.py index 5e25c1082..bedaa963d 100644 --- a/src/bisheng-langchain/bisheng_langchain/chat_models/interface/xunfei.py +++ b/src/bisheng-langchain/bisheng_langchain/chat_models/interface/xunfei.py @@ -2,6 +2,8 @@ import hashlib import hmac import json +# import ssl +# import threading from datetime import datetime from time import mktime from urllib.parse import urlencode, urlparse @@ -15,9 +17,6 @@ from .types import ChatInput, ChatOutput, Choice, Message, Usage from .utils import get_ts -# import ssl -# import threading - class Ws_Param(object): # 初始化 @@ -41,23 +40,28 @@ def create_url(self): signature_origin += 'GET ' + self.path + ' HTTP/1.1' # 进行hmac-sha256进行加密 - signature_sha = hmac.new(self.APISecret.encode('utf-8'), - signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() + signature_sha = hmac.new( + self.APISecret.encode('utf-8'), + signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() - signature_sha_base64 = base64.b64encode(signature_sha).decode( - encoding='utf-8') + signature_sha_base64 = base64.b64encode( + signature_sha).decode(encoding='utf-8') authorization_origin = ( - f'api_key="{self.APIKey}", ' - f'algorithm="hmac-sha256", headers="host date request-line",' - f' signature="{signature_sha_base64}"') + f'api_key="{self.APIKey}", ' + f'algorithm="hmac-sha256", headers="host date request-line",' + f' signature="{signature_sha_base64}"') authorization = base64.b64encode( authorization_origin.encode('utf-8')).decode(encoding='utf-8') # 将请求的鉴权参数组合为字典 - v = {'authorization': authorization, 'date': date, 'host': self.host} + v = { + 'authorization': authorization, + 'date': date, + 'host': self.host + } # 拼接鉴权参数,生成url url = self.gpt_url + '?' + urlencode(v) # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释, @@ -77,7 +81,7 @@ def on_close(ws): # 收到websocket连接建立的处理 def on_open(ws): - thread.start_new_thread(run, (ws, )) + thread.start_new_thread(run, (ws,)) def run(ws, *args): @@ -118,10 +122,9 @@ def gen_params(appid, question): }, 'payload': { 'message': { - 'text': [{ - 'role': 'user', - 'content': question - }] + 'text': [ + {'role': 'user', 'content': question} + ] } } } @@ -129,20 +132,25 @@ def gen_params(appid, question): class ChatCompletion(object): - def __init__(self, appid, api_key, api_secret, **kwargs): - gpt_url = 'ws://spark-api.xf-yun.com/v1.1/chat' - self.wsParam = Ws_Param(appid, api_key, api_secret, gpt_url) + gpt_url1 = 'ws://spark-api.xf-yun.com/v1.1/chat' + gpt_url2 = 'ws://spark-api.xf-yun.com/v2.1/chat' + self.wsParam1 = Ws_Param(appid, api_key, api_secret, gpt_url1) + self.wsParam2 = Ws_Param(appid, api_key, api_secret, gpt_url2) + websocket.enableTrace(False) - # wsUrl = wsParam.create_url() # todo: modify to the ws pool - # self.mutex = threading.Lock() # self.ws = websocket.WebSocket() - # self.ws.connect(wsUrl) + # self.ws.connect(self.wsUrl) + # self.mutex = threading.Lock() self.header = {'app_id': appid, 'uid': 'elem'} + def __del__(self): + pass + # self.ws.close() + def __call__(self, inp: ChatInput, verbose=False): messages = inp.messages model = inp.model @@ -161,21 +169,18 @@ def __call__(self, inp: ChatInput, verbose=False): role = 'user' new_messages.append({'role': role, 'content': m.content}) + domain = 'generalv2' if model == 'spark-v2.0' else 'general' created = get_ts() payload = { 'header': self.header, - 'payload': { - 'message': { - 'text': new_messages - } - }, + 'payload': {'message': {'text': new_messages}}, 'parameter': { - 'chat': { - 'domain': 'general', - 'temperature': temperature, - 'max_tokens': max_tokens, - 'auditing': 'default' - } + 'chat': { + 'domain': domain, + 'temperature': temperature, + 'max_tokens': max_tokens, + 'auditing': 'default' + } } } @@ -186,48 +191,48 @@ def __call__(self, inp: ChatInput, verbose=False): status_code = 200 status_message = 'success' choices = [] - usage = None + usage = Usage() texts = [] ws = None try: # self.mutex.acquire() - wsUrl = self.wsParam.create_url() + if model == 'spark-v2.0': + wsUrl = self.wsParam2.create_url() + else: + wsUrl = self.wsParam1.create_url() ws = create_connection(wsUrl) ws.send(json.dumps(payload)) + texts = [] while True: raw_data = ws.recv() if not raw_data: break - resp = json.loads(raw_data) if resp['header']['code'] == 0: texts.append( resp['payload']['choices']['text'][0]['content']) - if resp['header']['code'] == 0 and resp['header'][ - 'status'] == 2: + if resp['header']['code'] == 0 and resp['header']['status'] == 2: usage_dict = resp['payload']['usage']['text'] usage_dict.pop('question_tokens') usage = Usage(**usage_dict) except Exception as e: + print('exception', e) status_code = 401 status_message = str(e) finally: if ws: ws.close() + # self.mutex.release() if texts: finish_reason = 'default' msg = Message(role='assistant', content=''.join(texts)) - cho = Choice(index=0, message=msg, finish_reason=finish_reason) + cho = Choice(index=0, message=msg, + finish_reason=finish_reason) choices.append(cho) - if status_code != 200: - raise Exception(status_message) - - return ChatOutput(status_code=status_code, - status_message=status_message, - model=model, - object=req_type, - created=created, - choices=choices, - usage=usage) + return ChatOutput( + status_code=status_code, + status_message=status_message, + model=model, object=req_type, created=created, + choices=choices, usage=usage) diff --git a/src/bisheng-langchain/bisheng_langchain/chat_models/interface/zhipuai.py b/src/bisheng-langchain/bisheng_langchain/chat_models/interface/zhipuai.py index 664485f8d..36cc64746 100644 --- a/src/bisheng-langchain/bisheng_langchain/chat_models/interface/zhipuai.py +++ b/src/bisheng-langchain/bisheng_langchain/chat_models/interface/zhipuai.py @@ -49,7 +49,7 @@ def __call__(self, inp: ChatInput, verbose=False): req_type = 'chat.completion' status_message = 'success' choices = [] - usage = None + usage = Usage() try: resp = zhipuai.model_api.invoke(**payload) status_code = resp['code'] diff --git a/src/bisheng-langchain/bisheng_langchain/chat_models/minimax.py b/src/bisheng-langchain/bisheng_langchain/chat_models/minimax.py index a236ee0a7..3ab6cf3a5 100644 --- a/src/bisheng-langchain/bisheng_langchain/chat_models/minimax.py +++ b/src/bisheng-langchain/bisheng_langchain/chat_models/minimax.py @@ -129,7 +129,7 @@ class ChatMinimaxAI(BaseChatModel): """Whether to stream the results or not.""" n: Optional[int] = 1 """Number of chat completions to generate for each prompt.""" - max_tokens: Optional[int] = None + max_tokens: Optional[int] = 1024 """Maximum number of tokens to generate.""" tiktoken_model_name: Optional[str] = None """The model name to pass to tiktoken when using this class. diff --git a/src/bisheng-langchain/bisheng_langchain/chat_models/wenxin.py b/src/bisheng-langchain/bisheng_langchain/chat_models/wenxin.py index 01e3f038a..d8a2ce1ca 100644 --- a/src/bisheng-langchain/bisheng_langchain/chat_models/wenxin.py +++ b/src/bisheng-langchain/bisheng_langchain/chat_models/wenxin.py @@ -127,7 +127,7 @@ class ChatWenxin(BaseChatModel): """Whether to stream the results or not.""" n: Optional[int] = 1 """Number of chat completions to generate for each prompt.""" - max_tokens: Optional[int] = None + max_tokens: Optional[int] = 1024 """Maximum number of tokens to generate.""" tiktoken_model_name: Optional[str] = None """The model name to pass to tiktoken when using this class. diff --git a/src/bisheng-langchain/bisheng_langchain/chat_models/xunfeiai.py b/src/bisheng-langchain/bisheng_langchain/chat_models/xunfeiai.py index 4e67e627b..0c01d5584 100644 --- a/src/bisheng-langchain/bisheng_langchain/chat_models/xunfeiai.py +++ b/src/bisheng-langchain/bisheng_langchain/chat_models/xunfeiai.py @@ -109,7 +109,7 @@ class ChatXunfeiAI(BaseChatModel): client: Optional[Any] #: :meta private: """Model name to use.""" - model_name: str = Field('spark', alias='model') + model_name: str = Field('spark-1.5', alias='model') temperature: float = 0.5 top_p: float = 0.7 @@ -128,7 +128,7 @@ class ChatXunfeiAI(BaseChatModel): """Whether to stream the results or not.""" n: Optional[int] = 1 """Number of chat completions to generate for each prompt.""" - max_tokens: Optional[int] = None + max_tokens: Optional[int] = 1024 """Maximum number of tokens to generate.""" tiktoken_model_name: Optional[str] = None """The model name to pass to tiktoken when using this class. diff --git a/src/bisheng-langchain/bisheng_langchain/chat_models/zhipuai.py b/src/bisheng-langchain/bisheng_langchain/chat_models/zhipuai.py index 55f42d982..3adcf0090 100644 --- a/src/bisheng-langchain/bisheng_langchain/chat_models/zhipuai.py +++ b/src/bisheng-langchain/bisheng_langchain/chat_models/zhipuai.py @@ -136,7 +136,7 @@ class ChatZhipuAI(BaseChatModel): """Whether to stream the results or not.""" n: Optional[int] = 1 """Number of chat completions to generate for each prompt.""" - max_tokens: Optional[int] = None + max_tokens: Optional[int] = 1024 """Maximum number of tokens to generate.""" tiktoken_model_name: Optional[str] = None """The model name to pass to tiktoken when using this class. diff --git a/src/bisheng-langchain/tests/test_chat_sparkai.py b/src/bisheng-langchain/tests/test_chat_sparkai.py new file mode 100644 index 000000000..ba49e0e6a --- /dev/null +++ b/src/bisheng-langchain/tests/test_chat_sparkai.py @@ -0,0 +1,35 @@ +import os + +from bisheng_langchain.chat_models import ChatXunfeiAI + + +def test_chat_spark_v1(): + xunfeiai_appid = os.environ['xunfeiai_appid'] + xunfeiai_api_key = os.environ['xunfeiai_api_key'] + xunfeiai_api_secret = os.environ['xunfeiai_api_secret'] + chat = ChatXunfeiAI( + model='spark-v1.5', + xunfeiai_appid=xunfeiai_appid, + xunfeiai_api_key=xunfeiai_api_key, + xunfeiai_api_secret=xunfeiai_api_secret) + + resp = chat.predict('你好') + print(resp) + + +def test_chat_spark_v2(): + xunfeiai_appid = os.environ['xunfeiai_appid'] + xunfeiai_api_key = os.environ['xunfeiai_api_key'] + xunfeiai_api_secret = os.environ['xunfeiai_api_secret'] + chat = ChatXunfeiAI( + model='spark-v2.0', + xunfeiai_appid=xunfeiai_appid, + xunfeiai_api_key=xunfeiai_api_key, + xunfeiai_api_secret=xunfeiai_api_secret) + + resp = chat.predict('你好') + print(resp) + + +test_chat_spark_v1() +test_chat_spark_v2()