Skip to content

Commit

Permalink
Account all the element types to determine vector sizes. (#8552)
Browse files Browse the repository at this point in the history
The assumption was that all the element types have the same bitwidth.
However, there are cases that element types do not match, e.g.,
matmul i8xi8 -> i32. It caused issues that large tiling sizes were
selected, which kicked in heavy optimization in LLVM. This commit
chooses the smallest vector size over all the element types.

This also updates the logic of first level tiling, which follows what we've
done for generic ops.

The commit reduce compilation time from hours to 5 mins for
mobilebert-baseline-tf2-quant.mlir when targeting ARM.

Fixes #8540
  • Loading branch information
hanhanW authored Mar 16, 2022
1 parent 4b889d9 commit 6a9ca64
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 27 deletions.
62 changes: 41 additions & 21 deletions iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,33 @@ static int64_t getVectorSize(FuncOp entryPointFn, ShapedType shapedType) {
return getVectorSize(entryPointFn, byteWidth);
}

/// Returns minimum tiling sizes for each dimension. One dimension is possible
/// to access at different element types. It determines the tiling sizes by
/// looking into all the operands.
static SmallVector<int64_t> getMinTilingSizesForEachDim(FuncOp entryPointFn,
linalg::LinalgOp op) {
unsigned numLoops = op.getNumLoops();
SmallVector<int64_t> minTileSizes(numLoops, 1);
auto inputOutputOpOperands = op.getInputAndOutputOperands();
for (auto map : llvm::enumerate(op.getIndexingMaps())) {
// Check the fastest varying dimension of the operand. Set the vector size
// of the corresponding loop to the vector size.
if (map.value().getNumResults() == 0) continue;
auto fastestVaryingDimExpr =
map.value().getResults().back().dyn_cast<AffineDimExpr>();
if (!fastestVaryingDimExpr) continue;
unsigned fastestVaryingDim = fastestVaryingDimExpr.getPosition();

// If the indexing map has result it has to be a shaped type.
auto operandType =
inputOutputOpOperands[map.index()]->get().getType().cast<ShapedType>();
minTileSizes[fastestVaryingDim] =
std::max<int64_t>(minTileSizes[fastestVaryingDim],
getVectorSize(entryPointFn, operandType));
}
return minTileSizes;
}

/// Returns the type length in bytes. Looks through all the interface binding
/// ops to see the ABI types and guess-timates the type size to use. This is
/// used to convert the vector size in bytes to vector size in number of
Expand Down Expand Up @@ -409,11 +436,20 @@ static LogicalResult setRootConfig(
FuncOp entryPointFn, linalg::ContractionOpInterface contractionOp,
ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
auto linalgOp = cast<linalg::LinalgOp>(contractionOp.getOperation());
// Consider all element types and use the smallest vector size. The tiling
// sizes are chosen based on the vector size.
auto lhsShapedType = contractionOp.lhs().getType().cast<ShapedType>();
auto rhsShapedType = contractionOp.rhs().getType().cast<ShapedType>();
auto resShapedType =
linalgOp.getOutputOperand(0)->get().getType().cast<ShapedType>();
int64_t vectorSize = getVectorSize(entryPointFn, lhsShapedType);
vectorSize = std::min(vectorSize, getVectorSize(entryPointFn, rhsShapedType));
vectorSize = std::min(vectorSize, getVectorSize(entryPointFn, resShapedType));

// Use the default distribution for the matmul loops.
unsigned numLoops = linalgOp.getNumLoops();
int64_t vectorSize = getVectorSize(entryPointFn, lhsShapedType);
SmallVector<int64_t> minTileSizes(numLoops, vectorSize);
SmallVector<int64_t> minTileSizes =
getMinTilingSizesForEachDim(entryPointFn, linalgOp);
SmallVector<int64_t> maxTileSizes(numLoops, defaultWorkgroupTileSize);
if (numLoops > 3) {
minTileSizes[0] = 1;
Expand Down Expand Up @@ -539,25 +575,9 @@ static LogicalResult setRootConfig(
unsigned numLoops = genericOp.getNumLoops();
if (numLoops == 0) return success();

SmallVector<int64_t> minTileSizes(numLoops, 1),
maxTileSizes(numLoops, defaultWorkgroupTileSize);
auto inputOutputOpOperands = genericOp.getInputAndOutputOperands();
for (auto map : llvm::enumerate(genericOp.getIndexingMaps())) {
// Check the fastest varying dimension of the operand. Set the vector size
// of the corresponding loop to the vector size.
if (map.value().getNumResults() == 0) continue;
auto fastestVaryingDimExpr =
map.value().getResults().back().dyn_cast<AffineDimExpr>();
if (!fastestVaryingDimExpr) continue;
unsigned fastestVaryingDim = fastestVaryingDimExpr.getPosition();

// If the indexing map has result it has to be a shaped type.
auto operandType =
inputOutputOpOperands[map.index()]->get().getType().cast<ShapedType>();
minTileSizes[fastestVaryingDim] =
std::max<int64_t>(minTileSizes[fastestVaryingDim],
getVectorSize(entryPointFn, operandType));
}
SmallVector<int64_t> minTileSizes =
getMinTilingSizesForEachDim(entryPointFn, genericOp);
SmallVector<int64_t> maxTileSizes(numLoops, defaultWorkgroupTileSize);
if (llvm::all_of(minTileSizes, [](int64_t vs) { return vs == 1; })) {
// Nothing to vectorize just lower to loops.
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ hal.executable private @matmul_static {
}
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[28, 8, 0], [4, 4, 60], [4, 4, 4]{{\]}}>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[49, 8, 0], [7, 4, 60], [4, 4, 4]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUTileFuseAndVectorize>
// CHECK: hal.executable.entry_point public @matmul_static
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down Expand Up @@ -936,17 +936,17 @@ hal.executable private @reduction {
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
hal.executable private @matmul_i8_i8_i32 {
hal.executable private @matmul_x86_i8_i8_i32 {
hal.executable.variant public @embedded_elf_x86_64, target = #hal.executable.target<
"llvm",
"embedded-elf-x86_64", {
data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
native_vector_size = 4 : index,
target_triple = "x86_64-unknown-unknown-eabi-elf"
}> {
hal.executable.entry_point public @matmul_i8_i8_i32 layout(#executable_layout)
hal.executable.entry_point public @matmul_x86_i8_i8_i32 layout(#executable_layout)
builtin.module {
func @matmul_i8_i8_i32() {
func @matmul_x86_i8_i8_i32() {
%c0 = arith.constant 0 : index
%M = hal.interface.constant.load[0] : index
%N = hal.interface.constant.load[1] : index
Expand Down Expand Up @@ -974,7 +974,57 @@ hal.executable private @matmul_i8_i8_i32 {

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [8, 32, 0], [0, 0, 16]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
// CHECK: hal.executable.entry_point public @matmul_i8_i8_i32
// CHECK: hal.executable.entry_point public @matmul_x86_i8_i8_i32
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.matmul
// CHECK-SAME: lowering_config = #[[CONFIG]]

// -----

#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
hal.executable private @matmul_aarch_i8_i8_i32 {
hal.executable.variant public @system_elf_arm_64, target = <"llvm", "system-elf-arm_64", {
data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128",
native_vector_size = 16 : index,
target_triple = "aarch64-none-linux-android30"
}> {
hal.executable.entry_point public @matmul_aarch_i8_i8_i32 layout(#executable_layout)
builtin.module {
func @matmul_aarch_i8_i8_i32() {
%c0 = arith.constant 0 : index
%M = hal.interface.constant.load[0] : index
%N = hal.interface.constant.load[1] : index
%K = hal.interface.constant.load[2] : index
%lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
: !flow.dispatch.tensor<readonly:?x?xi8>{%M, %K}
%rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
: !flow.dispatch.tensor<readonly:?x?xi8>{%K, %N}
%result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32)
: !flow.dispatch.tensor<readwrite:?x?xi32>{%M, %N}
%lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
: !flow.dispatch.tensor<readonly:?x?xi8>{%M, %K} -> tensor<?x?xi8>
%rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
: !flow.dispatch.tensor<readonly:?x?xi8>{%K, %N} -> tensor<?x?xi8>
%init = flow.dispatch.tensor.load %result_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
: !flow.dispatch.tensor<readwrite:?x?xi32>{%M, %N} -> tensor<?x?xi32>
%gemm = linalg.matmul ins(%lhs, %rhs : tensor<?x?xi8>, tensor<?x?xi8>) outs(%init : tensor<?x?xi32>) -> tensor<?x?xi32>
flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
: tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%M, %N}
return
}
}
}
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [16, 4, 64], [4, 4, 4]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUTileFuseAndVectorize>
// CHECK: hal.executable.entry_point public @matmul_aarch_i8_i8_i32
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.matmul
// CHECK-SAME: lowering_config = #[[CONFIG]]
Expand Down Expand Up @@ -1118,7 +1168,7 @@ hal.executable private @matmul_odd {
}
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[3, 7, 0], [3, 7, 0], [0, 0, 16]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[11, 7, 0], [1, 7, 0], [0, 0, 16]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
// CHECK: hal.executable.entry_point public @matmul_odd
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down

0 comments on commit 6a9ca64

Please sign in to comment.