Skip to content

Commit 1a66819

Browse files
authored
[Codegen] Test Cleanup 4/8: Dialect tests (#22747)
Result of a scan over all tests in Codegen to cleanup common issues in tests. A summary of the results + a preamble approximating the issues to look for can be found here: https://gist.github.com/qedawkins/40f9e604fd83745bf1ac20fd63a7a61f
1 parent 3b7ff2d commit 1a66819

File tree

13 files changed

+62
-56
lines changed

13 files changed

+62
-56
lines changed

compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/roundtrip.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ func.func @test_full_lowering_config_with_scalable_vector() attributes {
4848
}
4949
// Order matters because it is sorted.
5050
// CHECK: #[[$CONFIG:.+]] = #iree_cpu.lowering_config<
51-
// CHECK-SAME{LITERAL}: cache_parallel = [64, 64, 0]
52-
// CHECK-SAME{LITERAL}: cache_reduction = [0, 0, 16]
53-
// CHECK-SAME{LITERAL}: distribution = [128, 128, 0]
51+
// CHECK-SAME: cache_parallel = [64, 64, 0]
52+
// CHECK-SAME: cache_reduction = [0, 0, 16]
53+
// CHECK-SAME: distribution = [128, 128, 0]
5454
// CHECK-SAME{LITERAL}: vector_common_parallel = [[4], [4], 0]
55-
// CHECK-SAME{LITERAL}: vector_inner_parallel = [0, 0, 0]
55+
// CHECK-SAME: vector_inner_parallel = [0, 0, 0]
5656
// CHECK-SAME{LITERAL}: vector_reduction = [0, 0, [4]]
57-
// CHECK-LABEL: @test_full_lowering_config_with_scalable_vector()
57+
// CHECK-LABEL: @test_full_lowering_config_with_scalable_vector()
5858
// CHECK-SAME: lowering_config = #[[$CONFIG]]
5959

6060
// -----

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/lowering_config_attr.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ module {
6666
return
6767
}
6868
}
69-
// CHECK: #iree_codegen.export_config<workgroup_size = [4, 1]
69+
// CHECK: #iree_codegen.export_config<workgroup_size = [4, 1]>
7070

7171
// -----
7272

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/ukernel_ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func.func @ukernel_generic_optional_input(
6363
return %0#0, %0#1 : tensor<?xf32>, tensor<?x?xf32>
6464
}
6565
// CHECK: func @ukernel_generic_optional_input(
66-
// CHECK: %[[RESULT:.+]]:2 = iree_codegen.ukernel.generic
66+
// CHECK: %{{.+}}:2 = iree_codegen.ukernel.generic
6767
// CHECK-NOT: ins
6868

6969
// -----
@@ -92,7 +92,7 @@ func.func @ukernel_generic_optional_other_operands(
9292
return %0#0, %0#1 : tensor<?xf32>, tensor<?x?xf32>
9393
}
9494
// CHECK: func @ukernel_generic_optional_other_operands(
95-
// CHECK: %[[RESULT:.+]]:2 = iree_codegen.ukernel.generic
95+
// CHECK: %{{.+}}:2 = iree_codegen.ukernel.generic
9696
// CHECK-SAME: outs(%{{.+}}, %{{.+}} : tensor<?xf32>, tensor<?x?xf32>) ->
9797

9898
// -----

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ module {
114114
return
115115
}
116116
}
117-
// CHECK-LABEL: func @test_data_tiled_mfma_f32_16x16x4_f32
117+
// CHECK-LABEL: func @test_data_tiled_mfma_f32_16x16x4_f32_subgroups_k
118118
// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, intrinsics_m = 4, subgroups_k = 2, operands_interleaving_intrinsics_k = [0, 1]>
119119

120120

compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/convert_to_multi_mma.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ module attributes { transform.with_named_sequence } {
105105
// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
106106
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
107107

108-
// CHECK-LABEL: func @convert_to_mfma_16x16x16
108+
// CHECK-LABEL: func @convert_to_mfma_16x16x16_transpose_b
109109
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<2x16x16xf16>
110110
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<2x16x16xf16>
111111
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<16x16xf32>

compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_inner_tiled.mlir

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,6 @@ module attributes { transform.with_named_sequence } {
8080
transform.yield
8181
}
8282
}
83-
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
84-
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
85-
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
8683

8784
// CHECK-LABEL: func @distribute_inner_tiled_I8_16x16x32_I32
8885
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<2x2x16x32xi8>

compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/unroll_multi_mma.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ module attributes { transform.with_named_sequence } {
9292
}
9393
}
9494

95-
// CHECK-LABEL: func @unroll_multi_mma_count
96-
// CHECK-COUNT-30: %[[MMA:.+]] = iree_codegen.inner_tiled {{.*}} : vector<1x1x4xf16>, vector<1x1x4xf16> into vector<1x1x4xf32>
95+
// CHECK-LABEL: func @unroll_multi_mma_count
96+
// CHECK-COUNT-30: {{.+}} = iree_codegen.inner_tiled {{.*}} : vector<1x1x4xf16>, vector<1x1x4xf16> into vector<1x1x4xf32>
9797
// CHECK-COUNT-10: vector.insert_strided_slice {{.*}} : vector<1x1x4xf32> into vector<2x5x4xf32>
9898

9999
// -----
@@ -130,7 +130,7 @@ module attributes { transform.with_named_sequence } {
130130
}
131131

132132
// CHECK-LABEL: func @unroll_scaled_multi_mma
133-
// CHECK-SAME: %[[LHS_SCALE:[A-Za-z0-9]+]]: vector<1x2x1xf8E8M0FNU>
133+
// CHECK-SAME: %[[LHS_SCALE:[A-Za-z0-9]+]]: vector<1x2x1xf8E8M0FNU>
134134
// CHECK-COUNT-2: vector.extract_strided_slice %[[LHS_SCALE]] {offsets = [0, 0]
135135
// CHECK-NOT: vector.extract_strided_slice %[[LHS_SCALE]] {offsets = [0, 0]
136136
// CHECK-COUNT-2: vector.extract_strided_slice %[[LHS_SCALE]] {offsets = [0, 1]

compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_inner_tiled_to_lanes.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,9 +1185,9 @@ func.func @fuse_producer_slice(%arg1 : tensor<4x2x16x16xbf16>, %arg2 : tensor<1x
11851185
}
11861186

11871187
// CHECK-LABEL: func @fuse_producer_slice
1188-
// CHECK : scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<4x1x16x16xf32>)
1189-
// CHECK : %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]]
1190-
// CHECK : %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[ACC_SLICE]] : tensor<4x1x4x1xf32>) -> tensor<4x1x4x1xf32>
1191-
// CHECK : iree_codegen.inner_tiled
1192-
// CHECK-SAME : outs(%[[FILL]])
1193-
// CHECK : mapping = [#iree_gpu.lane_id<0>]
1188+
// CHECK: scf.forall ({{.+}}) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<4x1x16x16xf32>)
1189+
// CHECK: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]]
1190+
// CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[ACC_SLICE]] : tensor<4x1x4x1xf32>) -> tensor<4x1x4x1xf32>
1191+
// CHECK: iree_codegen.inner_tiled
1192+
// CHECK-SAME: outs(%[[FILL]])
1193+
// CHECK: mapping = [#iree_gpu.lane_id<0>]

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/canonicalize.mlir

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,14 @@ func.func @transfer_gather_fold_single_element(%scalar: vector<1xindex>,
136136
return %out : vector<64x1xf16>
137137
}
138138

139-
// CHECK-LABEL: transfer_gather_fold_single_element
139+
// CHECK-LABEL: @transfer_gather_fold_single_element
140140
// CHECK-SAME: %{{.*}}: vector<1xindex>, %[[ARG1:.*]]: vector<64x1xindex>
141141
// CHECK: transfer_gather
142142
// CHECK-SAME: [None, %[[ARG1]]
143143

144144
// -----
145145

146-
func.func @transfer_gather_fold_contigious_load(%scalar: vector<64x1xindex>,
146+
func.func @transfer_gather_fold_contiguous_load(%scalar: vector<64x1xindex>,
147147
%indices: vector<64x1xindex>,
148148
%source: tensor<4096x64xf16>)
149149
-> vector<64x1xf16> {
@@ -157,8 +157,6 @@ func.func @transfer_gather_fold_contigious_load(%scalar: vector<64x1xindex>,
157157
return %out : vector<64x1xf16>
158158
}
159159

160-
// CHECK-LABEL: @transfer_gather_fold_contigious_load
160+
// CHECK-LABEL: @transfer_gather_fold_contiguous_load
161161
// CHECK: vector.transfer_read
162162
// CHECK-NOT: transfer_gather
163-
164-
// -----

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/invalid.mlir

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,10 @@
1010
subgroup_strides = [0, 0],
1111
thread_strides = [0, 0]>
1212

13-
func.func @invalid_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
14-
%cst_0 = arith.constant 0.0 : f16
15-
%c0 = arith.constant 0 : index
16-
%result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
13+
func.func @invalid_layout(%arg0: vector<32x32xf16>) -> vector<32x32xf16> {
1714
// expected-error @+1 {{Vector shape: [32, 32] does not match the layout (nested_layout<subgroup_tile = [1, 1], batch_tile = [1, 1], outer_tile = [1, 1], thread_tile = [1, 1], element_tile = [1, 1], subgroup_strides = [0, 0], thread_strides = [0, 0]>) at dim 0. Dimension expected by layout: 1 actual: 32}}
18-
%2 = iree_vector_ext.to_layout %result to layout(#layout1) : vector<32x32xf16>
19-
return %2 : vector<32x32xf16>
15+
%0 = iree_vector_ext.to_layout %arg0 to layout(#layout1) : vector<32x32xf16>
16+
return %0 : vector<32x32xf16>
2017
}
2118

2219
// -----
@@ -69,7 +66,7 @@ func.func @indexing_map_mismatch(%indices: vector<128xindex>,
6966

7067
// -----
7168

72-
func.func @indexing_map_mismatch(%indices: vector<128x64xindex>,
69+
func.func @indexing_map_invalid_index_vector_shape(%indices: vector<128x64xindex>,
7370
%source: tensor<128x64xf16>)
7471
-> vector<128x64xf16> {
7572

0 commit comments

Comments
 (0)