Skip to content

Commit 53661d0

Browse files
authored
Support Olive new uint8 quantization format (#1916)
This pull request updates the handling of the Olive quantization format in `quantized_model.py` to match the latest specification and improve code clarity. The main changes include correcting how in/out features are computed for Olive quantized layers, documenting the Olive format, and updating repacking logic for compatibility with ONNX Runtime (ORT). **Olive quantization format support and documentation:** * Updated computation of `in_features` and `out_features` for Olive quantized layers to match the new format, which packs weights along the last dimension (`qweight` is now `(out_features, packed_in_features)`), and adjusted all relevant projections in self-attention and MLP modules. [[1]](diffhunk://#diff-8c2caf775960974ce923934b24e069fae5b819a0fa972976363ab8689f996c23L557-R560) [[2]](diffhunk://#diff-8c2caf775960974ce923934b24e069fae5b819a0fa972976363ab8689f996c23L658-R684) * Added a docstring to the `OliveModel` class explaining the Olive quantization format for weights, scales, and zero points. **Repacking and compatibility improvements:** * Implemented a new `repack` method for Olive quantized modules to reshape tensors for ONNX Runtime (ORT) compatibility, including reshaping `qweight`, flattening `scales`, and flattening `qzeros`. * Added placeholder methods `handle_qzeros` and `unpack` for Olive format to clarify that no offset or unpacking is required.
1 parent cec93a0 commit 53661d0

File tree

1 file changed

+138
-56
lines changed

1 file changed

+138
-56
lines changed

src/python/py/models/quantized_model.py

Lines changed: 138 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,18 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
366366
# model.layers.layer_id.self_attention.query_key_value.qweight
367367
# model.layers.layer_id.self_attn.qkv_proj.weight
368368
# model.layers.layer_id.self_attention.query_key_value.weight
369-
q_dim = q_size // (32 // local_bits) if quant_type in {"awq", "quark"} else q_size
370-
kv_dim = kv_size // (32 // local_bits) if quant_type in {"awq", "quark"} else kv_size
371-
tensor_map["self_attn.q_proj.qweight"] = tensor[:, :q_dim]
372-
tensor_map["self_attn.k_proj.qweight"] = tensor[:, q_dim : q_dim + kv_dim]
373-
tensor_map["self_attn.v_proj.qweight"] = tensor[:, q_dim + kv_dim :]
369+
if quant_type == "olive":
370+
# Olive: (out_features, in_features), split on dim=0
371+
tensor_map["self_attn.q_proj.qweight"] = tensor[:q_size, :]
372+
tensor_map["self_attn.k_proj.qweight"] = tensor[q_size : q_size + kv_size, :]
373+
tensor_map["self_attn.v_proj.qweight"] = tensor[q_size + kv_size :, :]
374+
else:
375+
# AWQ/GPTQ/Quark: (in_features, out_features), split on dim=1
376+
q_dim = q_size // (32 // local_bits) if quant_type in {"awq", "quark"} else q_size
377+
kv_dim = kv_size // (32 // local_bits) if quant_type in {"awq", "quark"} else kv_size
378+
tensor_map["self_attn.q_proj.qweight"] = tensor[:, :q_dim]
379+
tensor_map["self_attn.k_proj.qweight"] = tensor[:, q_dim : q_dim + kv_dim]
380+
tensor_map["self_attn.v_proj.qweight"] = tensor[:, q_dim + kv_dim :]
374381
elif bool(
375382
re.match(
376383
r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.(scales|weight_scale)$",
@@ -381,9 +388,16 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
381388
# model.layers.layer_id.self_attention.query_key_value.scales
382389
# model.layers.layer_id.self_attn.qkv_proj.weight_scale
383390
# model.layers.layer_id.self_attention.query_key_value.weight_scale
384-
tensor_map["self_attn.q_proj.scales"] = tensor[:, :q_size]
385-
tensor_map["self_attn.k_proj.scales"] = tensor[:, q_size : q_size + kv_size]
386-
tensor_map["self_attn.v_proj.scales"] = tensor[:, q_size + kv_size :]
391+
if quant_type == "olive":
392+
# Olive: (out_features, num_groups), split on dim=0
393+
tensor_map["self_attn.q_proj.scales"] = tensor[:q_size, :]
394+
tensor_map["self_attn.k_proj.scales"] = tensor[q_size : q_size + kv_size, :]
395+
tensor_map["self_attn.v_proj.scales"] = tensor[q_size + kv_size :, :]
396+
else:
397+
# AWQ/GPTQ/Quark: split on dim=1
398+
tensor_map["self_attn.q_proj.scales"] = tensor[:, :q_size]
399+
tensor_map["self_attn.k_proj.scales"] = tensor[:, q_size : q_size + kv_size]
400+
tensor_map["self_attn.v_proj.scales"] = tensor[:, q_size + kv_size :]
387401
elif bool(
388402
re.match(
389403
r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.(qzeros|weight_zero_point)$",
@@ -394,19 +408,28 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
394408
# model.layers.layer_id.self_attention.query_key_value.qzeros
395409
# model.layers.layer_id.self_attn.qkv_proj.weight_zero_point
396410
# model.layers.layer_id.self_attention.query_key_value.weight_zero_point
397-
q_dim = (
398-
q_size // (32 // local_bits)
399-
if quant_type in {"awq", "gptq", "olive", "quark"}
400-
else q_size
401-
)
402-
kv_dim = (
403-
kv_size // (32 // local_bits)
404-
if quant_type in {"awq", "gptq", "olive", "quark"}
405-
else kv_size
406-
)
407-
tensor_map["self_attn.q_proj.qzeros"] = tensor[:, :q_dim]
408-
tensor_map["self_attn.k_proj.qzeros"] = tensor[:, q_dim : q_dim + kv_dim]
409-
tensor_map["self_attn.v_proj.qzeros"] = tensor[:, q_dim + kv_dim :]
411+
if quant_type == "olive":
412+
# Olive: (out_features, packed_num_groups) uint8, split on dim=0
413+
q_dim = q_size // (8 // local_bits)
414+
kv_dim = kv_size // (8 // local_bits)
415+
tensor_map["self_attn.q_proj.qzeros"] = tensor[:q_dim, :]
416+
tensor_map["self_attn.k_proj.qzeros"] = tensor[q_dim : q_dim + kv_dim, :]
417+
tensor_map["self_attn.v_proj.qzeros"] = tensor[q_dim + kv_dim :, :]
418+
else:
419+
# AWQ/GPTQ/Quark: int32 packing, split on dim=1
420+
q_dim = (
421+
q_size // (32 // local_bits)
422+
if quant_type in {"awq", "gptq", "quark"}
423+
else q_size
424+
)
425+
kv_dim = (
426+
kv_size // (32 // local_bits)
427+
if quant_type in {"awq", "gptq", "quark"}
428+
else kv_size
429+
)
430+
tensor_map["self_attn.q_proj.qzeros"] = tensor[:, :q_dim]
431+
tensor_map["self_attn.k_proj.qzeros"] = tensor[:, q_dim : q_dim + kv_dim]
432+
tensor_map["self_attn.v_proj.qzeros"] = tensor[:, q_dim + kv_dim :]
410433
elif bool(
411434
re.match(
412435
r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.g_idx$", name
@@ -434,13 +457,19 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
434457
# model.layers.layer_id.mlp.dense_h_to_4h.qweight
435458
# model.layers.layer_id.mlp.gate_up_proj.weight
436459
# model.layers.layer_id.mlp.dense_h_to_4h.weight
437-
intermediate_dim = (
438-
intermediate_size // (32 // local_bits)
439-
if quant_type in {"awq", "quark"}
440-
else intermediate_size
441-
)
442-
tensor_map["mlp.gate_proj.qweight"] = tensor[:, :intermediate_dim]
443-
tensor_map["mlp.up_proj.qweight"] = tensor[:, intermediate_dim:]
460+
if quant_type == "olive":
461+
# Olive: (out_features, in_features), split on dim=0
462+
tensor_map["mlp.gate_proj.qweight"] = tensor[:intermediate_size, :]
463+
tensor_map["mlp.up_proj.qweight"] = tensor[intermediate_size:, :]
464+
else:
465+
# AWQ/GPTQ/Quark: (in_features, out_features), split on dim=1
466+
intermediate_dim = (
467+
intermediate_size // (32 // local_bits)
468+
if quant_type in {"awq", "quark"}
469+
else intermediate_size
470+
)
471+
tensor_map["mlp.gate_proj.qweight"] = tensor[:, :intermediate_dim]
472+
tensor_map["mlp.up_proj.qweight"] = tensor[:, intermediate_dim:]
444473
elif bool(
445474
re.match(
446475
r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.(scales|weight_scale)$",
@@ -451,8 +480,14 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
451480
# model.layers.layer_id.mlp.dense_h_to_4h.scales
452481
# model.layers.layer_id.mlp.gate_up_proj.weight_scale
453482
# model.layers.layer_id.mlp.dense_h_to_4h.weight_scale
454-
tensor_map["mlp.gate_proj.scales"] = tensor[:, :intermediate_size]
455-
tensor_map["mlp.up_proj.scales"] = tensor[:, intermediate_size:]
483+
if quant_type == "olive":
484+
# Olive: (out_features, num_groups), split on dim=0
485+
tensor_map["mlp.gate_proj.scales"] = tensor[:intermediate_size, :]
486+
tensor_map["mlp.up_proj.scales"] = tensor[intermediate_size:, :]
487+
else:
488+
# AWQ/GPTQ/Quark: split on dim=1
489+
tensor_map["mlp.gate_proj.scales"] = tensor[:, :intermediate_size]
490+
tensor_map["mlp.up_proj.scales"] = tensor[:, intermediate_size:]
456491
elif bool(
457492
re.match(
458493
r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.(qzeros|weight_zero_point)$",
@@ -463,13 +498,20 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
463498
# model.layers.layer_id.mlp.dense_h_to_4h.qzeros
464499
# model.layers.layer_id.mlp.gate_up_proj.weight_zero_point
465500
# model.layers.layer_id.mlp.dense_h_to_4h.weight_zero_point
466-
intermediate_dim = (
467-
intermediate_size // (32 // local_bits)
468-
if quant_type in {"awq", "gptq", "quark", "olive"}
469-
else intermediate_size
470-
)
471-
tensor_map["mlp.gate_proj.qzeros"] = tensor[:, :intermediate_dim]
472-
tensor_map["mlp.up_proj.qzeros"] = tensor[:, intermediate_dim:]
501+
if quant_type == "olive":
502+
# Olive: (out_features, packed_num_groups) uint8, split on dim=0
503+
intermediate_dim = intermediate_size // (8 // local_bits)
504+
tensor_map["mlp.gate_proj.qzeros"] = tensor[:intermediate_dim, :]
505+
tensor_map["mlp.up_proj.qzeros"] = tensor[intermediate_dim:, :]
506+
else:
507+
# AWQ/GPTQ/Quark: int32 packing, split on dim=1
508+
intermediate_dim = (
509+
intermediate_size // (32 // local_bits)
510+
if quant_type in {"awq", "gptq", "quark"}
511+
else intermediate_size
512+
)
513+
tensor_map["mlp.gate_proj.qzeros"] = tensor[:, :intermediate_dim]
514+
tensor_map["mlp.up_proj.qzeros"] = tensor[:, intermediate_dim:]
473515
elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h)\.g_idx$", name)):
474516
# model.layers.layer_id.mlp.gate_up_proj.g_idx
475517
# model.layers.layer_id.mlp.dense_h_to_4h.g_idx
@@ -554,10 +596,10 @@ def set_properties(self):
554596
self.lm_head.out_features = self.lm_head.qweight.shape[1]
555597
self.lm_head.in_features = self.lm_head.g_idx.shape[0]
556598
elif self.quant_type == "olive":
557-
self.lm_head.out_features = self.lm_head.qweight.shape[1]
558-
# expects in_features to be divisible by the packing factor (32 // bits)
559-
# not a new assumption since no code here accounts for padded packed weights
560-
self.lm_head.in_features = self.lm_head.qweight.shape[0] * 32 // self.lm_head.bits
599+
# Olive format: qweight is (out_features, packed_in_features) uint8
600+
# packed_in_features = in_features * bits / 8
601+
self.lm_head.out_features = self.lm_head.qweight.shape[0]
602+
self.lm_head.in_features = self.lm_head.qweight.shape[1] * 8 // self.lm_head.bits
561603
else:
562604
raise NotImplementedError(f"The {self.quant_type} quantization method is not recognized.")
563605
for module in self.layers:
@@ -654,32 +696,31 @@ def set_properties(self):
654696
module.mlp.down_proj.in_features = module.mlp.down_proj.g_idx.shape[0]
655697

656698
elif self.quant_type == "olive":
657-
# Set in_features and out_features
658-
module.self_attn.q_proj.out_features = module.self_attn.q_proj.qweight.shape[1]
699+
module.self_attn.q_proj.out_features = module.self_attn.q_proj.qweight.shape[0]
659700
module.self_attn.q_proj.in_features = (
660-
module.self_attn.q_proj.qweight.shape[0] * 32 // module.self_attn.q_proj.bits
701+
module.self_attn.q_proj.qweight.shape[1] * 8 // module.self_attn.q_proj.bits
661702
)
662-
module.self_attn.k_proj.out_features = module.self_attn.k_proj.qweight.shape[1]
703+
module.self_attn.k_proj.out_features = module.self_attn.k_proj.qweight.shape[0]
663704
module.self_attn.k_proj.in_features = (
664-
module.self_attn.k_proj.qweight.shape[0] * 32 // module.self_attn.k_proj.bits
705+
module.self_attn.k_proj.qweight.shape[1] * 8 // module.self_attn.k_proj.bits
665706
)
666-
module.self_attn.v_proj.out_features = module.self_attn.v_proj.qweight.shape[1]
707+
module.self_attn.v_proj.out_features = module.self_attn.v_proj.qweight.shape[0]
667708
module.self_attn.v_proj.in_features = (
668-
module.self_attn.v_proj.qweight.shape[0] * 32 // module.self_attn.v_proj.bits
709+
module.self_attn.v_proj.qweight.shape[1] * 8 // module.self_attn.v_proj.bits
669710
)
670-
module.self_attn.o_proj.out_features = module.self_attn.o_proj.qweight.shape[1]
711+
module.self_attn.o_proj.out_features = module.self_attn.o_proj.qweight.shape[0]
671712
module.self_attn.o_proj.in_features = (
672-
module.self_attn.o_proj.qweight.shape[0] * 32 // module.self_attn.o_proj.bits
713+
module.self_attn.o_proj.qweight.shape[1] * 8 // module.self_attn.o_proj.bits
673714
)
674-
module.mlp.gate_proj.out_features = module.mlp.gate_proj.qweight.shape[1]
715+
module.mlp.gate_proj.out_features = module.mlp.gate_proj.qweight.shape[0]
675716
module.mlp.gate_proj.in_features = (
676-
module.mlp.gate_proj.qweight.shape[0] * 32 // module.mlp.gate_proj.bits
717+
module.mlp.gate_proj.qweight.shape[1] * 8 // module.mlp.gate_proj.bits
677718
)
678-
module.mlp.up_proj.out_features = module.mlp.up_proj.qweight.shape[1]
679-
module.mlp.up_proj.in_features = module.mlp.up_proj.qweight.shape[0] * 32 // module.mlp.up_proj.bits
680-
module.mlp.down_proj.out_features = module.mlp.down_proj.qweight.shape[1]
719+
module.mlp.up_proj.out_features = module.mlp.up_proj.qweight.shape[0]
720+
module.mlp.up_proj.in_features = module.mlp.up_proj.qweight.shape[1] * 8 // module.mlp.up_proj.bits
721+
module.mlp.down_proj.out_features = module.mlp.down_proj.qweight.shape[0]
681722
module.mlp.down_proj.in_features = (
682-
module.mlp.down_proj.qweight.shape[0] * 32 // module.mlp.down_proj.bits
723+
module.mlp.down_proj.qweight.shape[1] * 8 // module.mlp.down_proj.bits
683724
)
684725

685726
else:
@@ -1138,6 +1179,13 @@ def reverse_reorder_tensor(self, tensor, bits):
11381179

11391180

11401181
class OliveModel(GPTQModel):
1182+
"""
1183+
Olive quantization format:
1184+
- qweight: (out_features, packed_in_features) uint8, packed along last dim
1185+
- scales: (out_features, num_groups) float
1186+
- qzeros: (out_features, packed_num_groups) uint8, packed along last dim
1187+
"""
1188+
11411189
def _load_quant_config(self, quant_attrs):
11421190
super()._load_quant_config(quant_attrs)
11431191
self.overrides = quant_attrs["config"]["overrides"] or {}
@@ -1150,6 +1198,40 @@ def get_layer_group_size(self, layer_name):
11501198
name = ".".join(layer_name.split(".")[:-1])
11511199
return self.overrides.get(name, {}).get("group_size", self.global_group_size)
11521200

1201+
def handle_qzeros(self, module):
1202+
"""Olive uses unsigned quantization, no offset needed."""
1203+
pass
1204+
1205+
def unpack(self, module):
1206+
"""Skip unpack for Olive format."""
1207+
pass
1208+
1209+
def repack(self, module):
1210+
"""
1211+
Olive format:
1212+
- qweight: (out_features, packed_in_features) uint8
1213+
- scales: (out_features, num_groups) float
1214+
- qzeros: (out_features, packed_num_groups) uint8
1215+
1216+
ORT format:
1217+
- qweight: (out_features, k_blocks, blob_size) uint8
1218+
- scales: (out_features * num_groups,) float, flattened
1219+
- qzeros: (out_features * packed_num_groups,) uint8, flattened
1220+
"""
1221+
kpack = 8 // module.bits
1222+
k_blocks = module.in_features // module.group_size
1223+
blob_size = module.group_size // kpack
1224+
1225+
# qweight: (out_features, packed_in_features) -> (out_features, k_blocks, blob_size)
1226+
module.qweight = module.qweight.reshape(module.out_features, k_blocks, blob_size).contiguous()
1227+
1228+
# scales: (out_features, num_groups) -> flatten to 1D
1229+
module.scales = module.scales.reshape(-1).contiguous()
1230+
1231+
# qzeros: (out_features, packed_num_groups) -> flatten to 1D
1232+
if module.qzeros is not None and module.qzeros.numel() > 0:
1233+
module.qzeros = module.qzeros.reshape(-1).contiguous()
1234+
11531235

11541236
class QuantModel:
11551237
@staticmethod

0 commit comments

Comments
 (0)