diff --git a/chat.py b/chat.py index a0f5d580..e4b12ce1 100644 --- a/chat.py +++ b/chat.py @@ -410,17 +410,35 @@ def on_message(message): temperature=x_temp, top_p=x_top_p, ) - out = run_rnn([token], newline_adj=newline_adj) + + tokens = [187, 187] if token == 0 else [token] + out = run_rnn(tokens, newline_adj=newline_adj) xxx = tokenizer.decode(model_tokens[out_last:]) if '\ufffd' not in xxx: # avoid utf-8 display issues print(xxx, end='', flush=True) out_last = begin + i + 1 - send_msg = tokenizer.decode(model_tokens[begin:]) + send_msg: str = tokenizer.decode(model_tokens[begin:]) if '\n\n' in send_msg: send_msg = send_msg.strip() break + + idx = send_msg.find(f'{user}{interface}') + if idx >= 0: + send_msg = f' {send_msg[:idx].strip()}\n\n' + tokens = tokenizer.encode(send_msg) + out = load_all_stat(srv, 'chat_pre') + out = run_rnn(tokens) + send_msg = send_msg.strip() + + idx = send_msg.find(f'{bot}{interface}') + if idx >= 0: + send_msg = f' {send_msg[:idx].strip()}\n\n' + tokens = tokenizer.encode(send_msg) + out = load_all_stat(srv, 'chat_pre') + out = run_rnn(tokens) + send_msg = send_msg.strip() # send_msg = tokenizer.decode(model_tokens[begin:]).strip() # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!! diff --git a/v2/chat.py b/v2/chat.py index b6283240..7e64583b 100644 --- a/v2/chat.py +++ b/v2/chat.py @@ -356,13 +356,14 @@ def on_message(message): ) # if token == END_OF_TEXT: # break + tokens = [END_OF_LINE, END_OF_LINE] if token == END_OF_TEXT else [token] if token not in occurrence: occurrence[token] = 1 else: occurrence[token] += 1 out = run_rnn([token], newline_adj=newline_adj) - out[END_OF_TEXT] = -999999999 # disable <|endoftext|> + # out[END_OF_TEXT] = -999999999 # disable <|endoftext|> xxx = pipeline.decode(model_tokens[out_last:]) if '\ufffd' not in xxx: # avoid utf-8 display issues @@ -374,6 +375,22 @@ def on_message(message): send_msg = send_msg.strip() break + idx = send_msg.find(f'{user}{interface}') + if idx >= 0: + send_msg = f' {send_msg[:idx].strip()}\n\n' + tokens = pipeline.encode(send_msg) + out = load_all_stat(srv, 'chat_pre') + out = run_rnn(tokens) + send_msg = send_msg.strip() + + idx = send_msg.find(f'{bot}{interface}') + if idx >= 0: + send_msg = f' {send_msg[:idx].strip()}\n\n' + tokens = pipeline.encode(send_msg) + out = load_all_stat(srv, 'chat_pre') + out = run_rnn(tokens) + send_msg = send_msg.strip() + # send_msg = pipeline.decode(model_tokens[begin:]).strip() # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!! # send_msg = send_msg[:-len(f'{user}{interface}')].strip()