Skip to content

Commit

Permalink
Add sub-byte emulation support to llvm-cpu and vulkan backend. (#14488)
Browse files Browse the repository at this point in the history
Based off of
ea49d2b

Fixes #14481

---------

Co-authored-by: Hanhan Wang <[email protected]>
Co-authored-by: yzhang93 <[email protected]>
  • Loading branch information
3 people authored Jul 28, 2023
1 parent 73b00a0 commit a83e543
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 1 deletion.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ iree_compiler_cc_library(
"DecomposeConvolutionToLowerDimOps.cpp",
"DecomposeLinalgGeneric.cpp",
"DecomposePackUnPackOps.cpp",
"EmulateNarrowType.cpp",
"EraseDeadAllocAndStores.cpp",
"EraseHALDescriptorTypeFromMemRef.cpp",
"ExtractAddressComputation.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ iree_cc_library(
"DecomposeConvolutionToLowerDimOps.cpp"
"DecomposeLinalgGeneric.cpp"
"DecomposePackUnPackOps.cpp"
"EmulateNarrowType.cpp"
"EraseDeadAllocAndStores.cpp"
"EraseHALDescriptorTypeFromMemRef.cpp"
"ExtractAddressComputation.cpp"
Expand Down
120 changes: 120 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
namespace iree_compiler {

namespace {

//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//

struct ConvertHalInterfaceBindingSubspan final
: OpConversionPattern<IREE::HAL::InterfaceBindingSubspanOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type newResultTy = getTypeConverter()->convertType(op.getType());
if (!newResultTy)
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to legalize memref type: {0}", op.getType()));

rewriter.replaceOpWithNewOp<IREE::HAL::InterfaceBindingSubspanOp>(
op, newResultTy, adaptor.getSet(), adaptor.getBinding(),
adaptor.getDescriptorType(), adaptor.getByteOffset(),
adaptor.getDynamicDims(), adaptor.getAlignmentAttr(),
adaptor.getDescriptorFlagsAttr());
return success();
}
};

static void populateIreeNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &converter,
RewritePatternSet &patterns) {
patterns.add<ConvertHalInterfaceBindingSubspan>(converter,
patterns.getContext());
}

//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//

struct EmulateNarrowTypePass
: public EmulateNarrowTypeBase<EmulateNarrowTypePass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, func::FuncDialect,
memref::MemRefDialect, vector::VectorDialect,
affine::AffineDialect, IREE::HAL::HALDialect>();
}

void runOnOperation() override {
// The number of bits used in a load/store op.
constexpr unsigned kLoadStoreEmulateBitwidth = 8;
static_assert(
llvm::isPowerOf2_32(kLoadStoreEmulateBitwidth) &&
"only power of 2 is supported for narrow type load/store emulation");

MLIRContext *ctx = &getContext();

arith::NarrowTypeEmulationConverter typeConverter(
kLoadStoreEmulateBitwidth);
memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);

ConversionTarget target(*ctx);
target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
});
auto opLegalCallback = [&typeConverter](Operation *op) {
return typeConverter.isLegal(op);
};
target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
target.addDynamicallyLegalDialect<
arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
affine::AffineDialect, IREE::HAL::HALDialect>(opLegalCallback);

RewritePatternSet patterns(ctx);
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
populateIreeNarrowTypeEmulationPatterns(typeConverter, patterns);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

//===----------------------------------------------------------------------===//
// Public interface
//===----------------------------------------------------------------------===//

std::unique_ptr<OperationPass<ModuleOp>> createEmulateNarrowTypePass() {
return std::make_unique<EmulateNarrowTypePass>();
}

} // namespace iree_compiler
} // namespace mlir
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ createDecomposePackUnPackOpsPass(bool tileOuterToOne = false);
/// during bufferization.
std::unique_ptr<OperationPass<ModuleOp>> createEliminateEmptyTensorsPass();

/// A pass to emulate memref load operations that use narrow integer types
/// with equivalent operations on supported wide integer types.
std::unique_ptr<OperationPass<ModuleOp>> createEmulateNarrowTypePass();

/// Creates a pass to erase dead alloc ops where all uses are just store ops.
std::unique_ptr<OperationPass<func::FuncOp>>
createEraseDeadAllocAndStoresPass();
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ def EliminateEmptyTensors :
let constructor = "mlir::iree_compiler::createEliminateEmptyTensorsPass()";
}

def EmulateNarrowType :
Pass<"iree-codegen-emulate-narrow-type", "ModuleOp"> {
let summary = "Emulate narrow integer operations using wide integer operations";
let constructor = "mlir::iree_compiler::createEmulateNarrowTypePass()";
}

def EraseDeadAllocAndStores :
Pass<"iree-codegen-erase-dead-alloc-and-stores", "func::FuncOp"> {
let summary = "Erase alloc ops if all the uses are just stores";
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ iree_lit_test_suite(
"decompose_linalg_generic.mlir",
"decompose_pack_unpack_ops.mlir",
"eliminate_empty_tensors.mlir",
"emulate_narrow_type.mlir",
"erase_hal_descriptor_type.mlir",
"extract_address_computation.mlir",
"flatten_memref_subspan.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_lit_test_suite(
"decompose_linalg_generic.mlir"
"decompose_pack_unpack_ops.mlir"
"eliminate_empty_tensors.mlir"
"emulate_narrow_type.mlir"
"erase_dead_alloc_and_stores.mlir"
"erase_hal_descriptor_type.mlir"
"extract_address_computation.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: iree-opt --split-input-file --iree-codegen-emulate-narrow-type %s | FileCheck %s

func.func @memref_i4_to_i8() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<8xi4>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<8xf32>
return
}
// CHECK-LABEL: func.func @memref_i4_to_i8
// CHECK: hal.interface.binding.subspan {{.+}} memref<8xi8>
// CHECK: hal.interface.binding.subspan {{.+}} memref<8xf32>
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ static void addLowerToLLVMPasses(OpPassManager &passManager) {
// (HAL, IREE, Linalg, CF) -> LLVM
passManager.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
passManager.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
passManager.addPass(createEmulateNarrowTypePass());
if (clInstrumentMemoryAccesses) {
passManager.addNestedPass<func::FuncOp>(
createInstrumentMemoryAccessesPass());
Expand Down
7 changes: 6 additions & 1 deletion tests/e2e/linalg/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package(
LLVM_SRCS = enforce_glob(
[
"conv2d.mlir",
"i4_to_f32.mlir",
],
include = ["*.mlir"],
exclude = ["large_linalg_matmul.mlir"],
Expand Down Expand Up @@ -48,7 +49,10 @@ VMVX_SRCS = enforce_glob(
"conv2d.mlir",
],
include = ["*.mlir"],
exclude = ["large_linalg_matmul.mlir"],
exclude = [
"large_linalg_matmul.mlir",
"i4_to_f32.mlir",
],
)

iree_check_single_backend_test_suite(
Expand All @@ -61,6 +65,7 @@ iree_check_single_backend_test_suite(
VULKAN_SRCS = enforce_glob(
[
"conv2d.mlir",
"i4_to_f32.mlir",
],
include = ["*.mlir"],
exclude = ["large_linalg_matmul.mlir"],
Expand Down
4 changes: 4 additions & 0 deletions tests/e2e/linalg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_check_single_backend_test_suite(
check_llvm-cpu_local-task
SRCS
"conv2d.mlir"
"i4_to_f32.mlir"
TARGET_BACKEND
"llvm-cpu"
DRIVER
Expand All @@ -26,6 +27,7 @@ iree_check_single_backend_test_suite(
check_winograd_llvm-cpu_local-task
SRCS
"conv2d.mlir"
"i4_to_f32.mlir"
TARGET_BACKEND
"llvm-cpu"
DRIVER
Expand All @@ -50,6 +52,7 @@ iree_check_single_backend_test_suite(
check_vulkan-spirv_vulkan
SRCS
"conv2d.mlir"
"i4_to_f32.mlir"
TARGET_BACKEND
"vulkan-spirv"
DRIVER
Expand All @@ -61,6 +64,7 @@ iree_check_single_backend_test_suite(
check_winograd_vulkan-spirv_vulkan
SRCS
"conv2d.mlir"
"i4_to_f32.mlir"
TARGET_BACKEND
"vulkan-spirv"
DRIVER
Expand Down
14 changes: 14 additions & 0 deletions tests/e2e/linalg/i4_to_f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#map = affine_map<(d0) -> (d0)>
func.func @i4_to_f32() {
%input = util.unfoldable_constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi4>
%0 = tensor.empty() : tensor<8xf32>
%res = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
ins(%input : tensor<8xi4>) outs(%0 : tensor<8xf32>) {
^bb0(%in: i4, %out: f32):
%2 = arith.extui %in : i4 to i32
%3 = arith.uitofp %2 : i32 to f32
linalg.yield %3 : f32
} -> tensor<8xf32>
check.expect_eq_const(%res, dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]> : tensor<8xf32>) : tensor<8xf32>
return
}

0 comments on commit a83e543

Please sign in to comment.