Skip to content

Commit 7f1d4b2

Browse files
authored
[tosa] : Move casting to integer domain after all computations for quantize_per_tensor. (#4487)
The existing lowering of `quantize_per_tensor` to TOSA is incorrect. Here's a numerical example explaining the bug (thanks to AI for formatting nicely) which is that the existing implementation casts to integer domain before `ZP` addition and that produces incorrect result: Consider an example input = 3.4501, Scale = 1/66.933334, Zero Point = -128 BUGGY Code (Original) ``` Step 1: Scale → 3.4501 × 66.933334 = 230.95 Step 2: Round → 231.0 Step 3: Cast to i8 → 231 → 127 (clamped by int8 range!) Step 4: Add ZP → 127 + (-128) = -1 ❌ Step 5: Clamp → clamp(-1, -128, 127) = -1 ``` FIXED Code ``` Step 1: Scale → 3.4501 × 66.933334 = 230.95 Step 2: Round → 231.0 Step 3: Add ZP → 231.0 + (-128.0) = 103.0 Step 4: Clamp → clamp(103.0, -128, 127) = 103.0 Step 5: Cast to i8 → 103 ✓ ```
1 parent 56e635e commit 7f1d4b2

File tree

4 files changed

+86
-43
lines changed

4 files changed

+86
-43
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10226,36 +10226,44 @@ LogicalResult ConvertAtenOp<AtenQuantizePerTensorOp>::matchAndRewriteImpl(
1022610226
op, "failed to implement round-half-to-even with TOSA ops");
1022710227
}
1022810228

10229-
// Cast to the destination integer type.
10230-
auto intermediateIntTy = resultTy.clone(resultElemTy);
10231-
Value castToInt =
10232-
tosa::CastOp::create(rewriter, loc, intermediateIntTy, *rounded);
10233-
10234-
// Add the zero point.
10235-
Value zpTensor =
10236-
tosa::createZeroPointTensor(rewriter, loc, intermediateIntTy, zpConst)
10229+
// Add the zero point
10230+
Value zpTensorFloat =
10231+
tosa::getConstTensor<float>(rewriter, op, static_cast<float>(zpConst), {},
10232+
inputElemTy)
1023710233
.value();
10238-
if (mlir::tosa::EqualizeRanks(rewriter, loc, castToInt, zpTensor).failed())
10234+
if (mlir::tosa::EqualizeRanks(rewriter, loc, *rounded, zpTensorFloat)
10235+
.failed())
1023910236
return failure();
10240-
Value withZp = tosa::AddOp::create(rewriter, loc, intermediateIntTy,
10241-
castToInt, zpTensor);
10242-
10243-
// Clamp the result to the valid range of the quantized type.
10244-
std::optional<int64_t> minInt,
10245-
maxInt; // no initialization needed as we want to clamp to the numeric
10246-
// limits of the type
10247-
IntegerAttr minIntAttr, maxIntAttr;
10237+
Value withZp =
10238+
tosa::AddOp::create(rewriter, loc, inputTy, *rounded, zpTensorFloat);
10239+
10240+
// Clamp the result to the valid range of the result/quantized type
10241+
std::optional<int64_t> minInt, maxInt;
10242+
IntegerAttr minIntAttr, maxIntAttr; // no initialization needed as we want to
10243+
// clamp to the numeric limits of the type
1024810244
if (failed(tosa::getIntegerClampAttrs(rewriter, op, resultElemTy, minInt,
1024910245
maxInt, minIntAttr, maxIntAttr))) {
1025010246
return failure();
1025110247
}
10248+
10249+
// Create float clamp attributes (clamp happens with integer range based on
10250+
// the result/quantized type but in the domain of the input type to preserve
10251+
// numeric)
10252+
auto minFloat = static_cast<float>(minIntAttr.getInt());
10253+
auto maxFloat = static_cast<float>(maxIntAttr.getInt());
10254+
auto minFloatAttr = rewriter.getFloatAttr(inputElemTy, minFloat);
10255+
auto maxFloatAttr = rewriter.getFloatAttr(inputElemTy, maxFloat);
10256+
1025210257
Value clamped = tosa::ClampOp::create(
10253-
rewriter, loc, resultTy, withZp, minIntAttr, maxIntAttr,
10258+
rewriter, loc, inputTy, withZp, minFloatAttr, maxFloatAttr,
1025410259
/*nan_mode=*/
1025510260
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
1025610261
tosa::NanPropagationMode::PROPAGATE));
1025710262

10258-
rewriter.replaceOp(op, clamped);
10263+
// Cast to the destination integer type
10264+
Value castToInt = tosa::CastOp::create(rewriter, loc, resultTy, clamped);
10265+
10266+
rewriter.replaceOp(op, castToInt);
1025910267
return success();
1026010268
}
1026110269

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@
829829
"QuantizedMLP_basic",
830830
"QuantizedNoLayer_basic",
831831
"QuantizedSingleLayer_basic",
832+
"QuantizePerTensorModule_basic",
832833
"RandnDtypeDeviceModule_basic",
833834
"RandnGeneratorF64Module_basic",
834835
"RandnGeneratorModule_basic",
@@ -3195,6 +3196,7 @@
31953196
"QuantizedReluInt8_basic",
31963197
"QuantizedReluInt32_basic",
31973198
"QuantizedReluUint8_basic",
3199+
"QuantizePerTensorModule_basic",
31983200
"RandIntDtypeModule_basic",
31993201
"RandIntModule_basic",
32003202
"RandIntPinMemoryModule_basic",
@@ -4802,6 +4804,7 @@
48024804
"QuantizedReluInt8_basic",
48034805
"QuantizedReluUint8_basic",
48044806
"QuantizedSingleLayer_basic",
4807+
"QuantizePerTensorModule_basic",
48054808
"RandIntDtypeModule_basic",
48064809
"RandIntModule_basic",
48074810
"RandIntPinMemoryModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
from torch import nn
8+
import torch.ao.quantization.fx._decomposed
89

910
from torch_mlir_e2e_test.framework import TestUtils
1011
from torch_mlir_e2e_test.registry import register_test_case
@@ -206,3 +207,33 @@ def forward(self, a):
206207
@register_test_case(module_factory=lambda: FakeQuantizePerTensorAffineCachemaskModule())
207208
def FakeQuantizePerTensorAffineCachemaskModule_basic(module, tu: TestUtils):
208209
module.forward(tu.rand(6, 4))
210+
211+
212+
# ==============================================================================
213+
214+
215+
class QuantizePerTensorModule(torch.nn.Module):
216+
def __init__(self):
217+
super().__init__()
218+
219+
@export
220+
@annotate_args(
221+
[
222+
None,
223+
([1, 64, 112, 112], torch.float32, True),
224+
]
225+
)
226+
def forward(self, x):
227+
scale = 0.014940238557755947
228+
zp = -128
229+
quant_min = -128
230+
quant_max = 127
231+
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
232+
x, scale, zp, quant_min, quant_max, torch.int8
233+
)
234+
235+
236+
@register_test_case(module_factory=lambda: QuantizePerTensorModule())
237+
def QuantizePerTensorModule_basic(module, tu: TestUtils):
238+
# use values within [-5, 5] to ensure we run into overflow/underflow
239+
module.forward(10 * torch.rand(1, 64, 112, 112) - 5)

test/Conversion/TorchToTosa/quantization.mlir

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,30 +45,31 @@ func.func @AtenMmQint8(%arg0: !torch.vtensor<[3,4],si8>, %arg1: !torch.vtensor<[
4545

4646
// -----
4747
// CHECK-LABEL: func.func @quantization_per_tensor(
48-
// CHECK-SAME: %[[IN:.*]]: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[2,4,4],!torch.qint8> {
49-
// CHECK: %[[ZP:.*]] = "tosa.const"() <{values = dense<3> : tensor<1x1x1xi8>}> : () -> tensor<1x1x1xi8>
50-
// CHECK: %[[C2:.*]] = "tosa.const"() <{values = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
51-
// CHECK: %[[CHALF:.*]] = "tosa.const"() <{values = dense<5.000000e-01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
52-
// CHECK: %[[C10:.*]] = "tosa.const"() <{values = dense<1.000000e+01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
53-
// CHECK: %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
54-
// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[IN]] : !torch.vtensor<[2,4,4],f32> -> tensor<2x4x4xf32>
55-
// CHECK: %[[RESCALE:.*]] = tosa.mul %[[IN_TENSOR]], %[[C10]], %[[MUL_SHIFT]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32>
56-
// CHECK: %[[FLOOR:.*]] = tosa.floor %[[RESCALE]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
57-
// CHECK: %[[FRAC:.*]] = tosa.sub %[[RESCALE]], %[[FLOOR]] : (tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
58-
// CHECK: %[[CEIL:.*]] = tosa.ceil %[[RESCALE]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
59-
// CHECK: %[[FLOOR_DIV_BY_2:.*]] = tosa.mul %[[FLOOR]], %[[CHALF]], %[[MUL_SHIFT]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32>
60-
// CHECK: %[[FLOOR_DIV:.*]] = tosa.floor %[[FLOOR_DIV_BY_2]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
61-
// CHECK: %[[EVEN_COMP:.*]] = tosa.mul %[[FLOOR_DIV]], %[[C2]], %[[MUL_SHIFT]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32>
62-
// CHECK: %[[FLOOR_INPUT_EVEN:.*]] = tosa.equal %[[FLOOR]], %[[EVEN_COMP]] : (tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xi1>
63-
// CHECK: %[[FRAC_EQ_HALF:.*]] = tosa.equal %[[FRAC]], %[[CHALF]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>) -> tensor<2x4x4xi1>
64-
// CHECK: %[[GRTR:.*]] = tosa.greater %[[CHALF]], %[[FRAC]] : (tensor<1x1x1xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xi1>
65-
// CHECK: %[[AND:.*]] = tosa.logical_and %[[FRAC_EQ_HALF]], %[[FLOOR_INPUT_EVEN]] : (tensor<2x4x4xi1>, tensor<2x4x4xi1>) -> tensor<2x4x4xi1>
66-
// CHECK: %[[OR:.*]] = tosa.logical_or %[[GRTR]], %[[AND]] : (tensor<2x4x4xi1>, tensor<2x4x4xi1>) -> tensor<2x4x4xi1>
67-
// CHECK: %[[SELECT:.*]] = tosa.select %[[OR]], %[[FLOOR]], %[[CEIL]] : (tensor<2x4x4xi1>, tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
68-
// CHECK: %[[CAST:.*]] = tosa.cast %[[SELECT]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xi8>
69-
// CHECK: %[[ADD:.*]] = tosa.add %[[CAST]], %[[ZP]] : (tensor<2x4x4xi8>, tensor<1x1x1xi8>) -> tensor<2x4x4xi8>
70-
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[ADD]] : tensor<2x4x4xi8> -> !torch.vtensor<[2,4,4],!torch.qint8>
71-
// CHECK: return %[[RES]]
48+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[2,4,4],!torch.qint8> {
49+
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<3.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
50+
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
51+
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<5.000000e-01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
52+
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<1.000000e+01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
53+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
54+
// CHECK: %[[TO_BUILTIN_TENSOR_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,4],f32> -> tensor<2x4x4xf32>
55+
// CHECK: %[[MUL_0:.*]] = tosa.mul %[[TO_BUILTIN_TENSOR_0]], %[[VAL_3]], %[[VAL_4]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32>
56+
// CHECK: %[[FLOOR_0:.*]] = tosa.floor %[[MUL_0]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
57+
// CHECK: %[[SUB_0:.*]] = tosa.sub %[[MUL_0]], %[[FLOOR_0]] : (tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
58+
// CHECK: %[[CEIL_0:.*]] = tosa.ceil %[[MUL_0]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
59+
// CHECK: %[[MUL_1:.*]] = tosa.mul %[[FLOOR_0]], %[[VAL_2]], %[[VAL_4]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32>
60+
// CHECK: %[[FLOOR_1:.*]] = tosa.floor %[[MUL_1]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
61+
// CHECK: %[[MUL_2:.*]] = tosa.mul %[[FLOOR_1]], %[[VAL_1]], %[[VAL_4]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32>
62+
// CHECK: %[[EQUAL_0:.*]] = tosa.equal %[[FLOOR_0]], %[[MUL_2]] : (tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xi1>
63+
// CHECK: %[[EQUAL_1:.*]] = tosa.equal %[[SUB_0]], %[[VAL_2]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>) -> tensor<2x4x4xi1>
64+
// CHECK: %[[GREATER_0:.*]] = tosa.greater %[[VAL_2]], %[[SUB_0]] : (tensor<1x1x1xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xi1>
65+
// CHECK: %[[LOGICAL_AND_0:.*]] = tosa.logical_and %[[EQUAL_1]], %[[EQUAL_0]] : (tensor<2x4x4xi1>, tensor<2x4x4xi1>) -> tensor<2x4x4xi1>
66+
// CHECK: %[[LOGICAL_OR_0:.*]] = tosa.logical_or %[[GREATER_0]], %[[LOGICAL_AND_0]] : (tensor<2x4x4xi1>, tensor<2x4x4xi1>) -> tensor<2x4x4xi1>
67+
// CHECK: %[[SELECT_0:.*]] = tosa.select %[[LOGICAL_OR_0]], %[[FLOOR_0]], %[[CEIL_0]] : (tensor<2x4x4xi1>, tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
68+
// CHECK: %[[ADD_0:.*]] = tosa.add %[[SELECT_0]], %[[VAL_0]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>) -> tensor<2x4x4xf32>
69+
// CHECK: %[[CLAMP_0:.*]] = tosa.clamp %[[ADD_0]] {max_val = 1.270000e+02 : f32, min_val = -1.280000e+02 : f32} : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
70+
// CHECK: %[[CAST_0:.*]] = tosa.cast %[[CLAMP_0]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xi8>
71+
// CHECK: %[[FROM_BUILTIN_TENSOR_0:.*]] = torch_c.from_builtin_tensor %[[CAST_0]] : tensor<2x4x4xi8> -> !torch.vtensor<[2,4,4],!torch.qint8>
72+
// CHECK: return %[[FROM_BUILTIN_TENSOR_0]]
7273
func.func @quantization_per_tensor(%arg0: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[2,4,4],!torch.qint8> {
7374
%dtype = torch.constant.int 12
7475
%scale = torch.constant.float 0.1

0 commit comments

Comments
 (0)