@@ -147,6 +147,188 @@ def make_quant(module, names, bits, groupsize, name=''):
147
147
for name1 , child in module .named_children ():
148
148
make_quant (child , names , bits , groupsize , name + '.' + name1 if name != '' else name1 )
149
149
150
+ def make_quant_custom (module , names , bits , groupsize , name = '' ):
151
+ if isinstance (module , QuantLinear ):
152
+ return
153
+ for attr in dir (module ):
154
+ tmp = getattr (module , attr )
155
+ name1 = name + '.' + attr if name != '' else attr
156
+ if name1 in names :
157
+
158
+ bias_name = attr .replace ('w' , 'b' )
159
+ layer_name = attr .replace ('w' , 'quant' )
160
+ setattr (module , layer_name , QuantLinear_custom (bits , groupsize , tmp .shape [0 ], tmp .shape [1 ], module .w [bias_name ] is not None ))
161
+
162
+
163
+ class QuantLinear_custom (nn .Module ):
164
+ def __init__ (self , bits , groupsize , infeatures , outfeatures , bias , kernel_switch_threshold = 128 , is_cuda = is_cuda ):
165
+ super ().__init__ ()
166
+ if bits not in [2 ,3 ,4 ,8 ]:
167
+ raise NotImplementedError ("Only 2,3,4,8 bits are supported." )
168
+ self .infeatures = infeatures
169
+ self .outfeatures = outfeatures
170
+ self .bits = bits
171
+ self .groupsize = groupsize if groupsize != - 1 else infeatures
172
+ self .maxq = 2 ** self .bits - 1
173
+
174
+ self .register_buffer ('qweight' , torch .zeros ((infeatures // 32 * self .bits , outfeatures ), dtype = torch .int32 ))
175
+ self .register_buffer ('qzeros' , torch .zeros ((math .ceil (infeatures / self .groupsize ), outfeatures // 32 * self .bits ), dtype = torch .int32 ))
176
+ self .register_buffer ('scales' , torch .zeros ((math .ceil (infeatures / self .groupsize ), outfeatures ), dtype = torch .float16 ))
177
+ self .register_buffer ('g_idx' , torch .tensor ([i // self .groupsize for i in range (infeatures )], dtype = torch .int32 ))
178
+ if bias :
179
+ self .register_buffer ('bias' , torch .zeros ((outfeatures ),dtype = torch .float16 ))
180
+ else :
181
+ self .bias = None
182
+
183
+ # is performed by unpacking the weights and using torch.matmul
184
+ if self .bits in [2 ,4 ,8 ]:
185
+ self .register_buffer ('wf' ,torch .tensor (list (range (0 ,32 ,self .bits )), dtype = torch .int32 ).unsqueeze (0 ),persistent = False )
186
+ elif self .bits == 3 :
187
+ self .register_buffer ('wf' , torch .tensor ([[0 , 3 , 6 , 9 , 12 , 15 , 18 , 21 , 24 , 27 , 30 , 0 ],
188
+ [0 , 1 , 4 , 7 , 10 , 13 , 16 , 19 , 22 , 25 , 28 , 31 ],
189
+ [0 , 2 , 5 , 8 , 11 , 14 , 17 , 20 , 23 , 26 , 29 , 0 ],], dtype = torch .int32 ).reshape (1 ,3 ,12 ), persistent = False )
190
+
191
+ self .kernel_switch_threshold = kernel_switch_threshold
192
+ self .is_cuda = is_cuda
193
+
194
+ def pack (self , weight , bias , scales , zeros , g_idx = None ):
195
+ self .g_idx = g_idx .clone () if g_idx is not None else self .g_idx
196
+
197
+ scales = scales .t ().contiguous ()
198
+ zeros = zeros .t ().contiguous ()
199
+ scale_zeros = zeros * scales
200
+ self .scales = scales .clone ().half ()
201
+ if bias is not None :
202
+ self .bias = bias .clone ().half ()
203
+
204
+ intweight = []
205
+ for idx in range (self .infeatures ):
206
+ intweight .append (torch .round ((weight [:,idx ] + scale_zeros [self .g_idx [idx ]]) / self .scales [self .g_idx [idx ]]).to (torch .int )[:,None ])
207
+ intweight = torch .cat (intweight ,dim = 1 )
208
+ intweight = intweight .t ().contiguous ()
209
+ intweight = intweight .numpy ().astype (np .uint32 )
210
+ qweight = np .zeros (
211
+ (intweight .shape [0 ] // 32 * self .bits , intweight .shape [1 ]), dtype = np .uint32
212
+ )
213
+ i = 0
214
+ row = 0
215
+ while row < qweight .shape [0 ]:
216
+ if self .bits in [2 ,4 ,8 ]:
217
+ for j in range (i , i + (32 // self .bits )):
218
+ qweight [row ] |= intweight [j ] << (self .bits * (j - i ))
219
+ i += 32 // self .bits
220
+ row += 1
221
+ elif self .bits == 3 :
222
+ for j in range (i , i + 10 ):
223
+ qweight [row ] |= intweight [j ] << (3 * (j - i ))
224
+ i += 10
225
+ qweight [row ] |= intweight [i ] << 30
226
+ row += 1
227
+ qweight [row ] |= (intweight [i ] >> 2 ) & 1
228
+ i += 1
229
+ for j in range (i , i + 10 ):
230
+ qweight [row ] |= intweight [j ] << (3 * (j - i ) + 1 )
231
+ i += 10
232
+ qweight [row ] |= intweight [i ] << 31
233
+ row += 1
234
+ qweight [row ] |= (intweight [i ] >> 1 ) & 0x3
235
+ i += 1
236
+ for j in range (i , i + 10 ):
237
+ qweight [row ] |= intweight [j ] << (3 * (j - i ) + 2 )
238
+ i += 10
239
+ row += 1
240
+ else :
241
+ raise NotImplementedError ("Only 2,3,4,8 bits are supported." )
242
+
243
+ qweight = qweight .astype (np .int32 )
244
+ self .qweight = torch .from_numpy (qweight )
245
+
246
+ zeros -= 1
247
+ zeros = zeros .numpy ().astype (np .uint32 )
248
+ qzeros = np .zeros ((zeros .shape [0 ], zeros .shape [1 ] // 32 * self .bits ), dtype = np .uint32 )
249
+ i = 0
250
+ col = 0
251
+ while col < qzeros .shape [1 ]:
252
+ if self .bits in [2 ,4 ,8 ]:
253
+ for j in range (i , i + (32 // self .bits )):
254
+ qzeros [:, col ] |= zeros [:, j ] << (self .bits * (j - i ))
255
+ i += 32 // self .bits
256
+ col += 1
257
+ elif self .bits == 3 :
258
+ for j in range (i , i + 10 ):
259
+ qzeros [:, col ] |= zeros [:, j ] << (3 * (j - i ))
260
+ i += 10
261
+ qzeros [:, col ] |= zeros [:, i ] << 30
262
+ col += 1
263
+ qzeros [:, col ] |= (zeros [:, i ] >> 2 ) & 1
264
+ i += 1
265
+ for j in range (i , i + 10 ):
266
+ qzeros [:, col ] |= zeros [:, j ] << (3 * (j - i ) + 1 )
267
+ i += 10
268
+ qzeros [:, col ] |= zeros [:, i ] << 31
269
+ col += 1
270
+ qzeros [:, col ] |= (zeros [:, i ] >> 1 ) & 0x3
271
+ i += 1
272
+ for j in range (i , i + 10 ):
273
+ qzeros [:, col ] |= zeros [:, j ] << (3 * (j - i ) + 2 )
274
+ i += 10
275
+ col += 1
276
+ else :
277
+ raise NotImplementedError ("Only 2,3,4,8 bits are supported." )
278
+
279
+ qzeros = qzeros .astype (np .int32 )
280
+ self .qzeros = torch .from_numpy (qzeros )
281
+
282
+ def forward (self , x ):
283
+ out_shape = x .shape [:- 1 ] + (self .outfeatures , )
284
+ x = x .reshape (- 1 ,x .shape [- 1 ])
285
+ if self .is_cuda is True and (self .kernel_switch_threshold is False or x .shape [0 ] < self .kernel_switch_threshold ):
286
+ out = torch .zeros ((x .shape [0 ], self .outfeatures ), device = x .device , dtype = torch .float32 )
287
+ if self .bits == 2 :
288
+ quant_cuda .vecquant2matmul (x .float (), self .qweight , out , self .scales .float (), self .qzeros , self .g_idx )
289
+ elif self .bits == 3 :
290
+ quant_cuda .vecquant3matmul (x .float (), self .qweight , out , self .scales .float (), self .qzeros , self .g_idx )
291
+ elif self .bits == 4 :
292
+ quant_cuda .vecquant4matmul (x .float (), self .qweight , out , self .scales .float (), self .qzeros , self .g_idx )
293
+ elif self .bits == 8 :
294
+ quant_cuda .vecquant8matmul (x .float (), self .qweight , out , self .scales .float (), self .qzeros , self .g_idx )
295
+ out = out .half ()
296
+ else :
297
+ if self .bits in [2 ,4 ,8 ]:
298
+ zeros = torch .bitwise_right_shift (torch .unsqueeze (self .qzeros , 2 ).expand (- 1 , - 1 , 32 // self .bits ), self .wf .unsqueeze (0 )).to (torch .int16 if self .bits == 8 else torch .int8 )
299
+ torch .bitwise_and (zeros , (2 ** self .bits ) - 1 , out = zeros )
300
+
301
+ zeros = zeros + 1
302
+ zeros = zeros .reshape (self .scales .shape )
303
+
304
+ weight = torch .bitwise_right_shift (torch .unsqueeze (self .qweight , 1 ).expand (- 1 , 32 // self .bits , - 1 ), self .wf .unsqueeze (- 1 )).to (torch .int16 if self .bits == 8 else torch .int8 )
305
+ torch .bitwise_and (weight ,(2 ** self .bits ) - 1 , out = weight )
306
+ elif self .bits == 3 :
307
+ zeros = self .qzeros .reshape (self .qzeros .shape [0 ], self .qzeros .shape [1 ]// 3 , 3 , 1 ).expand (- 1 , - 1 , - 1 , 12 )
308
+ zeros = (zeros >> self .wf .unsqueeze (0 ))
309
+ zeros [:,:,0 ,10 ] = (zeros [:,:,0 ,10 ]& 0x3 ) | ((zeros [:,:,1 ,0 ] << 2 )& 0x4 )
310
+ zeros [:,:,1 ,11 ] = (zeros [:,:,1 ,11 ]& 0x1 ) | ((zeros [:,:,2 ,0 ] << 1 )& 0x6 )
311
+ zeros = zeros & 0x7
312
+ zeros = torch .cat ([zeros [:,:,0 ,:11 ], zeros [:,:,1 ,1 :12 ], zeros [:,:,2 ,1 :11 ]], dim = 2 )
313
+
314
+ zeros = zeros + 1
315
+ zeros = zeros .reshape (self .scales .shape )
316
+
317
+ weight = self .qweight .reshape (self .qweight .shape [0 ]// 3 , 3 , 1 , self .qweight .shape [1 ]).expand (- 1 , - 1 , 12 , - 1 )
318
+ weight = (weight >> self .wf .unsqueeze (- 1 ))& 0x7
319
+ weight [:,0 ,10 ] = (weight [:,0 ,10 ]& 0x3 ) | ((weight [:,1 ,0 ] << 2 )& 0x4 )
320
+ weight [:,1 ,11 ] = (weight [:,1 ,11 ]& 0x1 ) | ((weight [:,2 ,0 ] << 1 )& 0x6 )
321
+ weight = weight & 0x7
322
+ weight = torch .cat ([weight [:,0 ,:11 ], weight [:,1 ,1 :12 ], weight [:,2 ,1 :11 ]], dim = 1 )
323
+
324
+ weight = weight .reshape (weight .shape [0 ] * weight .shape [1 ], weight .shape [2 ])
325
+
326
+ weights = (self .scales [self .g_idx ] * (weight - zeros [self .g_idx ]))
327
+ out = torch .matmul (x .half (), weights )
328
+ out = out .reshape (out_shape )
329
+ out = out + self .bias if self .bias is not None else out
330
+ return out
331
+
150
332
class QuantLinear (nn .Module ):
151
333
def __init__ (self , bits , groupsize , infeatures , outfeatures , bias , kernel_switch_threshold = 128 , is_cuda = is_cuda ):
152
334
super ().__init__ ()
0 commit comments