@@ -45,30 +45,31 @@ func.func @AtenMmQint8(%arg0: !torch.vtensor<[3,4],si8>, %arg1: !torch.vtensor<[
4545
4646// -----
4747// CHECK-LABEL: func.func @quantization_per_tensor(
48- // CHECK-SAME: %[[IN:.*]]: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[2,4,4],!torch.qint8> {
49- // CHECK: %[[ZP:.*]] = "tosa.const"() <{values = dense<3> : tensor<1x1x1xi8>}> : () -> tensor<1x1x1xi8>
50- // CHECK: %[[C2:.*]] = "tosa.const"() <{values = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
51- // CHECK: %[[CHALF:.*]] = "tosa.const"() <{values = dense<5.000000e-01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
52- // CHECK: %[[C10:.*]] = "tosa.const"() <{values = dense<1.000000e+01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
53- // CHECK: %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
54- // CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[IN]] : !torch.vtensor<[2,4,4],f32> -> tensor<2x4x4xf32>
55- // CHECK: %[[RESCALE:.*]] = tosa.mul %[[IN_TENSOR]], %[[C10]], %[[MUL_SHIFT]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32>
56- // CHECK: %[[FLOOR:.*]] = tosa.floor %[[RESCALE]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
57- // CHECK: %[[FRAC:.*]] = tosa.sub %[[RESCALE]], %[[FLOOR]] : (tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
58- // CHECK: %[[CEIL:.*]] = tosa.ceil %[[RESCALE]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
59- // CHECK: %[[FLOOR_DIV_BY_2:.*]] = tosa.mul %[[FLOOR]], %[[CHALF]], %[[MUL_SHIFT]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32>
60- // CHECK: %[[FLOOR_DIV:.*]] = tosa.floor %[[FLOOR_DIV_BY_2]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
61- // CHECK: %[[EVEN_COMP:.*]] = tosa.mul %[[FLOOR_DIV]], %[[C2]], %[[MUL_SHIFT]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32>
62- // CHECK: %[[FLOOR_INPUT_EVEN:.*]] = tosa.equal %[[FLOOR]], %[[EVEN_COMP]] : (tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xi1>
63- // CHECK: %[[FRAC_EQ_HALF:.*]] = tosa.equal %[[FRAC]], %[[CHALF]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>) -> tensor<2x4x4xi1>
64- // CHECK: %[[GRTR:.*]] = tosa.greater %[[CHALF]], %[[FRAC]] : (tensor<1x1x1xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xi1>
65- // CHECK: %[[AND:.*]] = tosa.logical_and %[[FRAC_EQ_HALF]], %[[FLOOR_INPUT_EVEN]] : (tensor<2x4x4xi1>, tensor<2x4x4xi1>) -> tensor<2x4x4xi1>
66- // CHECK: %[[OR:.*]] = tosa.logical_or %[[GRTR]], %[[AND]] : (tensor<2x4x4xi1>, tensor<2x4x4xi1>) -> tensor<2x4x4xi1>
67- // CHECK: %[[SELECT:.*]] = tosa.select %[[OR]], %[[FLOOR]], %[[CEIL]] : (tensor<2x4x4xi1>, tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
68- // CHECK: %[[CAST:.*]] = tosa.cast %[[SELECT]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xi8>
69- // CHECK: %[[ADD:.*]] = tosa.add %[[CAST]], %[[ZP]] : (tensor<2x4x4xi8>, tensor<1x1x1xi8>) -> tensor<2x4x4xi8>
70- // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[ADD]] : tensor<2x4x4xi8> -> !torch.vtensor<[2,4,4],!torch.qint8>
71- // CHECK: return %[[RES]]
48+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[2,4,4],!torch.qint8> {
49+ // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<3.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
50+ // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
51+ // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<5.000000e-01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
52+ // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<1.000000e+01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
53+ // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
54+ // CHECK: %[[TO_BUILTIN_TENSOR_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,4],f32> -> tensor<2x4x4xf32>
55+ // CHECK: %[[MUL_0:.*]] = tosa.mul %[[TO_BUILTIN_TENSOR_0]], %[[VAL_3]], %[[VAL_4]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32>
56+ // CHECK: %[[FLOOR_0:.*]] = tosa.floor %[[MUL_0]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
57+ // CHECK: %[[SUB_0:.*]] = tosa.sub %[[MUL_0]], %[[FLOOR_0]] : (tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
58+ // CHECK: %[[CEIL_0:.*]] = tosa.ceil %[[MUL_0]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
59+ // CHECK: %[[MUL_1:.*]] = tosa.mul %[[FLOOR_0]], %[[VAL_2]], %[[VAL_4]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32>
60+ // CHECK: %[[FLOOR_1:.*]] = tosa.floor %[[MUL_1]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
61+ // CHECK: %[[MUL_2:.*]] = tosa.mul %[[FLOOR_1]], %[[VAL_1]], %[[VAL_4]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32>
62+ // CHECK: %[[EQUAL_0:.*]] = tosa.equal %[[FLOOR_0]], %[[MUL_2]] : (tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xi1>
63+ // CHECK: %[[EQUAL_1:.*]] = tosa.equal %[[SUB_0]], %[[VAL_2]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>) -> tensor<2x4x4xi1>
64+ // CHECK: %[[GREATER_0:.*]] = tosa.greater %[[VAL_2]], %[[SUB_0]] : (tensor<1x1x1xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xi1>
65+ // CHECK: %[[LOGICAL_AND_0:.*]] = tosa.logical_and %[[EQUAL_1]], %[[EQUAL_0]] : (tensor<2x4x4xi1>, tensor<2x4x4xi1>) -> tensor<2x4x4xi1>
66+ // CHECK: %[[LOGICAL_OR_0:.*]] = tosa.logical_or %[[GREATER_0]], %[[LOGICAL_AND_0]] : (tensor<2x4x4xi1>, tensor<2x4x4xi1>) -> tensor<2x4x4xi1>
67+ // CHECK: %[[SELECT_0:.*]] = tosa.select %[[LOGICAL_OR_0]], %[[FLOOR_0]], %[[CEIL_0]] : (tensor<2x4x4xi1>, tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
68+ // CHECK: %[[ADD_0:.*]] = tosa.add %[[SELECT_0]], %[[VAL_0]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>) -> tensor<2x4x4xf32>
69+ // CHECK: %[[CLAMP_0:.*]] = tosa.clamp %[[ADD_0]] {max_val = 1.270000e+02 : f32, min_val = -1.280000e+02 : f32} : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32>
70+ // CHECK: %[[CAST_0:.*]] = tosa.cast %[[CLAMP_0]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xi8>
71+ // CHECK: %[[FROM_BUILTIN_TENSOR_0:.*]] = torch_c.from_builtin_tensor %[[CAST_0]] : tensor<2x4x4xi8> -> !torch.vtensor<[2,4,4],!torch.qint8>
72+ // CHECK: return %[[FROM_BUILTIN_TENSOR_0]]
7273func.func @quantization_per_tensor (%arg0: !torch.vtensor <[2 ,4 ,4 ],f32 >) -> !torch.vtensor <[2 ,4 ,4 ],!torch.qint8 > {
7374 %dtype = torch.constant.int 12
7475 %scale = torch.constant.float 0.1
0 commit comments