7
7
import torch .nn .functional as F
8
8
import torch .nn as nn
9
9
import time
10
- import gc
11
10
import math
12
11
import re
13
12
@@ -132,27 +131,21 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
132
131
### begin GPTQ_RWKV
133
132
def __init__ (self , model , strategy ):
134
133
super ().__init__ (model , strategy )
135
- #TODO: add assert to only quantize in CPU FP32 mode
134
+ for i in range (self .args .n_layer ):
135
+ assert self .strategy [i ].device == "cpu"
136
136
137
137
def _fill_subset (self , layer_id ):
138
138
# Keep only layer within block layer_id
139
- dd = self .strategy [layer_id ]
140
- dev = dd .device
141
-
142
- for name in self .w .keys ():
143
- if re .match (f'^blocks\.{ layer_id } \..*\.weight$' , name ):
144
- tensor = self .w [name ]
145
-
146
- #TODO: Skip 1D tensors for now
147
- if len (tensor .shape ) == 1 :
148
- continue
149
-
150
- print (f"{ name } = { self .w [name ].shape } " )
151
-
152
- if re .match (f'^blocks\.{ layer_id } \.(?:att|ffn)\.(?:key|value|output|receptance)\.weight$' , name ):
153
- tensor = tensor .to (device = dev , non_blocking = True )
154
-
155
- self .subset [name ] = tensor
139
+ is_weight = re .compile (f'^blocks\.{ layer_id } \..*\.weight$' )
140
+ for name in self .w .keys ():
141
+ if is_weight .match (name ):
142
+ if len (self .w [name ].shape ) == 1 : continue #TODO: Skip 1D tensors for now
143
+ self .subset [name ] = self .w [name ]
144
+
145
+ is_last_layer = (layer_id == self .args .n_layer - 1 )
146
+ if is_last_layer :
147
+ self .subset ["head.weight" ] = self .w ["head.weight" ]
148
+
156
149
157
150
def alloc_gptq (self , layer_id ):
158
151
self .subset = {}
@@ -178,7 +171,6 @@ def fasterquant(self, layer_id, quantizers):
178
171
self .gptq [name ].fasterquant (percdamp = 0.01 , groupsize = - 1 , actorder = False )
179
172
# self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order)
180
173
quantizers [name ] = self .gptq [name ].quantizer
181
- # TODO: may be free gptq here to save memory
182
174
183
175
### end GPTQ_RWKV
184
176
@@ -272,7 +264,7 @@ def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr
272
264
vw .add_batch (vx )
273
265
return x + out , xx [- 1 ,:]
274
266
275
- def forward_block (self , x , state , i , seq_mode , is_last_layer , full_output = False ):
267
+ def forward_block (self , x , state , i , seq_mode , full_output = False ):
276
268
with torch .no_grad ():
277
269
args = self .args
278
270
@@ -312,6 +304,12 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
312
304
rw = self .gptq [f'{ att } receptance.weight' ]
313
305
ow = self .gptq [f'{ att } output.weight' ]
314
306
307
+ if dd .stream :
308
+ kw = kw .to (device = dev , non_blocking = True )
309
+ vw = vw .to (device = dev , non_blocking = True )
310
+ rw = rw .to (device = dev , non_blocking = True )
311
+ ow = ow .to (device = dev , non_blocking = True )
312
+
315
313
kmx = self .w [f'{ att } key.weight_mx' ] if wtype == torch .uint8 else x
316
314
krx = self .w [f'{ att } key.weight_rx' ] if wtype == torch .uint8 else x
317
315
kmy = self .w [f'{ att } key.weight_my' ] if wtype == torch .uint8 else x
@@ -341,6 +339,7 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
341
339
omx = omx , orx = orx , omy = omy , ory = ory ,
342
340
)
343
341
342
+ # Deactivate add_batch() after quantization is applied
344
343
kw .deactivate_add_batch_call = True
345
344
vw .deactivate_add_batch_call = True
346
345
rw .deactivate_add_batch_call = True
@@ -352,6 +351,7 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
352
351
kw = self .gptq [f'{ ffn } key.weight' ]
353
352
vw = self .gptq [f'{ ffn } value.weight' ]
354
353
rw = self .gptq [f'{ ffn } receptance.weight' ]
354
+
355
355
if dd .stream :
356
356
kw = kw .to (device = dev , non_blocking = True )
357
357
vw = vw .to (device = dev , non_blocking = True )
@@ -391,43 +391,46 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
391
391
if (i + 1 ) % self .RESCALE_LAYER == 0 :
392
392
x = x / 2
393
393
394
- if is_last_layer :
394
+ is_last_layer = i == (args .n_layer - 1 )
395
+
396
+ if is_last_layer :
395
397
dd = self .strategy [args .n_layer ]
396
398
x = x [- 1 ,:] if (seq_mode and (not full_output )) else x
397
399
x = x .to (dtype = dd .atype , device = dd .device )
398
400
399
- #TODO: Add GPTQ support for head & ln_out
401
+ #TODO: ln_out.weight is 1D tensor
400
402
x = F .layer_norm (x , (args .n_embd ,), weight = self .w ['ln_out.weight' ], bias = self .w ['ln_out.bias' ])
403
+
401
404
if self .w ['head.weight' ].dtype != torch .uint8 :
402
- x = x @ self .w ['head.weight' ]
403
- else :
404
- if seq_mode and full_output :
405
- x = self .mm8_seq (x , self .w ['head.weight' ], self .w ['head.weight_mx' ], self .w ['head.weight_rx' ], self .w ['head.weight_my' ], self .w ['head.weight_ry' ])
406
- else :
407
- x = self .mm8_one (x , self .w ['head.weight' ], self .w ['head.weight_mx' ], self .w ['head.weight_rx' ], self .w ['head.weight_my' ], self .w ['head.weight_ry' ])
405
+ x = x @ self .gptq ['head.weight' ].weight
406
+ self .gptq ['head.weight' ].add_batch (x )
407
+ self .gptq ['head.weight' ].deactivate_add_batch_call = True
408
408
409
409
return x .float ()
410
410
411
411
### end RWKV
412
412
413
+ model = GPTQ_RWKV ("./RWKV-4-Pile-169M-20220807-8023.pth" , strategy = 'cpu fp32' )
414
+
413
415
NSAMPLES = 2
414
- HIDDEN_SIZE = 768
415
- SEQLEN = 2048 # TODO: this is chosen by the model
416
+ HIDDEN_SIZE = model . args . n_embd
417
+ SEQLEN = 1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m
416
418
417
419
# train_tokens, test_tokens = get_loaders(
418
420
# dataset_name="wikitext2",
419
421
# nsamples=NSAMPLES,
420
422
# seed=42,
421
423
# seqlen=SEQLEN,
422
- # model=None
424
+ # model=model
423
425
# )
424
426
425
427
# tokens = torch.cat([inp for inp, _ in train_tokens], dim=0)
426
428
tokens = torch .zeros ((NSAMPLES , SEQLEN ), dtype = torch .int64 )
427
429
print ("tokens.shape" , tokens .shape )
428
430
429
- model = GPTQ_RWKV ("./RWKV-4-Pile-169M-20220807-8023.pth" , strategy = 'cpu fp32' )
430
- is_last_layer = [False ] * (model .args .n_layer - 1 ) + [True ]
431
+ is_last_layer = lambda x : x == (model .args .n_layer - 1 )
432
+
433
+ start_time = time .time ()
431
434
432
435
#TODO: Do the same in GPU side
433
436
with torch .no_grad ():
@@ -442,23 +445,28 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
442
445
model .alloc_gptq (layer_id )
443
446
444
447
for j in range (NSAMPLES ):
445
- if not is_last_layer [ layer_id ] :
446
- outs [j ] = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode , is_last_layer = is_last_layer [ layer_id ] )
448
+ if not is_last_layer ( layer_id ) :
449
+ outs [j ] = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode )
447
450
else :
448
- _ = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode , is_last_layer = is_last_layer [ layer_id ] )
451
+ _ = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode )
449
452
450
453
model .fasterquant (layer_id , quantizers )
451
454
452
455
for j in range (NSAMPLES ):
453
- if not is_last_layer [ layer_id ] :
454
- outs [j ] = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode , is_last_layer = is_last_layer [ layer_id ] )
456
+ if not is_last_layer ( layer_id ) :
457
+ outs [j ] = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode )
455
458
else :
456
- _ = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode , is_last_layer = is_last_layer [layer_id ])
459
+ _ = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode )
460
+
457
461
model .free_gptq ()
458
462
459
- if not is_last_layer [ layer_id ]:
460
- # We need to pass the outputs of block i as input of block i+1 (except for last block)
463
+ # We need to pass the outputs of block i as input of block i+1 (except for last block)
464
+ if not is_last_layer ( layer_id ):
461
465
inps , outs = outs , inps
462
466
463
- # TODO: create a function that check if all weights were properly quantized
464
- print ("Done" )
467
+ end_time = time .time ()
468
+
469
+ print (f"Done in { end_time - start_time :.2f} seconds" )
470
+
471
+ # TODO: Do something with quantizers dictionary
472
+ # TODO: pack3 save model
0 commit comments