Skip to content

Commit c92f72f

Browse files
committed
fix(quantize): add missing part in forward block + support head.weight quantization
1 parent 57079e7 commit c92f72f

File tree

1 file changed

+52
-44
lines changed

1 file changed

+52
-44
lines changed

quantize/tmp_rwkv.py

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch.nn.functional as F
88
import torch.nn as nn
99
import time
10-
import gc
1110
import math
1211
import re
1312

@@ -132,27 +131,21 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
132131
### begin GPTQ_RWKV
133132
def __init__(self, model, strategy):
134133
super().__init__(model, strategy)
135-
#TODO: add assert to only quantize in CPU FP32 mode
134+
for i in range(self.args.n_layer):
135+
assert self.strategy[i].device == "cpu"
136136

137137
def _fill_subset(self, layer_id):
138138
# Keep only layer within block layer_id
139-
dd = self.strategy[layer_id]
140-
dev = dd.device
141-
142-
for name in self.w.keys():
143-
if re.match(f'^blocks\.{layer_id}\..*\.weight$', name):
144-
tensor = self.w[name]
145-
146-
#TODO: Skip 1D tensors for now
147-
if len(tensor.shape) == 1:
148-
continue
149-
150-
print(f"{name} = {self.w[name].shape}")
151-
152-
if re.match(f'^blocks\.{layer_id}\.(?:att|ffn)\.(?:key|value|output|receptance)\.weight$', name):
153-
tensor = tensor.to(device=dev, non_blocking=True)
154-
155-
self.subset[name] = tensor
139+
is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$')
140+
for name in self.w.keys():
141+
if is_weight.match(name):
142+
if len(self.w[name].shape) == 1: continue #TODO: Skip 1D tensors for now
143+
self.subset[name] = self.w[name]
144+
145+
is_last_layer = (layer_id == self.args.n_layer - 1)
146+
if is_last_layer:
147+
self.subset["head.weight"] = self.w["head.weight"]
148+
156149

157150
def alloc_gptq(self, layer_id):
158151
self.subset = {}
@@ -178,7 +171,6 @@ def fasterquant(self, layer_id, quantizers):
178171
self.gptq[name].fasterquant(percdamp=0.01, groupsize=-1, actorder=False)
179172
# self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order)
180173
quantizers[name] = self.gptq[name].quantizer
181-
# TODO: may be free gptq here to save memory
182174

183175
### end GPTQ_RWKV
184176

@@ -272,7 +264,7 @@ def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr
272264
vw.add_batch(vx)
273265
return x + out, xx[-1,:]
274266

275-
def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False):
267+
def forward_block(self, x, state, i, seq_mode, full_output=False):
276268
with torch.no_grad():
277269
args = self.args
278270

@@ -312,6 +304,12 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
312304
rw = self.gptq[f'{att}receptance.weight']
313305
ow = self.gptq[f'{att}output.weight']
314306

307+
if dd.stream:
308+
kw = kw.to(device=dev, non_blocking=True)
309+
vw = vw.to(device=dev, non_blocking=True)
310+
rw = rw.to(device=dev, non_blocking=True)
311+
ow = ow.to(device=dev, non_blocking=True)
312+
315313
kmx = self.w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x
316314
krx = self.w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x
317315
kmy = self.w[f'{att}key.weight_my'] if wtype == torch.uint8 else x
@@ -341,6 +339,7 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
341339
omx=omx, orx=orx, omy=omy, ory=ory,
342340
)
343341

342+
# Deactivate add_batch() after quantization is applied
344343
kw.deactivate_add_batch_call = True
345344
vw.deactivate_add_batch_call = True
346345
rw.deactivate_add_batch_call = True
@@ -352,6 +351,7 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
352351
kw = self.gptq[f'{ffn}key.weight']
353352
vw = self.gptq[f'{ffn}value.weight']
354353
rw = self.gptq[f'{ffn}receptance.weight']
354+
355355
if dd.stream:
356356
kw = kw.to(device=dev, non_blocking=True)
357357
vw = vw.to(device=dev, non_blocking=True)
@@ -391,43 +391,46 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
391391
if (i+1) % self.RESCALE_LAYER == 0:
392392
x = x / 2
393393

394-
if is_last_layer:
394+
is_last_layer = i == (args.n_layer - 1)
395+
396+
if is_last_layer:
395397
dd = self.strategy[args.n_layer]
396398
x = x[-1,:] if (seq_mode and (not full_output)) else x
397399
x = x.to(dtype=dd.atype, device=dd.device)
398400

399-
#TODO: Add GPTQ support for head & ln_out
401+
#TODO: ln_out.weight is 1D tensor
400402
x = F.layer_norm(x, (args.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias'])
403+
401404
if self.w['head.weight'].dtype != torch.uint8:
402-
x = x @ self.w['head.weight']
403-
else:
404-
if seq_mode and full_output:
405-
x = self.mm8_seq(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry'])
406-
else:
407-
x = self.mm8_one(x, self.w['head.weight'], self.w['head.weight_mx'], self.w['head.weight_rx'], self.w['head.weight_my'], self.w['head.weight_ry'])
405+
x = x @ self.gptq['head.weight'].weight
406+
self.gptq['head.weight'].add_batch(x)
407+
self.gptq['head.weight'].deactivate_add_batch_call = True
408408

409409
return x.float()
410410

411411
### end RWKV
412412

413+
model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32')
414+
413415
NSAMPLES=2
414-
HIDDEN_SIZE=768
415-
SEQLEN=2048 # TODO: this is chosen by the model
416+
HIDDEN_SIZE=model.args.n_embd
417+
SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m
416418

417419
# train_tokens, test_tokens = get_loaders(
418420
# dataset_name="wikitext2",
419421
# nsamples=NSAMPLES,
420422
# seed=42,
421423
# seqlen=SEQLEN,
422-
# model=None
424+
# model=model
423425
# )
424426

425427
# tokens = torch.cat([inp for inp, _ in train_tokens], dim=0)
426428
tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64)
427429
print("tokens.shape", tokens.shape)
428430

429-
model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32')
430-
is_last_layer = [False] * (model.args.n_layer - 1) + [True]
431+
is_last_layer = lambda x: x == (model.args.n_layer - 1)
432+
433+
start_time = time.time()
431434

432435
#TODO: Do the same in GPU side
433436
with torch.no_grad():
@@ -442,23 +445,28 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
442445
model.alloc_gptq(layer_id)
443446

444447
for j in range(NSAMPLES):
445-
if not is_last_layer[layer_id]:
446-
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id])
448+
if not is_last_layer(layer_id):
449+
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)
447450
else:
448-
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id])
451+
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)
449452

450453
model.fasterquant(layer_id, quantizers)
451454

452455
for j in range(NSAMPLES):
453-
if not is_last_layer[layer_id]:
454-
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id])
456+
if not is_last_layer(layer_id):
457+
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)
455458
else:
456-
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode, is_last_layer=is_last_layer[layer_id])
459+
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)
460+
457461
model.free_gptq()
458462

459-
if not is_last_layer[layer_id]:
460-
# We need to pass the outputs of block i as input of block i+1 (except for last block)
463+
# We need to pass the outputs of block i as input of block i+1 (except for last block)
464+
if not is_last_layer(layer_id):
461465
inps, outs = outs, inps
462466

463-
# TODO: create a function that check if all weights were properly quantized
464-
print("Done")
467+
end_time = time.time()
468+
469+
print(f"Done in {end_time - start_time:.2f} seconds")
470+
471+
# TODO: Do something with quantizers dictionary
472+
# TODO: pack3 save model

0 commit comments

Comments
 (0)