Skip to content

Commit

Permalink
[MLIR] Support i1 datatypes
Browse files Browse the repository at this point in the history
* Turn on by `--iree-enable-i1` option.
  • Loading branch information
lialan committed Oct 15, 2024
1 parent 7622770 commit 36561a9
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 2 deletions.
197 changes: 197 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,202 @@ static void populateIreeNarrowTypeEmulationPatterns(
patterns.getContext());
}

static bool isByteAligned(ShapedType type) {
unsigned elementBits = type.getElementType().getIntOrFloatBitWidth();
auto numElements = type.getNumElements();
return (numElements * elementBits) % 8 == 0;
}

struct PadSubbyteTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const final {
auto target = writeOp.getVector();
auto targetType = cast<VectorType>(target.getType());
if (isByteAligned(targetType)) {
return failure();
}

auto source = writeOp.getSource();
auto sourceType = cast<ShapedType>(source.getType());
auto elemType = targetType.getElementType();
unsigned elementBits = targetType.getElementType().getIntOrFloatBitWidth();
auto numElements = targetType.getNumElements();

SmallVector<int64_t> strides;
SmallVector<int64_t> offsets;
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
strides.push_back(1);
offsets.push_back(0);
}

// TODO: we should keep the source and sink ... otherwise we are
// overwriting some part of the source tensor

SmallVector<int64_t> newShape = SmallVector<int64_t>(targetType.getShape());
newShape.back() += (8 - (numElements * elementBits) % 8) / elementBits;
auto newTargetType = VectorType::get(newShape, elemType);

// create an empty vector of the correct size
SmallVector<bool> zeroValues;
for (unsigned i = 0; i < newTargetType.getNumElements(); ++i) {
zeroValues.push_back(false);
}
auto zeroVector = rewriter.create<arith::ConstantOp>(
writeOp.getLoc(), DenseIntElementsAttr::get(newTargetType, zeroValues));

auto extendedOp = rewriter.create<vector::InsertStridedSliceOp>(
writeOp->getLoc(), target, zeroVector, offsets, strides);

writeOp.getVectorMutable().assign(extendedOp);
return success();
}
};

struct PadSubbyteTransferReadPattern
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const final {
auto resultType = cast<VectorType>(readOp.getResult().getType());
if (isByteAligned(resultType)) {
return failure();
}

unsigned elementBits = resultType.getElementType().getIntOrFloatBitWidth();
auto numElements = resultType.getNumElements();

// pad the type to be byte aligned
SmallVector<int64_t> newShape = SmallVector<int64_t>(resultType.getShape());
newShape.back() += (8 - (numElements * elementBits) % 8) / elementBits;
// Create a new vector type with the padded shape
auto newType = VectorType::get(newShape, resultType.getElementType());

// Create a new transfer read op with the new type
auto paddingValue = rewriter.create<arith::ConstantOp>(
readOp.getLoc(), resultType.getElementType(),
rewriter.getZeroAttr(resultType.getElementType()));

// use a vector extract to extract the original vector
SmallVector<int64_t> offsets, strides;
for (unsigned i = 0; i < resultType.getRank(); ++i) {
offsets.push_back(0);
strides.push_back(1);
}

auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), newType, readOp.getSource(), readOp.getIndices(),
paddingValue);

rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
readOp, newTransferReadOp, offsets, resultType.getShape(), strides);
return success();
}
};

struct PadSubbyteVectorLoadPattern : public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
PatternRewriter &rewriter) const final {
auto result = loadOp.getResult();
auto resultType = mlir::cast<VectorType>(result.getType());
if (isByteAligned(resultType)) {
return failure();
}

unsigned elementBits = resultType.getElementType().getIntOrFloatBitWidth();
auto numElements = resultType.getNumElements();

SmallVector<int64_t> newShape = SmallVector<int64_t>(resultType.getShape());
newShape.back() += (8 - (numElements * elementBits) % 8) / elementBits;
auto newTargetType = VectorType::get(newShape, resultType.getElementType());

// create a new vector load op with the new type
auto newVectorLoad = rewriter.create<vector::LoadOp>(
loadOp.getLoc(), newTargetType, loadOp.getBase(), loadOp.getIndices());

auto newNumElements = newTargetType.getNumElements();
SmallVector<bool> zeroValues;
for (unsigned i = 0; i < newNumElements; ++i) {
zeroValues.push_back(false);
}

// extract strided slice
SmallVector<int64_t> offsets, strides;
for (unsigned i = 0; i < resultType.getRank(); ++i) {
offsets.push_back(0);
strides.push_back(1);
}

rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
loadOp, newVectorLoad, offsets, resultType.getShape(), strides);
return success();
}
};

struct PadSubbyteVectorStorePattern : public OpRewritePattern<vector::StoreOp> {
using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
PatternRewriter &rewriter) const final {
auto storeValue = storeOp.getValueToStore();
auto valueType = mlir::cast<ShapedType>(storeValue.getType());
if (isByteAligned(valueType)) {
return failure();
}

auto target = storeOp.getBase();
auto targetType = mlir::cast<ShapedType>(target.getType());
// check that the type size is byte aligned
auto elemType = valueType.getElementType();
unsigned elementBits = valueType.getElementType().getIntOrFloatBitWidth();
auto numElements = valueType.getNumElements();

SmallVector<int64_t> newShape = SmallVector<int64_t>(valueType.getShape());
newShape.back() += (8 - (numElements * elementBits) % 8) / elementBits;
auto newValueType = VectorType::get(newShape, elemType);

SmallVector<int64_t> strides;
SmallVector<int64_t> offsets;
for (unsigned i = 0; i < targetType.getRank(); ++i) {
strides.push_back(1);
offsets.push_back(0);
}

// create an empty vector of the correct size
SmallVector<bool> zeroValues;
for (unsigned i = 0; i < newValueType.getNumElements(); ++i) {
zeroValues.push_back(false);
}
auto zeroVector = rewriter.create<arith::ConstantOp>(
storeOp.getLoc(), DenseIntElementsAttr::get(newValueType, zeroValues));

auto extendedOp = rewriter.create<vector::InsertStridedSliceOp>(
storeOp->getLoc(), storeValue, zeroVector, offsets, strides);

// create a mask and use masked store:
SmallVector<Value> maskShape;
for (auto dim : valueType.getShape()) {
maskShape.push_back(
rewriter.create<arith::ConstantIndexOp>(storeOp.getLoc(), dim));
}
auto mask = rewriter.create<vector::CreateMaskOp>(storeOp.getLoc(),
newValueType, maskShape);

rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
storeOp, target, storeOp.getIndices(), mask, extendedOp);
return success();
}
};

static void populateSubbyteTypeHandlingPatterns(RewritePatternSet &patterns) {
patterns.add<PadSubbyteTransferReadPattern, PadSubbyteTransferWritePattern,
PadSubbyteVectorLoadPattern, PadSubbyteVectorStorePattern>(
patterns.getContext());
}

//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -133,6 +329,7 @@ struct EmulateNarrowTypePass final
affine::AffineDialect, IREE::HAL::HALDialect>(opLegalCallback);

RewritePatternSet patterns(ctx);
populateSubbyteTypeHandlingPatterns(patterns);
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
populateIREEResolveExtractStridedMetadataPatterns(ctx, patterns);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: iree-opt --split-input-file --iree-codegen-emulate-narrow-type %s | FileCheck %s

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>
]>

func.func @i1_datatype_emulation() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<8xi1, strided<[1], offset: ?>>
%3 = vector.load %0[%c0] : memref<8xi1, strided<[1], offset: ?>>, vector<6xi1>
%4 = vector.load %0[%c0] : memref<8xi1, strided<[1], offset: ?>>, vector<6xi1>
%5 = arith.addi %3, %4 : vector<6xi1>
vector.store %5, %0[%c0] : memref<8xi1, strided<[1], offset: ?>>, vector<6xi1>
return
}
// CHECK-LABEL: @i1_datatype_emulation


// CHECK: %[[EMU_LOAD:.+]] = vector.load
// CHECK-SAME: vector<1xi8>
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMU_LOAD]]
// CHECK-SAME: vector<1xi8> to vector<8xi1>
// CHECK: vector.extract_strided_slice %[[BITCAST]]
// CHECK-SAME: vector<8xi1> to vector<6xi1>

// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice
// CHECK-SAME: vector<6xi1> into vector<8xi1>
// CHECK: vector.create_mask
// CHECK-SAME: vector<8xi1>

// CHECK: vector.maskedstore
// CHECK-SAME: vector<1xi1>, vector<1xi8>

Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-generic-vectorization{vectorize-padding=true}))" --split-input-file %s | FileCheck %s

func.func @test_subbyte_6_i1() attributes {translation_info = #iree_codegen.translation_info<CPUDoubleTilingExpert>} {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<8xi1>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<8xi1>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<8xi1>>
%3 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [6], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<8xi1>> -> tensor<6xi1>
%4 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [6], strides = [1] : !flow.dispatch.tensor<readonly:tensor<8xi1>> -> tensor<6xi1>
%5 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [6], strides = [1] : !flow.dispatch.tensor<readonly:tensor<8xi1>> -> tensor<6xi1>
%6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%4, %5 : tensor<6xi1>, tensor<6xi1>) outs(%3 : tensor<6xi1>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[6], [8], [0], [0]]>} {
^bb0(%in: i1, %in_0: i1, %out: i1):
%7 = arith.addi %in, %in_0 : i1
linalg.yield %7 : i1
} -> tensor<6xi1>
flow.dispatch.tensor.store %6, %2, offsets = [0], sizes = [6], strides = [1] : tensor<6xi1> -> !flow.dispatch.tensor<writeonly:tensor<8xi1>>
return
}

// CHECK-LABEL: @test_subbyte_6_i1

// CHECK: %[[TR1:.+]] = vector.transfer_read %[[.+]][%c0], %false : tensor<6xi1>, vector<8xi1>
// CHECK: %[[ESS1:.+]] = vector.extract_strided_slice %[[TR1]] {offsets = [0], sizes = [6], strides = [1]} : vector<8xi1> to vector<6xi1>

// CHECK: %[[TR2:.+]] = vector.transfer_read %[[.+]][%c0], %false : tensor<6xi1>, vector<8xi1>
// CHECK: %[[ESS2:.+]] = vector.extract_strided_slice %[[TR2]] {offsets = [0], sizes = [6], strides = [1]} : vector<8xi1> to vector<6xi1>

// CHECK: %[[ISS:.+]] = vector.insert_strided_slice
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<6xi1> into vector<8xi1>
// CHECK: vector.transfer_write %[[ISS]], %[[.+]][%c0] {in_bounds = [true]} : vector<8xi1>, tensor<6xi1>

10 changes: 10 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2942,6 +2942,16 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn,
}
}

// check to make sure the innermost tile size times element size is multiple
// of byte
auto elementTypeSize =
cast<ShapedType>(rootOperation->getResultTypes().front())
.getElementType()
.getIntOrFloatBitWidth();
auto innermostTileSize = commonVecTileSizes.back();
commonVecTileSizes.back() =
llvm::alignTo(innermostTileSize * elementTypeSize, 8) / elementTypeSize;

// Set the lowering configs with new tile sizes.
for (auto op : computeOps) {
int numLoops = cast<TilingInterface>(op).getLoopIteratorTypes().size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
#executable_target_system_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "system-elf-riscv_64", {cpu = "generic-rv64", cpu_features = "+m,+a,+f,+d,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 64 : index, target_triple = "riscv64"}>
#executable_target_system_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "system-elf-riscv_64", {cpu = "generic-rv64", cpu_features = "+m,+a,+f,+d,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 4 : index, target_triple = "riscv64"}>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1) -> (d0 + d1 * 257)>
func.func @main_dispatch_77_generic_1x257x257x21() attributes {hal.executable.target = #executable_target_system_elf_riscv_64_} {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap,
Value maskVal = args[0];

// TODO: Replace bool mask condition once treated as i1 (instead of i8)
if (maskVal.getType().isInteger()) {
auto maskValType = maskVal.getType();
if (maskValType.isInteger() && !maskValType.isInteger(1)) {
maskVal =
b.create<arith::TruncIOp>(loc, builder.getI1Type(), maskVal);
maskVal = b.create<arith::SelectOp>(loc, maskVal, zero, negInf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,21 @@ static bool insertBindingOp(BlockArgument arg,
}
}

// align tensor type to multiple of 8 bits:
auto rankedTensorType = tensorType.asRankedTensorType();
auto elementSize = rankedTensorType.getElementType().getIntOrFloatBitWidth();
auto typeSize = tensorType.getNumElements() * elementSize;

if (typeSize * elementSize % 8 != 0) {
SmallVector<int64_t> newShape(rankedTensorType.getShape());
newShape.back() = llvm::alignTo(newShape.back(), 8 / elementSize);

auto newTensorType = IREE::Flow::DispatchTensorType::get(
tensorType.getAccess(), newShape,
rankedTensorType.getElementType(), rankedTensorType.getEncoding());
tensorType = newTensorType;
}

auto subspanOp = builder.create<IREE::Stream::BindingSubspanOp>(
arg.getLoc(), tensorType, arg, zero, dynamicDims);
arg.replaceAllUsesExcept(subspanOp.getResult(), subspanOp);
Expand Down
14 changes: 14 additions & 0 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@

namespace mlir::iree_compiler {

llvm::cl::opt<bool> clEnableI1Support(
"iree-enable-i1",
llvm::cl::desc("enable i1 data type codegen"),
llvm::cl::init(false));

bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
// Enable i1 support if requested.
if (clEnableI1Support) {
return bitWidth == 1;
}
// Require the original bit width to be some power of two for now to avoid
// trickiness and weirdness of packing and cross-byte access.
// Also disallow boolean values for now--they may require separate interface
Expand Down Expand Up @@ -99,6 +108,11 @@ Value calculateStorageElementCountInBytes(Location loc,
}
}

// make sure the last dimension is byte aligned.
if (needToPackSubByteElementBitWidth(elementBits)) {
paddedShape.back() = llvm::alignTo(paddedShape.back(), 8 / elementBits);
}

for (unsigned i = 0; i < shapedType.getRank(); ++i) {
if (!shapedType.isDynamicDim(i))
staticCount *= paddedShape[i];
Expand Down

0 comments on commit 36561a9

Please sign in to comment.