Skip to content

Commit eccba14

Browse files
committed
PTQ update
1 parent 1426481 commit eccba14

File tree

5 files changed

+82
-35
lines changed

5 files changed

+82
-35
lines changed

src/brevitas_examples/common/generative/quant_blocks.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@
33
# SPDX-License-Identifier: BSD-3-Clause
44
"""
55

6-
from typing import Callable, List, Optional, Tuple
6+
from typing import Callable
77

88
import torch
99
from torch import Tensor
1010
import torch.nn as nn
1111

12-
import brevitas
13-
from brevitas.core.function_wrapper.shape import PermuteDims
14-
from brevitas.core.utils import SliceTensor
1512
from brevitas.core.zero_point import _ScaleShiftZeroPoint
1613
from brevitas.function.ops_ste import abs_binary_sign_grad
1714

src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def unique(sequence):
8989
'act_bit_width': [8], # Act bit width
9090
'bias_bit_width': [32], # Bias Bit-Width for Po2 scale
9191
'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel
92+
'act_quant_granularity': ['per_tensor'], # Scaling Per Output Channel
9293
'act_quant_type': ['sym'], # Act Quant Type
94+
'act_scale_computation_type': ['static'], # Act Quant Type
9395
'act_param_method': ['stats'], # Act Param Method
9496
'weight_param_method': ['mse'], # Weight Quant Type
9597
'bias_corr': [True], # Bias Correction
@@ -240,7 +242,9 @@ def ptq_torchvision_models(args):
240242
weight_param_method=config_namespace.weight_param_method,
241243
act_param_method=config_namespace.act_param_method,
242244
bias_bit_width=config_namespace.bias_bit_width,
245+
act_scale_computation_type=config_namespace.act_scale_computation_type,
243246
weight_quant_granularity=config_namespace.weight_quant_granularity,
247+
act_quant_granularity=config_namespace.act_quant_granularity,
244248
act_quant_percentile=config_namespace.act_quant_percentile,
245249
act_quant_type=config_namespace.act_quant_type,
246250
scale_factor_type=config_namespace.scale_factor_type,

src/brevitas_examples/imagenet_classification/ptq/ptq_common.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@
2929
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloatMSE
3030
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat
3131
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloatMSE
32+
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
33+
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloatMSE
34+
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
35+
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloatMSE
36+
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
37+
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloatMSE
38+
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
39+
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight
40+
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE
41+
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act
42+
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight
43+
from brevitas.quant.experimental.mx_quant_ocp import MXInt8WeightMSE
44+
from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8Weight
45+
from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8WeightMSE
3246
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint
3347
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE
3448
from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint
@@ -96,12 +110,16 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
96110
'per_tensor': {
97111
'sym': Int8WeightPerTensorFixedPoint},
98112
'per_channel': {
99-
'sym': Int8WeightPerChannelFixedPoint},},
113+
'sym': Int8WeightPerChannelFixedPoint},
114+
'per_group': {
115+
'sym': MXInt8Weight, 'asym': ShiftedMXUInt8Weight}},
100116
'mse': {
101117
'per_tensor': {
102118
'sym': Int8WeightPerTensorFixedPointMSE},
103119
'per_channel': {
104-
'sym': Int8WeightPerChannelFixedPointMSE}},}},
120+
'sym': Int8WeightPerChannelFixedPointMSE},
121+
'per_group': {
122+
'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}},}},
105123
'float': {
106124
'float_scale': {
107125
'stats': {
@@ -113,7 +131,26 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
113131
'per_tensor': {
114132
'sym': Fp8e4m3WeightPerTensorFloatMSE},
115133
'per_channel': {
116-
'sym': Fp8e4m3WeightPerChannelFloatMSE}}}}}
134+
'sym': Fp8e4m3WeightPerChannelFloatMSE}}}},
135+
'float_ocp': {
136+
'float_scale': {
137+
'stats': {
138+
'per_tensor': {
139+
'sym': Fp8e4m3OCPWeightPerTensorFloat},
140+
'per_channel': {
141+
'sym': Fp8e4m3OCPWeightPerChannelFloat}},
142+
'mse': {
143+
'per_tensor': {
144+
'sym': Fp8e4m3OCPWeightPerTensorFloatMSE},
145+
'per_channel': {
146+
'sym': Fp8e4m3OCPWeightPerChannelFloatMSE}}},
147+
'po2_scale': {
148+
'stats': {
149+
'per_group': {
150+
'sym': MXFloat8e4m3Weight}},
151+
'mse': {
152+
'per_group': {
153+
'sym': MXFloat8e4m3WeightMSE}}}}}
117154

118155
INPUT_QUANT_MAP = {
119156
'int': {
@@ -139,7 +176,10 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
139176
'stats': {
140177
'per_tensor': {
141178
'sym': CNNInt8DynamicActPerTensorFloat,
142-
'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}}},
179+
'asym': CNNShiftedUint8DynamicActPerTensorFloat}}},
180+
'po2_scale': {
181+
'stats': {
182+
'per_group': MXInt8Act}}}},
143183
'float': {
144184
'static': {
145185
'float_scale': {
@@ -148,7 +188,21 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
148188
'sym': Fp8e4m3ActPerTensorFloat}},
149189
'mse': {
150190
'per_tensor': {
151-
'sym': Fp8e4m3ActPerTensorFloatMSE}}}}}}
191+
'sym': Fp8e4m3ActPerTensorFloatMSE}}}}},
192+
'float_ocp': {
193+
'static': {
194+
'float_scale': {
195+
'stats': {
196+
'per_tensor': {
197+
'sym': Fp8e4m3OCPActPerTensorFloat}},
198+
'mse': {
199+
'per_tensor': {
200+
'sym': Fp8e4m3OCPActPerTensorFloatMSE}}}},
201+
'dynamic': {
202+
'po2_scale': {
203+
'stats': {
204+
'per_group': {
205+
'sym': MXFloat8e4m3Act}}}}}}
152206

153207

154208
def quantize_model(
@@ -252,14 +306,14 @@ def layerwise_bit_width_fn_weight(module):
252306
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
253307
act_bit_width_dict['act_bit_width'] = act_bit_width
254308

255-
if quant_format == 'float' and backend == 'layerwise':
309+
if 'float' in quant_format and backend == 'layerwise':
256310
weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight
257311
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act
258312
weight_bit_width_dict['weight_mantissa_bit_width'] = layerwise_bit_width_fn_weight_mantissa
259313
weight_bit_width_dict['weight_exponent_bit_width'] = layerwise_bit_width_fn_weight_exponent
260314
act_bit_width_dict['act_mantissa_bit_width'] = layerwise_bit_width_fn_act_mantissa
261315
act_bit_width_dict['act_exponent_bit_width'] = layerwise_bit_width_fn_act_exponent
262-
elif quant_format == 'float' and backend != 'layerwise':
316+
elif 'float' in quant_format and backend != 'layerwise':
263317
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
264318
act_bit_width_dict['act_bit_width'] = act_bit_width
265319
weight_bit_width_dict['weight_mantissa_bit_width'] = weight_mantissa_bit_width
@@ -334,12 +388,12 @@ def kwargs_prefix(prefix, weight_kwargs):
334388
return {prefix + k: v for k, v in weight_kwargs.items()}
335389

336390
weight_bit_width_dict = {'bit_width': weight_bit_width}
337-
if weight_quant_format == 'float':
391+
if 'float' in weight_quant_format:
338392
weight_bit_width_dict['exponent_bit_width'] = weight_exponent_bit_width
339393
weight_bit_width_dict['mantissa_bit_width'] = weight_mantissa_bit_width
340394

341395
act_bit_width_dict = {'bit_width': act_bit_width}
342-
if act_quant_format == 'float':
396+
if 'float' in act_quant_format:
343397
act_bit_width_dict['exponent_bit_width'] = act_exponent_bit_width
344398
act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width
345399

@@ -355,16 +409,12 @@ def kwargs_prefix(prefix, weight_kwargs):
355409
# Some activations in MHA should always be symmetric
356410
sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][
357411
act_scale_type][act_param_method][act_quant_granularity]['sym']
358-
# Linear layers with 2d input should always be per tensor
359-
per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][
360-
act_scale_type][act_param_method]['per_tensor'][act_quant_type]
412+
361413
act_quant = act_quant.let(**act_bit_width_dict)
362414
sym_act_quant = sym_act_quant.let(**act_bit_width_dict)
363-
per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict)
364415
else:
365416
act_quant = None
366417
sym_act_quant = None
367-
per_tensor_act_quant = None
368418

369419
# Modify the weight quantizer based on the arguments passed in
370420
weight_quant = weight_quant.let(
@@ -383,13 +433,6 @@ def kwargs_prefix(prefix, weight_kwargs):
383433
sym_act_quant = sym_act_quant.let(
384434
**{
385435
'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device})
386-
if per_tensor_act_quant is not None:
387-
per_tensor_act_quant = per_tensor_act_quant.let(
388-
**{
389-
'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device})
390-
if act_quant_type == 'asym' and act_quant_percentile is not None:
391-
per_tensor_act_quant = per_tensor_act_quant.let(
392-
**{'low_percentile_q': 100 - act_quant_percentile})
393436

394437
weight_quant_dict = {'weight_quant': weight_quant}
395438

@@ -431,9 +474,9 @@ def kwargs_prefix(prefix, weight_kwargs):
431474
unsigned_quant_act_kwargs['signed'] = False
432475

433476
# Layerwise is basic quant kwargs + input_quant
434-
layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant}
477+
layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': act_quant}
435478

436-
layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': per_tensor_act_quant}
479+
layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': act_quant}
437480

438481
quant_layer_map = {
439482
torch.nn.Linear: (qnn.QuantLinear, quant_wbiol_kwargs),
@@ -526,7 +569,7 @@ def apply_gptq(calib_loader, model, act_order=False):
526569
dtype = next(model.parameters()).dtype
527570
device = next(model.parameters()).device
528571
with torch.no_grad():
529-
with gptq_mode(model, act_order=act_order, use_quant_activations=False) as gptq:
572+
with gptq_mode(model, act_order=act_order, use_quant_activations=True) as gptq:
530573
gptq_model = gptq.model
531574
for i in tqdm(range(gptq.num_layers)):
532575
for i, (images, target) in enumerate(calib_loader):

src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,12 @@ def parse_type(v, default_type):
120120
parser.add_argument(
121121
'--weight-quant-granularity',
122122
default='per_tensor',
123-
choices=['per_tensor', 'per_channel'],
123+
choices=['per_tensor', 'per_channel', 'per_group'],
124+
help='Weight quantization type (default: per_tensor)')
125+
parser.add_argument(
126+
'--act-quant-granularity',
127+
default='per_tensor',
128+
choices=['per_tensor', 'per_group'],
124129
help='Activation quantization type (default: per_tensor)')
125130
parser.add_argument(
126131
'--weight-quant-calibration-type',
@@ -168,11 +173,7 @@ def parse_type(v, default_type):
168173
'--export-torch-qcdq',
169174
action='store_true',
170175
help='If true, export the model in torch qcdq format')
171-
add_bool_arg(
172-
parser,
173-
'scaling-per-output-channel',
174-
default=True,
175-
help='Weight scaling per output channel (default: enabled)')
176+
176177
add_bool_arg(
177178
parser, 'bias-corr', default=True, help='Bias correction after calibration (default: enabled)')
178179
add_bool_arg(
@@ -189,7 +190,7 @@ def parse_type(v, default_type):
189190
parser.add_argument(
190191
'--quant-format',
191192
default='int',
192-
choices=['int', 'float'],
193+
choices=['int', 'float', 'float_ocp'],
193194
help='Quantization format to use for weights and activations (default: int)')
194195
parser.add_argument(
195196
'--layerwise-first-last-mantissa-bit-width',
@@ -409,6 +410,7 @@ def main():
409410
weight_narrow_range=args.weight_narrow_range,
410411
weight_param_method=args.weight_quant_calibration_type,
411412
weight_quant_granularity=args.weight_quant_granularity,
413+
act_quant_granularity=args.act_quant_granularity,
412414
weight_quant_type=args.weight_quant_type,
413415
layerwise_first_last_bit_width=args.layerwise_first_last_bit_width,
414416
act_bit_width=args.act_bit_width,

tests/brevitas/graph/equalization_fixtures.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def forward(self, x):
387387
input_quant, weight_quant = pytest_cases.param_fixtures("input_quant, weight_quant", [(None, Int8WeightPerTensorFloat), (Int8ActPerTensorFloat, Int8WeightPerTensorFloat), (MXInt8Act, MXInt8Weight), (MXFloat8e4m3Act, MXFloat8e4m3Weight)])
388388

389389

390+
390391
@pytest_cases.fixture
391392
def quant_conv_with_input_quant_model(input_quant, weight_quant):
392393

0 commit comments

Comments
 (0)