@@ -366,11 +366,18 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
366366 # model.layers.layer_id.self_attention.query_key_value.qweight
367367 # model.layers.layer_id.self_attn.qkv_proj.weight
368368 # model.layers.layer_id.self_attention.query_key_value.weight
369- q_dim = q_size // (32 // local_bits ) if quant_type in {"awq" , "quark" } else q_size
370- kv_dim = kv_size // (32 // local_bits ) if quant_type in {"awq" , "quark" } else kv_size
371- tensor_map ["self_attn.q_proj.qweight" ] = tensor [:, :q_dim ]
372- tensor_map ["self_attn.k_proj.qweight" ] = tensor [:, q_dim : q_dim + kv_dim ]
373- tensor_map ["self_attn.v_proj.qweight" ] = tensor [:, q_dim + kv_dim :]
369+ if quant_type == "olive" :
370+ # Olive: (out_features, in_features), split on dim=0
371+ tensor_map ["self_attn.q_proj.qweight" ] = tensor [:q_size , :]
372+ tensor_map ["self_attn.k_proj.qweight" ] = tensor [q_size : q_size + kv_size , :]
373+ tensor_map ["self_attn.v_proj.qweight" ] = tensor [q_size + kv_size :, :]
374+ else :
375+ # AWQ/GPTQ/Quark: (in_features, out_features), split on dim=1
376+ q_dim = q_size // (32 // local_bits ) if quant_type in {"awq" , "quark" } else q_size
377+ kv_dim = kv_size // (32 // local_bits ) if quant_type in {"awq" , "quark" } else kv_size
378+ tensor_map ["self_attn.q_proj.qweight" ] = tensor [:, :q_dim ]
379+ tensor_map ["self_attn.k_proj.qweight" ] = tensor [:, q_dim : q_dim + kv_dim ]
380+ tensor_map ["self_attn.v_proj.qweight" ] = tensor [:, q_dim + kv_dim :]
374381 elif bool (
375382 re .match (
376383 r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.(scales|weight_scale)$" ,
@@ -381,9 +388,16 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
381388 # model.layers.layer_id.self_attention.query_key_value.scales
382389 # model.layers.layer_id.self_attn.qkv_proj.weight_scale
383390 # model.layers.layer_id.self_attention.query_key_value.weight_scale
384- tensor_map ["self_attn.q_proj.scales" ] = tensor [:, :q_size ]
385- tensor_map ["self_attn.k_proj.scales" ] = tensor [:, q_size : q_size + kv_size ]
386- tensor_map ["self_attn.v_proj.scales" ] = tensor [:, q_size + kv_size :]
391+ if quant_type == "olive" :
392+ # Olive: (out_features, num_groups), split on dim=0
393+ tensor_map ["self_attn.q_proj.scales" ] = tensor [:q_size , :]
394+ tensor_map ["self_attn.k_proj.scales" ] = tensor [q_size : q_size + kv_size , :]
395+ tensor_map ["self_attn.v_proj.scales" ] = tensor [q_size + kv_size :, :]
396+ else :
397+ # AWQ/GPTQ/Quark: split on dim=1
398+ tensor_map ["self_attn.q_proj.scales" ] = tensor [:, :q_size ]
399+ tensor_map ["self_attn.k_proj.scales" ] = tensor [:, q_size : q_size + kv_size ]
400+ tensor_map ["self_attn.v_proj.scales" ] = tensor [:, q_size + kv_size :]
387401 elif bool (
388402 re .match (
389403 r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.(qzeros|weight_zero_point)$" ,
@@ -394,19 +408,28 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
394408 # model.layers.layer_id.self_attention.query_key_value.qzeros
395409 # model.layers.layer_id.self_attn.qkv_proj.weight_zero_point
396410 # model.layers.layer_id.self_attention.query_key_value.weight_zero_point
397- q_dim = (
398- q_size // (32 // local_bits )
399- if quant_type in {"awq" , "gptq" , "olive" , "quark" }
400- else q_size
401- )
402- kv_dim = (
403- kv_size // (32 // local_bits )
404- if quant_type in {"awq" , "gptq" , "olive" , "quark" }
405- else kv_size
406- )
407- tensor_map ["self_attn.q_proj.qzeros" ] = tensor [:, :q_dim ]
408- tensor_map ["self_attn.k_proj.qzeros" ] = tensor [:, q_dim : q_dim + kv_dim ]
409- tensor_map ["self_attn.v_proj.qzeros" ] = tensor [:, q_dim + kv_dim :]
411+ if quant_type == "olive" :
412+ # Olive: (out_features, packed_num_groups) uint8, split on dim=0
413+ q_dim = q_size // (8 // local_bits )
414+ kv_dim = kv_size // (8 // local_bits )
415+ tensor_map ["self_attn.q_proj.qzeros" ] = tensor [:q_dim , :]
416+ tensor_map ["self_attn.k_proj.qzeros" ] = tensor [q_dim : q_dim + kv_dim , :]
417+ tensor_map ["self_attn.v_proj.qzeros" ] = tensor [q_dim + kv_dim :, :]
418+ else :
419+ # AWQ/GPTQ/Quark: int32 packing, split on dim=1
420+ q_dim = (
421+ q_size // (32 // local_bits )
422+ if quant_type in {"awq" , "gptq" , "quark" }
423+ else q_size
424+ )
425+ kv_dim = (
426+ kv_size // (32 // local_bits )
427+ if quant_type in {"awq" , "gptq" , "quark" }
428+ else kv_size
429+ )
430+ tensor_map ["self_attn.q_proj.qzeros" ] = tensor [:, :q_dim ]
431+ tensor_map ["self_attn.k_proj.qzeros" ] = tensor [:, q_dim : q_dim + kv_dim ]
432+ tensor_map ["self_attn.v_proj.qzeros" ] = tensor [:, q_dim + kv_dim :]
410433 elif bool (
411434 re .match (
412435 r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.g_idx$" , name
@@ -434,13 +457,19 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
434457 # model.layers.layer_id.mlp.dense_h_to_4h.qweight
435458 # model.layers.layer_id.mlp.gate_up_proj.weight
436459 # model.layers.layer_id.mlp.dense_h_to_4h.weight
437- intermediate_dim = (
438- intermediate_size // (32 // local_bits )
439- if quant_type in {"awq" , "quark" }
440- else intermediate_size
441- )
442- tensor_map ["mlp.gate_proj.qweight" ] = tensor [:, :intermediate_dim ]
443- tensor_map ["mlp.up_proj.qweight" ] = tensor [:, intermediate_dim :]
460+ if quant_type == "olive" :
461+ # Olive: (out_features, in_features), split on dim=0
462+ tensor_map ["mlp.gate_proj.qweight" ] = tensor [:intermediate_size , :]
463+ tensor_map ["mlp.up_proj.qweight" ] = tensor [intermediate_size :, :]
464+ else :
465+ # AWQ/GPTQ/Quark: (in_features, out_features), split on dim=1
466+ intermediate_dim = (
467+ intermediate_size // (32 // local_bits )
468+ if quant_type in {"awq" , "quark" }
469+ else intermediate_size
470+ )
471+ tensor_map ["mlp.gate_proj.qweight" ] = tensor [:, :intermediate_dim ]
472+ tensor_map ["mlp.up_proj.qweight" ] = tensor [:, intermediate_dim :]
444473 elif bool (
445474 re .match (
446475 r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.(scales|weight_scale)$" ,
@@ -451,8 +480,14 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
451480 # model.layers.layer_id.mlp.dense_h_to_4h.scales
452481 # model.layers.layer_id.mlp.gate_up_proj.weight_scale
453482 # model.layers.layer_id.mlp.dense_h_to_4h.weight_scale
454- tensor_map ["mlp.gate_proj.scales" ] = tensor [:, :intermediate_size ]
455- tensor_map ["mlp.up_proj.scales" ] = tensor [:, intermediate_size :]
483+ if quant_type == "olive" :
484+ # Olive: (out_features, num_groups), split on dim=0
485+ tensor_map ["mlp.gate_proj.scales" ] = tensor [:intermediate_size , :]
486+ tensor_map ["mlp.up_proj.scales" ] = tensor [intermediate_size :, :]
487+ else :
488+ # AWQ/GPTQ/Quark: split on dim=1
489+ tensor_map ["mlp.gate_proj.scales" ] = tensor [:, :intermediate_size ]
490+ tensor_map ["mlp.up_proj.scales" ] = tensor [:, intermediate_size :]
456491 elif bool (
457492 re .match (
458493 r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.(qzeros|weight_zero_point)$" ,
@@ -463,13 +498,20 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
463498 # model.layers.layer_id.mlp.dense_h_to_4h.qzeros
464499 # model.layers.layer_id.mlp.gate_up_proj.weight_zero_point
465500 # model.layers.layer_id.mlp.dense_h_to_4h.weight_zero_point
466- intermediate_dim = (
467- intermediate_size // (32 // local_bits )
468- if quant_type in {"awq" , "gptq" , "quark" , "olive" }
469- else intermediate_size
470- )
471- tensor_map ["mlp.gate_proj.qzeros" ] = tensor [:, :intermediate_dim ]
472- tensor_map ["mlp.up_proj.qzeros" ] = tensor [:, intermediate_dim :]
501+ if quant_type == "olive" :
502+ # Olive: (out_features, packed_num_groups) uint8, split on dim=0
503+ intermediate_dim = intermediate_size // (8 // local_bits )
504+ tensor_map ["mlp.gate_proj.qzeros" ] = tensor [:intermediate_dim , :]
505+ tensor_map ["mlp.up_proj.qzeros" ] = tensor [intermediate_dim :, :]
506+ else :
507+ # AWQ/GPTQ/Quark: int32 packing, split on dim=1
508+ intermediate_dim = (
509+ intermediate_size // (32 // local_bits )
510+ if quant_type in {"awq" , "gptq" , "quark" }
511+ else intermediate_size
512+ )
513+ tensor_map ["mlp.gate_proj.qzeros" ] = tensor [:, :intermediate_dim ]
514+ tensor_map ["mlp.up_proj.qzeros" ] = tensor [:, intermediate_dim :]
473515 elif bool (re .match (r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h)\.g_idx$" , name )):
474516 # model.layers.layer_id.mlp.gate_up_proj.g_idx
475517 # model.layers.layer_id.mlp.dense_h_to_4h.g_idx
@@ -554,10 +596,10 @@ def set_properties(self):
554596 self .lm_head .out_features = self .lm_head .qweight .shape [1 ]
555597 self .lm_head .in_features = self .lm_head .g_idx .shape [0 ]
556598 elif self .quant_type == "olive" :
557- self . lm_head . out_features = self . lm_head . qweight . shape [ 1 ]
558- # expects in_features to be divisible by the packing factor (32 // bits)
559- # not a new assumption since no code here accounts for padded packed weights
560- self .lm_head .in_features = self .lm_head .qweight .shape [0 ] * 32 // self .lm_head .bits
599+ # Olive format: qweight is (out_features, packed_in_features) uint8
600+ # packed_in_features = in_features * bits / 8
601+ self . lm_head . out_features = self . lm_head . qweight . shape [ 0 ]
602+ self .lm_head .in_features = self .lm_head .qweight .shape [1 ] * 8 // self .lm_head .bits
561603 else :
562604 raise NotImplementedError (f"The { self .quant_type } quantization method is not recognized." )
563605 for module in self .layers :
@@ -654,32 +696,31 @@ def set_properties(self):
654696 module .mlp .down_proj .in_features = module .mlp .down_proj .g_idx .shape [0 ]
655697
656698 elif self .quant_type == "olive" :
657- # Set in_features and out_features
658- module .self_attn .q_proj .out_features = module .self_attn .q_proj .qweight .shape [1 ]
699+ module .self_attn .q_proj .out_features = module .self_attn .q_proj .qweight .shape [0 ]
659700 module .self_attn .q_proj .in_features = (
660- module .self_attn .q_proj .qweight .shape [0 ] * 32 // module .self_attn .q_proj .bits
701+ module .self_attn .q_proj .qweight .shape [1 ] * 8 // module .self_attn .q_proj .bits
661702 )
662- module .self_attn .k_proj .out_features = module .self_attn .k_proj .qweight .shape [1 ]
703+ module .self_attn .k_proj .out_features = module .self_attn .k_proj .qweight .shape [0 ]
663704 module .self_attn .k_proj .in_features = (
664- module .self_attn .k_proj .qweight .shape [0 ] * 32 // module .self_attn .k_proj .bits
705+ module .self_attn .k_proj .qweight .shape [1 ] * 8 // module .self_attn .k_proj .bits
665706 )
666- module .self_attn .v_proj .out_features = module .self_attn .v_proj .qweight .shape [1 ]
707+ module .self_attn .v_proj .out_features = module .self_attn .v_proj .qweight .shape [0 ]
667708 module .self_attn .v_proj .in_features = (
668- module .self_attn .v_proj .qweight .shape [0 ] * 32 // module .self_attn .v_proj .bits
709+ module .self_attn .v_proj .qweight .shape [1 ] * 8 // module .self_attn .v_proj .bits
669710 )
670- module .self_attn .o_proj .out_features = module .self_attn .o_proj .qweight .shape [1 ]
711+ module .self_attn .o_proj .out_features = module .self_attn .o_proj .qweight .shape [0 ]
671712 module .self_attn .o_proj .in_features = (
672- module .self_attn .o_proj .qweight .shape [0 ] * 32 // module .self_attn .o_proj .bits
713+ module .self_attn .o_proj .qweight .shape [1 ] * 8 // module .self_attn .o_proj .bits
673714 )
674- module .mlp .gate_proj .out_features = module .mlp .gate_proj .qweight .shape [1 ]
715+ module .mlp .gate_proj .out_features = module .mlp .gate_proj .qweight .shape [0 ]
675716 module .mlp .gate_proj .in_features = (
676- module .mlp .gate_proj .qweight .shape [0 ] * 32 // module .mlp .gate_proj .bits
717+ module .mlp .gate_proj .qweight .shape [1 ] * 8 // module .mlp .gate_proj .bits
677718 )
678- module .mlp .up_proj .out_features = module .mlp .up_proj .qweight .shape [1 ]
679- module .mlp .up_proj .in_features = module .mlp .up_proj .qweight .shape [0 ] * 32 // module .mlp .up_proj .bits
680- module .mlp .down_proj .out_features = module .mlp .down_proj .qweight .shape [1 ]
719+ module .mlp .up_proj .out_features = module .mlp .up_proj .qweight .shape [0 ]
720+ module .mlp .up_proj .in_features = module .mlp .up_proj .qweight .shape [1 ] * 8 // module .mlp .up_proj .bits
721+ module .mlp .down_proj .out_features = module .mlp .down_proj .qweight .shape [0 ]
681722 module .mlp .down_proj .in_features = (
682- module .mlp .down_proj .qweight .shape [0 ] * 32 // module .mlp .down_proj .bits
723+ module .mlp .down_proj .qweight .shape [1 ] * 8 // module .mlp .down_proj .bits
683724 )
684725
685726 else :
@@ -1138,6 +1179,13 @@ def reverse_reorder_tensor(self, tensor, bits):
11381179
11391180
11401181class OliveModel (GPTQModel ):
1182+ """
1183+ Olive quantization format:
1184+ - qweight: (out_features, packed_in_features) uint8, packed along last dim
1185+ - scales: (out_features, num_groups) float
1186+ - qzeros: (out_features, packed_num_groups) uint8, packed along last dim
1187+ """
1188+
11411189 def _load_quant_config (self , quant_attrs ):
11421190 super ()._load_quant_config (quant_attrs )
11431191 self .overrides = quant_attrs ["config" ]["overrides" ] or {}
@@ -1150,6 +1198,40 @@ def get_layer_group_size(self, layer_name):
11501198 name = "." .join (layer_name .split ("." )[:- 1 ])
11511199 return self .overrides .get (name , {}).get ("group_size" , self .global_group_size )
11521200
1201+ def handle_qzeros (self , module ):
1202+ """Olive uses unsigned quantization, no offset needed."""
1203+ pass
1204+
1205+ def unpack (self , module ):
1206+ """Skip unpack for Olive format."""
1207+ pass
1208+
1209+ def repack (self , module ):
1210+ """
1211+ Olive format:
1212+ - qweight: (out_features, packed_in_features) uint8
1213+ - scales: (out_features, num_groups) float
1214+ - qzeros: (out_features, packed_num_groups) uint8
1215+
1216+ ORT format:
1217+ - qweight: (out_features, k_blocks, blob_size) uint8
1218+ - scales: (out_features * num_groups,) float, flattened
1219+ - qzeros: (out_features * packed_num_groups,) uint8, flattened
1220+ """
1221+ kpack = 8 // module .bits
1222+ k_blocks = module .in_features // module .group_size
1223+ blob_size = module .group_size // kpack
1224+
1225+ # qweight: (out_features, packed_in_features) -> (out_features, k_blocks, blob_size)
1226+ module .qweight = module .qweight .reshape (module .out_features , k_blocks , blob_size ).contiguous ()
1227+
1228+ # scales: (out_features, num_groups) -> flatten to 1D
1229+ module .scales = module .scales .reshape (- 1 ).contiguous ()
1230+
1231+ # qzeros: (out_features, packed_num_groups) -> flatten to 1D
1232+ if module .qzeros is not None and module .qzeros .numel () > 0 :
1233+ module .qzeros = module .qzeros .reshape (- 1 ).contiguous ()
1234+
11531235
11541236class QuantModel :
11551237 @staticmethod
0 commit comments