From 6a9ca642d6a22a021fca30eb27942ab42617ea7d Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Tue, 15 Mar 2022 18:44:55 -0700 Subject: [PATCH] Account all the element types to determine vector sizes. (#8552) The assumption was that all the element types have the same bitwidth. However, there are cases that element types do not match, e.g., matmul i8xi8 -> i32. It caused issues that large tiling sizes were selected, which kicked in heavy optimization in LLVM. This commit chooses the smallest vector size over all the element types. This also updates the logic of first level tiling, which follows what we've done for generic ops. The commit reduce compilation time from hours to 5 mins for mobilebert-baseline-tf2-quant.mlir when targeting ARM. Fixes https://github.com/google/iree/issues/8540 --- .../Codegen/LLVMCPU/KernelDispatch.cpp | 62 ++++++++++++------- .../materialize_launch_configuration.mlir | 62 +++++++++++++++++-- 2 files changed, 97 insertions(+), 27 deletions(-) diff --git a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index 726cdd8a7de0..db883f20ac09 100644 --- a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -99,6 +99,33 @@ static int64_t getVectorSize(FuncOp entryPointFn, ShapedType shapedType) { return getVectorSize(entryPointFn, byteWidth); } +/// Returns minimum tiling sizes for each dimension. One dimension is possible +/// to access at different element types. It determines the tiling sizes by +/// looking into all the operands. +static SmallVector getMinTilingSizesForEachDim(FuncOp entryPointFn, + linalg::LinalgOp op) { + unsigned numLoops = op.getNumLoops(); + SmallVector minTileSizes(numLoops, 1); + auto inputOutputOpOperands = op.getInputAndOutputOperands(); + for (auto map : llvm::enumerate(op.getIndexingMaps())) { + // Check the fastest varying dimension of the operand. Set the vector size + // of the corresponding loop to the vector size. + if (map.value().getNumResults() == 0) continue; + auto fastestVaryingDimExpr = + map.value().getResults().back().dyn_cast(); + if (!fastestVaryingDimExpr) continue; + unsigned fastestVaryingDim = fastestVaryingDimExpr.getPosition(); + + // If the indexing map has result it has to be a shaped type. + auto operandType = + inputOutputOpOperands[map.index()]->get().getType().cast(); + minTileSizes[fastestVaryingDim] = + std::max(minTileSizes[fastestVaryingDim], + getVectorSize(entryPointFn, operandType)); + } + return minTileSizes; +} + /// Returns the type length in bytes. Looks through all the interface binding /// ops to see the ABI types and guess-timates the type size to use. This is /// used to convert the vector size in bytes to vector size in number of @@ -409,11 +436,20 @@ static LogicalResult setRootConfig( FuncOp entryPointFn, linalg::ContractionOpInterface contractionOp, ArrayRef tiledLoops) { auto linalgOp = cast(contractionOp.getOperation()); + // Consider all element types and use the smallest vector size. The tiling + // sizes are chosen based on the vector size. auto lhsShapedType = contractionOp.lhs().getType().cast(); + auto rhsShapedType = contractionOp.rhs().getType().cast(); + auto resShapedType = + linalgOp.getOutputOperand(0)->get().getType().cast(); + int64_t vectorSize = getVectorSize(entryPointFn, lhsShapedType); + vectorSize = std::min(vectorSize, getVectorSize(entryPointFn, rhsShapedType)); + vectorSize = std::min(vectorSize, getVectorSize(entryPointFn, resShapedType)); + // Use the default distribution for the matmul loops. unsigned numLoops = linalgOp.getNumLoops(); - int64_t vectorSize = getVectorSize(entryPointFn, lhsShapedType); - SmallVector minTileSizes(numLoops, vectorSize); + SmallVector minTileSizes = + getMinTilingSizesForEachDim(entryPointFn, linalgOp); SmallVector maxTileSizes(numLoops, defaultWorkgroupTileSize); if (numLoops > 3) { minTileSizes[0] = 1; @@ -539,25 +575,9 @@ static LogicalResult setRootConfig( unsigned numLoops = genericOp.getNumLoops(); if (numLoops == 0) return success(); - SmallVector minTileSizes(numLoops, 1), - maxTileSizes(numLoops, defaultWorkgroupTileSize); - auto inputOutputOpOperands = genericOp.getInputAndOutputOperands(); - for (auto map : llvm::enumerate(genericOp.getIndexingMaps())) { - // Check the fastest varying dimension of the operand. Set the vector size - // of the corresponding loop to the vector size. - if (map.value().getNumResults() == 0) continue; - auto fastestVaryingDimExpr = - map.value().getResults().back().dyn_cast(); - if (!fastestVaryingDimExpr) continue; - unsigned fastestVaryingDim = fastestVaryingDimExpr.getPosition(); - - // If the indexing map has result it has to be a shaped type. - auto operandType = - inputOutputOpOperands[map.index()]->get().getType().cast(); - minTileSizes[fastestVaryingDim] = - std::max(minTileSizes[fastestVaryingDim], - getVectorSize(entryPointFn, operandType)); - } + SmallVector minTileSizes = + getMinTilingSizesForEachDim(entryPointFn, genericOp); + SmallVector maxTileSizes(numLoops, defaultWorkgroupTileSize); if (llvm::all_of(minTileSizes, [](int64_t vs) { return vs == 1; })) { // Nothing to vectorize just lower to loops. return success(); diff --git a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir index 3445c6048214..4f219c4633fa 100644 --- a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir +++ b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir @@ -718,7 +718,7 @@ hal.executable private @matmul_static { } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.entry_point public @matmul_static // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -936,7 +936,7 @@ hal.executable private @reduction { #hal.descriptor_set.binding<2, storage_buffer> ]> ]> -hal.executable private @matmul_i8_i8_i32 { +hal.executable private @matmul_x86_i8_i8_i32 { hal.executable.variant public @embedded_elf_x86_64, target = #hal.executable.target< "llvm", "embedded-elf-x86_64", { @@ -944,9 +944,9 @@ hal.executable private @matmul_i8_i8_i32 { native_vector_size = 4 : index, target_triple = "x86_64-unknown-unknown-eabi-elf" }> { - hal.executable.entry_point public @matmul_i8_i8_i32 layout(#executable_layout) + hal.executable.entry_point public @matmul_x86_i8_i8_i32 layout(#executable_layout) builtin.module { - func @matmul_i8_i8_i32() { + func @matmul_x86_i8_i8_i32() { %c0 = arith.constant 0 : index %M = hal.interface.constant.load[0] : index %N = hal.interface.constant.load[1] : index @@ -974,7 +974,57 @@ hal.executable private @matmul_i8_i8_i32 { // CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info -// CHECK: hal.executable.entry_point public @matmul_i8_i8_i32 +// CHECK: hal.executable.entry_point public @matmul_x86_i8_i8_i32 +// CHECK-SAME: translation_info = #[[TRANSLATION]] +// CHECK: linalg.matmul +// CHECK-SAME: lowering_config = #[[CONFIG]] + +// ----- + +#executable_layout = #hal.executable.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +hal.executable private @matmul_aarch_i8_i8_i32 { + hal.executable.variant public @system_elf_arm_64, target = <"llvm", "system-elf-arm_64", { + data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", + native_vector_size = 16 : index, + target_triple = "aarch64-none-linux-android30" + }> { + hal.executable.entry_point public @matmul_aarch_i8_i8_i32 layout(#executable_layout) + builtin.module { + func @matmul_aarch_i8_i8_i32() { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load[0] : index + %N = hal.interface.constant.load[1] : index + %K = hal.interface.constant.load[2] : index + %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) + : !flow.dispatch.tensor{%M, %K} + %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) + : !flow.dispatch.tensor{%K, %N} + %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) + : !flow.dispatch.tensor{%M, %N} + %lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor{%M, %K} -> tensor + %rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor{%K, %N} -> tensor + %init = flow.dispatch.tensor.load %result_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor{%M, %N} -> tensor + %gemm = linalg.matmul ins(%lhs, %rhs : tensor, tensor) outs(%init : tensor) -> tensor + flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor -> !flow.dispatch.tensor{%M, %N} + return + } + } + } +} + +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK: hal.executable.entry_point public @matmul_aarch_i8_i8_i32 // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK: linalg.matmul // CHECK-SAME: lowering_config = #[[CONFIG]] @@ -1118,7 +1168,7 @@ hal.executable private @matmul_odd { } } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.entry_point public @matmul_odd // CHECK-SAME: translation_info = #[[TRANSLATION]]