7
7
import torch
8
8
from torch .nn import functional as F
9
9
10
+
11
+ def end_overlap (a , b ):
12
+ for i in reversed (range (1 , len (a ) + 1 )):
13
+ if b .startswith (a [- i :]):
14
+ return i
15
+ return 0
16
+
10
17
class PIPELINE_ARGS ():
11
- def __init__ (self , temperature = 1.0 , top_p = 0.85 , top_k = 0 , alpha_frequency = 0.2 , alpha_presence = 0.2 , token_ban = [], token_stop = [], chunk_len = 256 ):
18
+ def __init__ (self ,
19
+ temperature = 1.0 ,
20
+ top_p = 0.85 ,
21
+ top_k = 0 ,
22
+ alpha_frequency = 0.2 ,
23
+ alpha_presence = 0.2 ,
24
+ token_ban = None ,
25
+ token_stop = None ,
26
+ stop_words = None ,
27
+ chunk_len = 256
28
+ ):
29
+
30
+ token_ban = token_ban or []
31
+ token_stop = token_stop or []
32
+ stop_words = stop_words or []
33
+
12
34
self .temperature = temperature
13
35
self .top_p = top_p
14
36
self .top_k = top_k
15
37
self .alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
16
38
self .alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
17
39
self .token_ban = token_ban # ban the generation of some tokens
18
40
self .token_stop = token_stop # stop generation whenever you see any token here
41
+ self .stop_words = stop_words # stop generation whenever you see any token here
19
42
self .chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)
20
43
21
44
class PIPELINE ():
@@ -77,12 +100,23 @@ def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
77
100
probs = probs ** (1.0 / temperature )
78
101
out = torch .multinomial (probs , num_samples = 1 )[0 ]
79
102
return int (out )
80
-
81
- def generate (self , ctx , token_count = 100 , args = PIPELINE_ARGS (), callback = None , state = None ):
103
+
104
+ def generate (self , * args , callback = None , ** kwargs ):
105
+ outstr = []
106
+ for delta in self .igenerate (* args , ** kwargs ):
107
+ outstr += [delta ]
108
+ if callback :
109
+ callback (delta )
110
+ return '' .join (outstr )
111
+
112
+ def igenerate (self , ctx , token_count = 100 , args = PIPELINE_ARGS (), state = None ):
82
113
all_tokens = []
83
114
out_last = 0
84
115
out_str = ''
85
116
occurrence = {}
117
+
118
+ stopword_checker = self .check_stopwords (args .stop_words )
119
+ next (stopword_checker )
86
120
for i in range (token_count ):
87
121
88
122
# forward & adjust prob.
@@ -108,9 +142,57 @@ def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, st
108
142
109
143
# output
110
144
tmp = self .decode (all_tokens [out_last :])
145
+ if len (all_tokens )== 1 :
146
+ tmp = tmp [1 :] # strip leading space
147
+ if tmp == '' :
148
+ continue
111
149
if '\ufffd ' not in tmp : # is valid utf-8 string?
112
- if callback :
113
- callback (tmp )
114
- out_str += tmp
150
+
151
+ try :
152
+ tmp = stopword_checker .send (tmp )
153
+ except StopIteration :
154
+ break
115
155
out_last = i + 1
116
- return out_str
156
+
157
+ if tmp is None :
158
+ continue
159
+ yield tmp
160
+ out_str += tmp
161
+ out_last = i + 1
162
+
163
+ @staticmethod
164
+ def check_stopwords (stop_words ):
165
+
166
+ longest_stopword = 0 if len (stop_words )== 0 else max (map (len , stop_words ))
167
+ chunk = ""
168
+ delta = True
169
+ # yield
170
+ to_yield = None
171
+ while delta :
172
+ delta = yield to_yield
173
+ chunk = chunk + delta
174
+
175
+ if longest_stopword == 0 :
176
+ # nothing to check just passthrough
177
+ to_yield = delta
178
+ continue
179
+ if chunk == '' :
180
+ to_yield = None
181
+ continue
182
+ if any (map (lambda stop_word : chunk .startswith (stop_word ), stop_words )):
183
+ return
184
+
185
+ if start_idx := max (map (lambda stop_word : end_overlap (chunk , stop_word ), stop_words )):
186
+ if start_idx > longest_stopword :
187
+ start_idx = longest_stopword # can no longer be a stopword so cut it down
188
+ good , chunk = chunk [:- start_idx ], chunk [- start_idx :]
189
+ if good :
190
+ to_yield = good
191
+ continue
192
+
193
+ to_yield = None
194
+ continue
195
+
196
+ out = chunk
197
+ chunk = ""
198
+ to_yield = out
0 commit comments