@@ -25,9 +25,37 @@ func.func @test_cast_int4_and_uint4_to_from_int8_uint8(%arg0: tensor<1xi4>, %arg
25
25
%2 = " onnx.Cast" (%arg1 ) {saturate = 1 : si64 , to = ui8 } : (tensor <1 xui4 >) -> tensor <1 xui8 >
26
26
%3 = " onnx.Cast" (%2 ) {saturate = 1 : si64 , to = ui4 } : (tensor <1 xui8 >) -> tensor <1 xui4 >
27
27
onnx.Return %1 , %3 : tensor <1 xi4 >, tensor <1 xui4 >
28
- // CHECK-LABEL: func.func @test_cast_int4_and_uint4_to_from_int8_uint8(
29
- // TOSA does not support int4 casting
30
- // CHECK-NOT: tosa.cast
28
+ // CHECK-LABEL: func.func @test_cast_int4_and_uint4_to_from_int8_uint8
29
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi4>, [[PARAM_1_:%.+]]: tensor<1xui4>) -> (tensor<1xi4>, tensor<1xui4>) {
30
+ // CHECK-DAG: [[VAR_0_:%.+]] = tosa.cast [[PARAM_0_]] : (tensor<1xi4>) -> tensor<1xi8>
31
+ // CHECK-DAG: [[VAR_1_:%.+]] = tosa.cast [[VAR_0_]] : (tensor<1xi8>) -> tensor<1xi4>
32
+ // CHECK-DAG: [[VAR_2_:%.+]] = tosa.cast [[PARAM_1_]] : (tensor<1xui4>) -> tensor<1xui8>
33
+ // CHECK-DAG: [[VAR_3_:%.+]] = tosa.cast [[VAR_2_]] : (tensor<1xui8>) -> tensor<1xui4>
34
+ // CHECK: onnx.Return [[VAR_1_]], [[VAR_3_]] : tensor<1xi4>, tensor<1xui4>
35
+ // CHECK: }
36
+ }
37
+
38
+ // -----
39
+
40
+ func.func @test_cast_int4_and_uint4_to_float_and_back (%arg0: tensor <1 xi4 >, %arg1: tensor <1 xui4 >) -> (tensor <1 xi4 >, tensor <1 xui4 >) {
41
+ %0 = " onnx.Cast" (%arg0 ) {saturate = 1 : si64 , to = f32 } : (tensor <1 xi4 >) -> tensor <1 xf32 >
42
+ %1 = " onnx.Cast" (%0 ) {saturate = 1 : si64 , to = i4 } : (tensor <1 xf32 >) -> tensor <1 xi4 >
43
+ %2 = " onnx.Cast" (%arg1 ) {saturate = 1 : si64 , to = f32 } : (tensor <1 xui4 >) -> tensor <1 xf32 >
44
+ %3 = " onnx.Cast" (%2 ) {saturate = 1 : si64 , to = ui4 } : (tensor <1 xf32 >) -> tensor <1 xui4 >
45
+ onnx.Return %1 , %3 : tensor <1 xi4 >, tensor <1 xui4 >
46
+ // CHECK-LABEL: func.func @test_cast_int4_and_uint4_to_float_and_back
47
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi4>, [[PARAM_1_:%.+]]: tensor<1xui4>) -> (tensor<1xi4>, tensor<1xui4>) {
48
+ // CHECK-DAG: [[VAR_0_:%.+]] = tosa.cast [[PARAM_0_]] : (tensor<1xi4>) -> tensor<1xf32>
49
+ // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
50
+ // CHECK-DAG: [[VAR_2_:%.+]] = tosa.greater_equal [[VAR_0_]], [[VAR_1_]] : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
51
+ // CHECK-DAG: [[VAR_3_:%.+]] = tosa.floor [[VAR_0_]] : (tensor<1xf32>) -> tensor<1xf32>
52
+ // CHECK-DAG: [[VAR_4_:%.+]] = tosa.ceil [[VAR_0_]] : (tensor<1xf32>) -> tensor<1xf32>
53
+ // CHECK-DAG: [[VAR_5_:%.+]] = tosa.select [[VAR_2_]], [[VAR_3_]], [[VAR_4_]] : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
54
+ // CHECK-DAG: [[VAR_6_:%.+]] = tosa.cast [[VAR_5_]] : (tensor<1xf32>) -> tensor<1xi4>
55
+ // CHECK-DAG: [[VAR_7_:%.+]] = tosa.cast [[PARAM_1_]] : (tensor<1xui4>) -> tensor<1xf32>
56
+ // CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Cast"([[VAR_7_]]) {saturate = 1 : si64, to = ui4} : (tensor<1xf32>) -> tensor<1xui4>
57
+ // CHECK: onnx.Return [[VAR_6_]], [[VAR_8_]] : tensor<1xi4>, tensor<1xui4>
58
+ // CHECK: }
31
59
}
32
60
33
61
// -----
0 commit comments