Skip to content

Commit 4233522

Browse files
committed
feat(sanity-check): implem version of gptq now added
1 parent 8a37fb4 commit 4233522

File tree

2 files changed

+271
-13
lines changed

2 files changed

+271
-13
lines changed

quantize/gptq/quant.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,188 @@ def make_quant(module, names, bits, groupsize, name=''):
147147
for name1, child in module.named_children():
148148
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
149149

150+
def make_quant_custom(module, names, bits, groupsize, name=''):
151+
if isinstance(module, QuantLinear):
152+
return
153+
for attr in dir(module):
154+
tmp = getattr(module, attr)
155+
name1 = name + '.' + attr if name != '' else attr
156+
if name1 in names:
157+
158+
bias_name = attr.replace('w', 'b')
159+
layer_name = attr.replace('w', 'quant')
160+
setattr(module, layer_name, QuantLinear_custom(bits, groupsize, tmp.shape[0], tmp.shape[1], module.w[bias_name] is not None))
161+
162+
163+
class QuantLinear_custom(nn.Module):
164+
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda):
165+
super().__init__()
166+
if bits not in [2,3,4,8]:
167+
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
168+
self.infeatures = infeatures
169+
self.outfeatures = outfeatures
170+
self.bits = bits
171+
self.groupsize = groupsize if groupsize != -1 else infeatures
172+
self.maxq = 2 ** self.bits - 1
173+
174+
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
175+
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
176+
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
177+
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32))
178+
if bias:
179+
self.register_buffer('bias', torch.zeros((outfeatures),dtype=torch.float16))
180+
else:
181+
self.bias = None
182+
183+
# is performed by unpacking the weights and using torch.matmul
184+
if self.bits in [2,4,8]:
185+
self.register_buffer('wf',torch.tensor(list(range(0,32,self.bits)), dtype=torch.int32).unsqueeze(0),persistent=False)
186+
elif self.bits == 3:
187+
self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0],
188+
[0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31],
189+
[0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],], dtype=torch.int32).reshape(1,3,12), persistent=False)
190+
191+
self.kernel_switch_threshold = kernel_switch_threshold
192+
self.is_cuda = is_cuda
193+
194+
def pack(self, weight, bias, scales, zeros, g_idx = None):
195+
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
196+
197+
scales = scales.t().contiguous()
198+
zeros = zeros.t().contiguous()
199+
scale_zeros = zeros * scales
200+
self.scales = scales.clone().half()
201+
if bias is not None:
202+
self.bias = bias.clone().half()
203+
204+
intweight = []
205+
for idx in range(self.infeatures):
206+
intweight.append(torch.round((weight[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None])
207+
intweight = torch.cat(intweight,dim=1)
208+
intweight = intweight.t().contiguous()
209+
intweight = intweight.numpy().astype(np.uint32)
210+
qweight = np.zeros(
211+
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
212+
)
213+
i = 0
214+
row = 0
215+
while row < qweight.shape[0]:
216+
if self.bits in [2,4,8]:
217+
for j in range(i, i + (32//self.bits)):
218+
qweight[row] |= intweight[j] << (self.bits * (j - i))
219+
i += 32//self.bits
220+
row += 1
221+
elif self.bits == 3:
222+
for j in range(i, i + 10):
223+
qweight[row] |= intweight[j] << (3 * (j - i))
224+
i += 10
225+
qweight[row] |= intweight[i] << 30
226+
row += 1
227+
qweight[row] |= (intweight[i] >> 2) & 1
228+
i += 1
229+
for j in range(i, i + 10):
230+
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
231+
i += 10
232+
qweight[row] |= intweight[i] << 31
233+
row += 1
234+
qweight[row] |= (intweight[i] >> 1) & 0x3
235+
i += 1
236+
for j in range(i, i + 10):
237+
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
238+
i += 10
239+
row += 1
240+
else:
241+
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
242+
243+
qweight = qweight.astype(np.int32)
244+
self.qweight = torch.from_numpy(qweight)
245+
246+
zeros -= 1
247+
zeros = zeros.numpy().astype(np.uint32)
248+
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
249+
i = 0
250+
col = 0
251+
while col < qzeros.shape[1]:
252+
if self.bits in [2,4,8]:
253+
for j in range(i, i + (32//self.bits)):
254+
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
255+
i += 32//self.bits
256+
col += 1
257+
elif self.bits == 3:
258+
for j in range(i, i + 10):
259+
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
260+
i += 10
261+
qzeros[:, col] |= zeros[:, i] << 30
262+
col += 1
263+
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
264+
i += 1
265+
for j in range(i, i + 10):
266+
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
267+
i += 10
268+
qzeros[:, col] |= zeros[:, i] << 31
269+
col += 1
270+
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
271+
i += 1
272+
for j in range(i, i + 10):
273+
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
274+
i += 10
275+
col += 1
276+
else:
277+
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
278+
279+
qzeros = qzeros.astype(np.int32)
280+
self.qzeros = torch.from_numpy(qzeros)
281+
282+
def forward(self, x):
283+
out_shape = x.shape[:-1] + (self.outfeatures, )
284+
x = x.reshape(-1,x.shape[-1])
285+
if self.is_cuda is True and (self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold):
286+
out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32)
287+
if self.bits == 2:
288+
quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
289+
elif self.bits == 3:
290+
quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
291+
elif self.bits == 4:
292+
quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
293+
elif self.bits == 8:
294+
quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
295+
out = out.half()
296+
else:
297+
if self.bits in [2,4,8]:
298+
zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8)
299+
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
300+
301+
zeros = zeros + 1
302+
zeros = zeros.reshape(self.scales.shape)
303+
304+
weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8)
305+
torch.bitwise_and(weight,(2 ** self.bits) - 1, out=weight)
306+
elif self.bits == 3:
307+
zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12)
308+
zeros = (zeros >> self.wf.unsqueeze(0))
309+
zeros[:,:,0,10] = (zeros[:,:,0,10]&0x3) | ((zeros[:,:,1,0] << 2)&0x4)
310+
zeros[:,:,1,11] = (zeros[:,:,1,11]&0x1) | ((zeros[:,:,2,0] << 1)&0x6)
311+
zeros = zeros & 0x7
312+
zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2)
313+
314+
zeros = zeros + 1
315+
zeros = zeros.reshape(self.scales.shape)
316+
317+
weight = self.qweight.reshape(self.qweight.shape[0]//3, 3, 1, self.qweight.shape[1]).expand(-1, -1, 12, -1)
318+
weight = (weight >> self.wf.unsqueeze(-1))&0x7
319+
weight[:,0,10] = (weight[:,0,10]&0x3) | ((weight[:,1,0] << 2)&0x4)
320+
weight[:,1,11] = (weight[:,1,11]&0x1) | ((weight[:,2,0] << 1)&0x6)
321+
weight = weight & 0x7
322+
weight = torch.cat([weight[:,0,:11], weight[:,1,1:12], weight[:,2,1:11]], dim=1)
323+
324+
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
325+
326+
weights = (self.scales[self.g_idx] * (weight - zeros[self.g_idx]))
327+
out = torch.matmul(x.half(), weights)
328+
out = out.reshape(out_shape)
329+
out = out + self.bias if self.bias is not None else out
330+
return out
331+
150332
class QuantLinear(nn.Module):
151333
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda):
152334
super().__init__()

quantize/gptq/sanity_check_main.py

Lines changed: 89 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
import torch.nn as nn
66
import torch.optim as optim
7+
from collections import OrderedDict
8+
import torch.nn.functional as F
79

810
from sanity_check_utils import seed_everything, MNISTloader, SimpleNet, train, evaluate, SimpleNet_V2
911
from gptq import *
@@ -34,9 +36,8 @@ def load_quant(model, checkpoint, wbits, groupsize):
3436

3537
# Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way)
3638
# (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification)
37-
for name in ["linear4"]:
38-
if name in layers:
39-
del layers[name]
39+
if "linear4" in layers:
40+
del layers["linear4"]
4041

4142
make_quant(model, layers, wbits, groupsize)
4243
model.load_state_dict(torch.load(checkpoint))
@@ -258,8 +259,8 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
258259
### begin GPTQ_CUSTOM
259260
def __init__(self, checkpoint_path):
260261
super().__init__()
261-
self.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
262-
262+
self.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
263+
263264
def _fill_subset(self, layer_id):
264265
is_last_layer = (layer_id == self.nb_layers - 1)
265266
if is_last_layer:
@@ -292,7 +293,7 @@ def fasterquant(self, layer_id, quantizers):
292293
print(layer_id, name)
293294
print('Quantizing ...')
294295
scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False)
295-
quantizers[f"linear{layer_id + 1}"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())
296+
quantizers[f"linear{layer_id}_w"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())
296297

297298
## end GPTQ_CUSTOM
298299

@@ -301,6 +302,19 @@ def my_linear(self, x, weight, bias):
301302
out = x @ weight.weight + bias
302303
weight.add_batch(x)
303304
return out
305+
306+
def forward(self, x):
307+
if len(x.shape) == 4:
308+
x = x.view(x.size(0), -1)
309+
310+
residual = x
311+
x = F.relu(self.linear0_quant(x))
312+
x = self.linear1_quant(x)
313+
x = F.relu(x) + residual
314+
x = self.linear2_quant(x)
315+
x = F.relu(x) + residual
316+
x = super().my_linear(x, self.linear3_w, self.linear3_b)
317+
return x
304318
## End SimpleNet_V2
305319

306320

@@ -321,9 +335,11 @@ def quantize_gptq_custom(model, train_loader):
321335
quantizers = {}
322336

323337
for layer_id in range(nb_layers):
324-
338+
325339
if not is_last_layer(layer_id):
326-
340+
341+
print(f"Quantizing layer {layer_id} ...")
342+
327343
model.alloc_gptq(layer_id)
328344

329345
for i in range(nsamples):
@@ -342,12 +358,56 @@ def quantize_gptq_custom(model, train_loader):
342358

343359
return quantizers
344360

345-
346361
def model_pack_custom(model, quantizers, wbits, groupsize):
347-
pass
362+
# Extract weights and bias from model
363+
is_weight = re.compile(r'^linear\d+_w$')
364+
weights, bias = OrderedDict(), OrderedDict()
365+
for name, param in model.w.items():
366+
if is_weight.match(name):
367+
weights[name] = param
368+
else:
369+
bias[name] = param
370+
371+
make_quant_custom(model, quantizers, wbits, groupsize)
372+
qlayers = find_layers(model, [QuantLinear_custom])
373+
374+
print('Packing ...')
375+
for i in range(len(qlayers)):
376+
name_w, name_b, layer_quant_name = f'linear{i}_w', f'linear{i}_b', f'linear{i}_quant'
377+
quantizers[name_w],scale,zero,g_idx = quantizers[name_w]
378+
qlayers[layer_quant_name].pack(weights[name_w], bias[name_b], scale, zero, g_idx)
379+
print('Done.')
380+
return model
381+
382+
def load_quant_custom(model, checkpoint, wbits, groupsize):
383+
print('Loading model ...')
384+
model = model.eval()
385+
# Extract weights and bias from model
386+
is_weight = re.compile(r'^linear\d+_w$')
387+
weights, bias = OrderedDict(), OrderedDict()
388+
for name, param in model.w.items():
389+
if is_weight.match(name):
390+
weights[name] = param
391+
else:
392+
bias[name] = param
393+
394+
# Create linear layer out of weights and bias
395+
layers = {}
396+
for (w_name, w_param), (_, b_param) in zip(weights.items(), bias.items()):
397+
layers[w_name] = nn.Linear(w_param.shape[1], w_param.shape[0], bias=True)
398+
layers[w_name].weight.data = w_param
399+
layers[w_name].bias.data = b_param
400+
401+
# Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way)
402+
# (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification)
403+
if "linear3_w" in layers:
404+
del layers["linear3_w"]
405+
406+
make_quant_custom(model, layers, wbits, groupsize)
407+
model.load_state_dict(torch.load(checkpoint))
408+
print('Done.')
409+
return model
348410

349-
def load_quant_custom(model, quantizers, wbits, groupsize):
350-
pass
351411

352412
def assert_parameters(model, model_custom):
353413
is_weight = re.compile(r'^linear\d+.weight$')
@@ -371,6 +431,7 @@ def assert_parameters(model, model_custom):
371431
parser.add_argument("--eval_gptq", action="store_true")
372432
parser.add_argument("--train_custom", action="store_true")
373433
parser.add_argument("--gptq_custom", action="store_true")
434+
parser.add_argument("--eval_gptq_custom", action="store_true")
374435
parser.add_argument("--pyquant", action="store_true")
375436

376437
args = parser.parse_args()
@@ -381,7 +442,9 @@ def assert_parameters(model, model_custom):
381442
criterion = nn.CrossEntropyLoss()
382443
train_loader, _, _ = MNISTloader(train_val_split=0.95).load()
383444

384-
#TODO: Do Custom packing
445+
#TODO: Do custom eval gptq
446+
#TODO: Is reference GPTQ quantizing bias as well ?
447+
#TODO: Add seed everywhere in GPT for reproducibility
385448

386449
## ================== REFERENCE ==================
387450
if args.train:
@@ -430,6 +493,19 @@ def assert_parameters(model, model_custom):
430493
model_pack_custom(model, quantizers, WBITS, GROUPSIZE)
431494
torch.save(model.state_dict(), "model_quantized_custom.pt")
432495
print("Done Custom GPTQ")
496+
elif args.eval_gptq_custom:
497+
model = GPTQ_CUSTOM("./model_custom.pt")
498+
device = torch.device("cuda:0")
499+
model = load_quant_custom(model, "model_quantized_custom.pt", WBITS, GROUPSIZE)
500+
model = model.to(device)
501+
502+
start = time.time()
503+
val_loss, val_acc = evaluate(device, model, criterion, train_loader)
504+
end = time.time()
505+
506+
print(f"wbits = {WBITS} using {device}")
507+
print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}")
508+
print(f"Latency: {end - start}")
433509
## ================== MISC ==================
434510
elif args.pyquant:
435511
# Baseline post-training quantization from Pytorch

0 commit comments

Comments
 (0)