29
29
from brevitas .quant .experimental .float import Fp8e4m3WeightPerChannelFloatMSE
30
30
from brevitas .quant .experimental .float import Fp8e4m3WeightPerTensorFloat
31
31
from brevitas .quant .experimental .float import Fp8e4m3WeightPerTensorFloatMSE
32
+ from brevitas .quant .experimental .float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
33
+ from brevitas .quant .experimental .float_quant_ocp import Fp8e4m3OCPActPerTensorFloatMSE
34
+ from brevitas .quant .experimental .float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
35
+ from brevitas .quant .experimental .float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloatMSE
36
+ from brevitas .quant .experimental .float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
37
+ from brevitas .quant .experimental .float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloatMSE
38
+ from brevitas .quant .experimental .mx_quant_ocp import MXFloat8e4m3Act
39
+ from brevitas .quant .experimental .mx_quant_ocp import MXFloat8e4m3Weight
40
+ from brevitas .quant .experimental .mx_quant_ocp import MXFloat8e4m3WeightMSE
41
+ from brevitas .quant .experimental .mx_quant_ocp import MXInt8Act
42
+ from brevitas .quant .experimental .mx_quant_ocp import MXInt8Weight
43
+ from brevitas .quant .experimental .mx_quant_ocp import MXInt8WeightMSE
44
+ from brevitas .quant .experimental .mx_quant_ocp import ShiftedMXUInt8Weight
45
+ from brevitas .quant .experimental .mx_quant_ocp import ShiftedMXUInt8WeightMSE
32
46
from brevitas .quant .fixed_point import Int8ActPerTensorFixedPoint
33
47
from brevitas .quant .fixed_point import Int8ActPerTensorFixedPointMSE
34
48
from brevitas .quant .fixed_point import Int8WeightPerChannelFixedPoint
@@ -96,12 +110,16 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
96
110
'per_tensor' : {
97
111
'sym' : Int8WeightPerTensorFixedPoint },
98
112
'per_channel' : {
99
- 'sym' : Int8WeightPerChannelFixedPoint },},
113
+ 'sym' : Int8WeightPerChannelFixedPoint },
114
+ 'per_group' : {
115
+ 'sym' : MXInt8Weight , 'asym' : ShiftedMXUInt8Weight }},
100
116
'mse' : {
101
117
'per_tensor' : {
102
118
'sym' : Int8WeightPerTensorFixedPointMSE },
103
119
'per_channel' : {
104
- 'sym' : Int8WeightPerChannelFixedPointMSE }},}},
120
+ 'sym' : Int8WeightPerChannelFixedPointMSE },
121
+ 'per_group' : {
122
+ 'sym' : MXInt8WeightMSE , 'asym' : ShiftedMXUInt8WeightMSE }},}},
105
123
'float' : {
106
124
'float_scale' : {
107
125
'stats' : {
@@ -113,7 +131,26 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
113
131
'per_tensor' : {
114
132
'sym' : Fp8e4m3WeightPerTensorFloatMSE },
115
133
'per_channel' : {
116
- 'sym' : Fp8e4m3WeightPerChannelFloatMSE }}}}}
134
+ 'sym' : Fp8e4m3WeightPerChannelFloatMSE }}}},
135
+ 'float_ocp' : {
136
+ 'float_scale' : {
137
+ 'stats' : {
138
+ 'per_tensor' : {
139
+ 'sym' : Fp8e4m3OCPWeightPerTensorFloat },
140
+ 'per_channel' : {
141
+ 'sym' : Fp8e4m3OCPWeightPerChannelFloat }},
142
+ 'mse' : {
143
+ 'per_tensor' : {
144
+ 'sym' : Fp8e4m3OCPWeightPerTensorFloatMSE },
145
+ 'per_channel' : {
146
+ 'sym' : Fp8e4m3OCPWeightPerChannelFloatMSE }}},
147
+ 'po2_scale' : {
148
+ 'stats' : {
149
+ 'per_group' : {
150
+ 'sym' : MXFloat8e4m3Weight }},
151
+ 'mse' : {
152
+ 'per_group' : {
153
+ 'sym' : MXFloat8e4m3WeightMSE }}}}}
117
154
118
155
INPUT_QUANT_MAP = {
119
156
'int' : {
@@ -139,7 +176,10 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
139
176
'stats' : {
140
177
'per_tensor' : {
141
178
'sym' : CNNInt8DynamicActPerTensorFloat ,
142
- 'asym' : CNNShiftedUint8DynamicActPerTensorFloat }}}}},
179
+ 'asym' : CNNShiftedUint8DynamicActPerTensorFloat }}},
180
+ 'po2_scale' : {
181
+ 'stats' : {
182
+ 'per_group' : MXInt8Act }}}},
143
183
'float' : {
144
184
'static' : {
145
185
'float_scale' : {
@@ -148,7 +188,21 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
148
188
'sym' : Fp8e4m3ActPerTensorFloat }},
149
189
'mse' : {
150
190
'per_tensor' : {
151
- 'sym' : Fp8e4m3ActPerTensorFloatMSE }}}}}}
191
+ 'sym' : Fp8e4m3ActPerTensorFloatMSE }}}}},
192
+ 'float_ocp' : {
193
+ 'static' : {
194
+ 'float_scale' : {
195
+ 'stats' : {
196
+ 'per_tensor' : {
197
+ 'sym' : Fp8e4m3OCPActPerTensorFloat }},
198
+ 'mse' : {
199
+ 'per_tensor' : {
200
+ 'sym' : Fp8e4m3OCPActPerTensorFloatMSE }}}},
201
+ 'dynamic' : {
202
+ 'po2_scale' : {
203
+ 'stats' : {
204
+ 'per_group' : {
205
+ 'sym' : MXFloat8e4m3Act }}}}}}
152
206
153
207
154
208
def quantize_model (
@@ -252,14 +306,14 @@ def layerwise_bit_width_fn_weight(module):
252
306
weight_bit_width_dict ['weight_bit_width' ] = weight_bit_width
253
307
act_bit_width_dict ['act_bit_width' ] = act_bit_width
254
308
255
- if quant_format == 'float' and backend == 'layerwise' :
309
+ if 'float' in quant_format and backend == 'layerwise' :
256
310
weight_bit_width_dict ['weight_bit_width' ] = layerwise_bit_width_fn_weight
257
311
act_bit_width_dict ['act_bit_width' ] = layerwise_bit_width_fn_act
258
312
weight_bit_width_dict ['weight_mantissa_bit_width' ] = layerwise_bit_width_fn_weight_mantissa
259
313
weight_bit_width_dict ['weight_exponent_bit_width' ] = layerwise_bit_width_fn_weight_exponent
260
314
act_bit_width_dict ['act_mantissa_bit_width' ] = layerwise_bit_width_fn_act_mantissa
261
315
act_bit_width_dict ['act_exponent_bit_width' ] = layerwise_bit_width_fn_act_exponent
262
- elif quant_format == 'float' and backend != 'layerwise' :
316
+ elif 'float' in quant_format and backend != 'layerwise' :
263
317
weight_bit_width_dict ['weight_bit_width' ] = weight_bit_width
264
318
act_bit_width_dict ['act_bit_width' ] = act_bit_width
265
319
weight_bit_width_dict ['weight_mantissa_bit_width' ] = weight_mantissa_bit_width
@@ -334,12 +388,12 @@ def kwargs_prefix(prefix, weight_kwargs):
334
388
return {prefix + k : v for k , v in weight_kwargs .items ()}
335
389
336
390
weight_bit_width_dict = {'bit_width' : weight_bit_width }
337
- if weight_quant_format == 'float' :
391
+ if 'float' in weight_quant_format :
338
392
weight_bit_width_dict ['exponent_bit_width' ] = weight_exponent_bit_width
339
393
weight_bit_width_dict ['mantissa_bit_width' ] = weight_mantissa_bit_width
340
394
341
395
act_bit_width_dict = {'bit_width' : act_bit_width }
342
- if act_quant_format == 'float' :
396
+ if 'float' in act_quant_format :
343
397
act_bit_width_dict ['exponent_bit_width' ] = act_exponent_bit_width
344
398
act_bit_width_dict ['mantissa_bit_width' ] = act_mantissa_bit_width
345
399
@@ -355,16 +409,12 @@ def kwargs_prefix(prefix, weight_kwargs):
355
409
# Some activations in MHA should always be symmetric
356
410
sym_act_quant = INPUT_QUANT_MAP [act_quant_format ][act_scale_computation_type ][
357
411
act_scale_type ][act_param_method ][act_quant_granularity ]['sym' ]
358
- # Linear layers with 2d input should always be per tensor
359
- per_tensor_act_quant = INPUT_QUANT_MAP [act_quant_format ][act_scale_computation_type ][
360
- act_scale_type ][act_param_method ]['per_tensor' ][act_quant_type ]
412
+
361
413
act_quant = act_quant .let (** act_bit_width_dict )
362
414
sym_act_quant = sym_act_quant .let (** act_bit_width_dict )
363
- per_tensor_act_quant = per_tensor_act_quant .let (** act_bit_width_dict )
364
415
else :
365
416
act_quant = None
366
417
sym_act_quant = None
367
- per_tensor_act_quant = None
368
418
369
419
# Modify the weight quantizer based on the arguments passed in
370
420
weight_quant = weight_quant .let (
@@ -383,13 +433,6 @@ def kwargs_prefix(prefix, weight_kwargs):
383
433
sym_act_quant = sym_act_quant .let (
384
434
** {
385
435
'high_percentile_q' : act_quant_percentile , 'dtype' : dtype , 'device' : device })
386
- if per_tensor_act_quant is not None :
387
- per_tensor_act_quant = per_tensor_act_quant .let (
388
- ** {
389
- 'high_percentile_q' : act_quant_percentile , 'dtype' : dtype , 'device' : device })
390
- if act_quant_type == 'asym' and act_quant_percentile is not None :
391
- per_tensor_act_quant = per_tensor_act_quant .let (
392
- ** {'low_percentile_q' : 100 - act_quant_percentile })
393
436
394
437
weight_quant_dict = {'weight_quant' : weight_quant }
395
438
@@ -431,9 +474,9 @@ def kwargs_prefix(prefix, weight_kwargs):
431
474
unsigned_quant_act_kwargs ['signed' ] = False
432
475
433
476
# Layerwise is basic quant kwargs + input_quant
434
- layerwise_quant_wbiol_kwargs = {** quant_wbiol_kwargs , 'input_quant' : per_tensor_act_quant }
477
+ layerwise_quant_wbiol_kwargs = {** quant_wbiol_kwargs , 'input_quant' : act_quant }
435
478
436
- layerwise_quant_mha_kwargs = {** quant_mha_kwargs , 'in_proj_input_quant' : per_tensor_act_quant }
479
+ layerwise_quant_mha_kwargs = {** quant_mha_kwargs , 'in_proj_input_quant' : act_quant }
437
480
438
481
quant_layer_map = {
439
482
torch .nn .Linear : (qnn .QuantLinear , quant_wbiol_kwargs ),
@@ -526,7 +569,7 @@ def apply_gptq(calib_loader, model, act_order=False):
526
569
dtype = next (model .parameters ()).dtype
527
570
device = next (model .parameters ()).device
528
571
with torch .no_grad ():
529
- with gptq_mode (model , act_order = act_order , use_quant_activations = False ) as gptq :
572
+ with gptq_mode (model , act_order = act_order , use_quant_activations = True ) as gptq :
530
573
gptq_model = gptq .model
531
574
for i in tqdm (range (gptq .num_layers )):
532
575
for i , (images , target ) in enumerate (calib_loader ):
0 commit comments