Skip to content

Commit

Permalink
fix bug in online chat models (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaojin3616 authored Oct 17, 2023
2 parents 47d97d9 + f58ddce commit 4f9d31a
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __call__(self, inp: ChatInput, verbose=False):
status_code = response.status_code
created = get_ts()
choices = []
usage = None
usage = Usage()
if status_code == 200:
try:
info = json.loads(response.text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __call__(self, inp: ChatInput, verbose=False):
req_type = 'chat.completion'
status_message = 'success'
choices = []
usage = None
usage = Usage()
try:
resp = openai.ChatCompletion.create(**payload)
status_code = 200
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Choice(BaseModel):
class Usage(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int
total_tokens: int = 0


class ChatOutput(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __call__(self, inp: ChatInput, verbose=False):
status_code = response.status_code
created = get_ts()
choices = []
usage = None
usage = Usage()
if status_code == 200:
try:
info = json.loads(response.text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import hashlib
import hmac
import json
# import ssl
# import threading
from datetime import datetime
from time import mktime
from urllib.parse import urlencode, urlparse
Expand All @@ -15,9 +17,6 @@
from .types import ChatInput, ChatOutput, Choice, Message, Usage
from .utils import get_ts

# import ssl
# import threading


class Ws_Param(object):
# 初始化
Expand All @@ -41,23 +40,28 @@ def create_url(self):
signature_origin += 'GET ' + self.path + ' HTTP/1.1'

# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.APISecret.encode('utf-8'),
signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = hmac.new(
self.APISecret.encode('utf-8'),
signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()

signature_sha_base64 = base64.b64encode(signature_sha).decode(
encoding='utf-8')
signature_sha_base64 = base64.b64encode(
signature_sha).decode(encoding='utf-8')

authorization_origin = (
f'api_key="{self.APIKey}", '
f'algorithm="hmac-sha256", headers="host date request-line",'
f' signature="{signature_sha_base64}"')
f'api_key="{self.APIKey}", '
f'algorithm="hmac-sha256", headers="host date request-line",'
f' signature="{signature_sha_base64}"')

authorization = base64.b64encode(
authorization_origin.encode('utf-8')).decode(encoding='utf-8')

# 将请求的鉴权参数组合为字典
v = {'authorization': authorization, 'date': date, 'host': self.host}
v = {
'authorization': authorization,
'date': date,
'host': self.host
}
# 拼接鉴权参数,生成url
url = self.gpt_url + '?' + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,
Expand All @@ -77,7 +81,7 @@ def on_close(ws):

# 收到websocket连接建立的处理
def on_open(ws):
thread.start_new_thread(run, (ws, ))
thread.start_new_thread(run, (ws,))


def run(ws, *args):
Expand Down Expand Up @@ -118,31 +122,35 @@ def gen_params(appid, question):
},
'payload': {
'message': {
'text': [{
'role': 'user',
'content': question
}]
'text': [
{'role': 'user', 'content': question}
]
}
}
}
return data


class ChatCompletion(object):

def __init__(self, appid, api_key, api_secret, **kwargs):
gpt_url = 'ws://spark-api.xf-yun.com/v1.1/chat'
self.wsParam = Ws_Param(appid, api_key, api_secret, gpt_url)
gpt_url1 = 'ws://spark-api.xf-yun.com/v1.1/chat'
gpt_url2 = 'ws://spark-api.xf-yun.com/v2.1/chat'
self.wsParam1 = Ws_Param(appid, api_key, api_secret, gpt_url1)
self.wsParam2 = Ws_Param(appid, api_key, api_secret, gpt_url2)

websocket.enableTrace(False)
# wsUrl = wsParam.create_url()

# todo: modify to the ws pool
# self.mutex = threading.Lock()
# self.ws = websocket.WebSocket()
# self.ws.connect(wsUrl)
# self.ws.connect(self.wsUrl)

# self.mutex = threading.Lock()
self.header = {'app_id': appid, 'uid': 'elem'}

def __del__(self):
pass
# self.ws.close()

def __call__(self, inp: ChatInput, verbose=False):
messages = inp.messages
model = inp.model
Expand All @@ -161,21 +169,18 @@ def __call__(self, inp: ChatInput, verbose=False):
role = 'user'
new_messages.append({'role': role, 'content': m.content})

domain = 'generalv2' if model == 'spark-v2.0' else 'general'
created = get_ts()
payload = {
'header': self.header,
'payload': {
'message': {
'text': new_messages
}
},
'payload': {'message': {'text': new_messages}},
'parameter': {
'chat': {
'domain': 'general',
'temperature': temperature,
'max_tokens': max_tokens,
'auditing': 'default'
}
'chat': {
'domain': domain,
'temperature': temperature,
'max_tokens': max_tokens,
'auditing': 'default'
}
}
}

Expand All @@ -186,48 +191,48 @@ def __call__(self, inp: ChatInput, verbose=False):
status_code = 200
status_message = 'success'
choices = []
usage = None
usage = Usage()
texts = []
ws = None
try:
# self.mutex.acquire()
wsUrl = self.wsParam.create_url()
if model == 'spark-v2.0':
wsUrl = self.wsParam2.create_url()
else:
wsUrl = self.wsParam1.create_url()
ws = create_connection(wsUrl)
ws.send(json.dumps(payload))
texts = []
while True:
raw_data = ws.recv()
if not raw_data:
break

resp = json.loads(raw_data)
if resp['header']['code'] == 0:
texts.append(
resp['payload']['choices']['text'][0]['content'])
if resp['header']['code'] == 0 and resp['header'][
'status'] == 2:
if resp['header']['code'] == 0 and resp['header']['status'] == 2:
usage_dict = resp['payload']['usage']['text']
usage_dict.pop('question_tokens')
usage = Usage(**usage_dict)
except Exception as e:
print('exception', e)
status_code = 401
status_message = str(e)
finally:
if ws:
ws.close()
# self.mutex.release()

if texts:
finish_reason = 'default'
msg = Message(role='assistant', content=''.join(texts))
cho = Choice(index=0, message=msg, finish_reason=finish_reason)
cho = Choice(index=0, message=msg,
finish_reason=finish_reason)
choices.append(cho)

if status_code != 200:
raise Exception(status_message)

return ChatOutput(status_code=status_code,
status_message=status_message,
model=model,
object=req_type,
created=created,
choices=choices,
usage=usage)
return ChatOutput(
status_code=status_code,
status_message=status_message,
model=model, object=req_type, created=created,
choices=choices, usage=usage)
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __call__(self, inp: ChatInput, verbose=False):
req_type = 'chat.completion'
status_message = 'success'
choices = []
usage = None
usage = Usage()
try:
resp = zhipuai.model_api.invoke(**payload)
status_code = resp['code']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class ChatMinimaxAI(BaseChatModel):
"""Whether to stream the results or not."""
n: Optional[int] = 1
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
max_tokens: Optional[int] = 1024
"""Maximum number of tokens to generate."""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class ChatWenxin(BaseChatModel):
"""Whether to stream the results or not."""
n: Optional[int] = 1
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
max_tokens: Optional[int] = 1024
"""Maximum number of tokens to generate."""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class ChatXunfeiAI(BaseChatModel):

client: Optional[Any] #: :meta private:
"""Model name to use."""
model_name: str = Field('spark', alias='model')
model_name: str = Field('spark-1.5', alias='model')

temperature: float = 0.5
top_p: float = 0.7
Expand All @@ -128,7 +128,7 @@ class ChatXunfeiAI(BaseChatModel):
"""Whether to stream the results or not."""
n: Optional[int] = 1
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
max_tokens: Optional[int] = 1024
"""Maximum number of tokens to generate."""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class ChatZhipuAI(BaseChatModel):
"""Whether to stream the results or not."""
n: Optional[int] = 1
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
max_tokens: Optional[int] = 1024
"""Maximum number of tokens to generate."""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Expand Down
35 changes: 35 additions & 0 deletions src/bisheng-langchain/tests/test_chat_sparkai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

from bisheng_langchain.chat_models import ChatXunfeiAI


def test_chat_spark_v1():
xunfeiai_appid = os.environ['xunfeiai_appid']
xunfeiai_api_key = os.environ['xunfeiai_api_key']
xunfeiai_api_secret = os.environ['xunfeiai_api_secret']
chat = ChatXunfeiAI(
model='spark-v1.5',
xunfeiai_appid=xunfeiai_appid,
xunfeiai_api_key=xunfeiai_api_key,
xunfeiai_api_secret=xunfeiai_api_secret)

resp = chat.predict('你好')
print(resp)


def test_chat_spark_v2():
xunfeiai_appid = os.environ['xunfeiai_appid']
xunfeiai_api_key = os.environ['xunfeiai_api_key']
xunfeiai_api_secret = os.environ['xunfeiai_api_secret']
chat = ChatXunfeiAI(
model='spark-v2.0',
xunfeiai_appid=xunfeiai_appid,
xunfeiai_api_key=xunfeiai_api_key,
xunfeiai_api_secret=xunfeiai_api_secret)

resp = chat.predict('你好')
print(resp)


test_chat_spark_v1()
test_chat_spark_v2()

0 comments on commit 4f9d31a

Please sign in to comment.