Skip to content

Commit f905496

Browse files
committed
feat: add stopword checker + iterable generate function
1 parent 18c847d commit f905496

File tree

1 file changed

+89
-7
lines changed

1 file changed

+89
-7
lines changed

rwkv_pip_package/src/rwkv/utils.py

+89-7
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,38 @@
77
import torch
88
from torch.nn import functional as F
99

10+
11+
def end_overlap(a, b):
12+
for i in reversed(range(1, len(a) + 1)):
13+
if b.startswith(a[-i:]):
14+
return i
15+
return 0
16+
1017
class PIPELINE_ARGS():
11-
def __init__(self, temperature=1.0, top_p=0.85, top_k=0, alpha_frequency=0.2, alpha_presence=0.2, token_ban=[], token_stop=[], chunk_len=256):
18+
def __init__(self,
19+
temperature=1.0,
20+
top_p=0.85,
21+
top_k=0,
22+
alpha_frequency=0.2,
23+
alpha_presence=0.2,
24+
token_ban=None,
25+
token_stop=None,
26+
stop_words=None,
27+
chunk_len=256
28+
):
29+
30+
token_ban = token_ban or []
31+
token_stop = token_stop or []
32+
stop_words = stop_words or []
33+
1234
self.temperature = temperature
1335
self.top_p = top_p
1436
self.top_k = top_k
1537
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
1638
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
1739
self.token_ban = token_ban # ban the generation of some tokens
1840
self.token_stop = token_stop # stop generation whenever you see any token here
41+
self.stop_words = stop_words # stop generation whenever you see any token here
1942
self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)
2043

2144
class PIPELINE():
@@ -77,12 +100,23 @@ def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
77100
probs = probs ** (1.0 / temperature)
78101
out = torch.multinomial(probs, num_samples=1)[0]
79102
return int(out)
80-
81-
def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None):
103+
104+
def generate(self, *args, callback=None, **kwargs):
105+
outstr = []
106+
for delta in self.igenerate(*args, **kwargs):
107+
outstr += [delta]
108+
if callback:
109+
callback(delta)
110+
return ''.join(outstr)
111+
112+
def igenerate(self, ctx, token_count=100, args=PIPELINE_ARGS(), state=None):
82113
all_tokens = []
83114
out_last = 0
84115
out_str = ''
85116
occurrence = {}
117+
118+
stopword_checker = self.check_stopwords(args.stop_words)
119+
next(stopword_checker)
86120
for i in range(token_count):
87121

88122
# forward & adjust prob.
@@ -108,9 +142,57 @@ def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, st
108142

109143
# output
110144
tmp = self.decode(all_tokens[out_last:])
145+
if len(all_tokens)==1:
146+
tmp = tmp[1:] # strip leading space
147+
if tmp == '':
148+
continue
111149
if '\ufffd' not in tmp: # is valid utf-8 string?
112-
if callback:
113-
callback(tmp)
114-
out_str += tmp
150+
151+
try:
152+
tmp = stopword_checker.send(tmp)
153+
except StopIteration:
154+
break
115155
out_last = i + 1
116-
return out_str
156+
157+
if tmp is None:
158+
continue
159+
yield tmp
160+
out_str += tmp
161+
out_last = i + 1
162+
163+
@staticmethod
164+
def check_stopwords(stop_words):
165+
166+
longest_stopword = 0 if len(stop_words)==0 else max(map(len, stop_words))
167+
chunk = ""
168+
delta = True
169+
# yield
170+
to_yield = None
171+
while delta:
172+
delta = yield to_yield
173+
chunk = chunk + delta
174+
175+
if longest_stopword == 0:
176+
# nothing to check just passthrough
177+
to_yield = delta
178+
continue
179+
if chunk == '':
180+
to_yield = None
181+
continue
182+
if any(map(lambda stop_word: chunk.startswith(stop_word), stop_words)):
183+
return
184+
185+
if start_idx := max(map(lambda stop_word: end_overlap(chunk, stop_word), stop_words)):
186+
if start_idx > longest_stopword:
187+
start_idx = longest_stopword # can no longer be a stopword so cut it down
188+
good, chunk = chunk[:-start_idx], chunk[-start_idx:]
189+
if good:
190+
to_yield = good
191+
continue
192+
193+
to_yield = None
194+
continue
195+
196+
out = chunk
197+
chunk = ""
198+
to_yield = out

0 commit comments

Comments
 (0)