diff --git a/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp b/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp index 772faf49a6273..ee14424fe277f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp @@ -95,6 +95,202 @@ static void populateIreeNarrowTypeEmulationPatterns( patterns.getContext()); } +static bool isByteAligned(ShapedType type) { + unsigned elementBits = type.getElementType().getIntOrFloatBitWidth(); + auto numElements = type.getNumElements(); + return (numElements * elementBits) % 8 == 0; +} + +struct PadSubbyteTransferWritePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, + PatternRewriter &rewriter) const final { + auto target = writeOp.getVector(); + auto targetType = cast(target.getType()); + if (isByteAligned(targetType)) { + return failure(); + } + + auto source = writeOp.getSource(); + auto sourceType = cast(source.getType()); + auto elemType = targetType.getElementType(); + unsigned elementBits = targetType.getElementType().getIntOrFloatBitWidth(); + auto numElements = targetType.getNumElements(); + + SmallVector strides; + SmallVector offsets; + for (unsigned i = 0; i < sourceType.getRank(); ++i) { + strides.push_back(1); + offsets.push_back(0); + } + + // TODO: we should keep the source and sink ... otherwise we are + // overwriting some part of the source tensor + + SmallVector newShape = SmallVector(targetType.getShape()); + newShape.back() += (8 - (numElements * elementBits) % 8) / elementBits; + auto newTargetType = VectorType::get(newShape, elemType); + + // create an empty vector of the correct size + SmallVector zeroValues; + for (unsigned i = 0; i < newTargetType.getNumElements(); ++i) { + zeroValues.push_back(false); + } + auto zeroVector = rewriter.create( + writeOp.getLoc(), DenseIntElementsAttr::get(newTargetType, zeroValues)); + + auto extendedOp = rewriter.create( + writeOp->getLoc(), target, zeroVector, offsets, strides); + + writeOp.getVectorMutable().assign(extendedOp); + return success(); + } +}; + +struct PadSubbyteTransferReadPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, + PatternRewriter &rewriter) const final { + auto resultType = cast(readOp.getResult().getType()); + if (isByteAligned(resultType)) { + return failure(); + } + + unsigned elementBits = resultType.getElementType().getIntOrFloatBitWidth(); + auto numElements = resultType.getNumElements(); + + // pad the type to be byte aligned + SmallVector newShape = SmallVector(resultType.getShape()); + newShape.back() += (8 - (numElements * elementBits) % 8) / elementBits; + // Create a new vector type with the padded shape + auto newType = VectorType::get(newShape, resultType.getElementType()); + + // Create a new transfer read op with the new type + auto paddingValue = rewriter.create( + readOp.getLoc(), resultType.getElementType(), + rewriter.getZeroAttr(resultType.getElementType())); + + // use a vector extract to extract the original vector + SmallVector offsets, strides; + for (unsigned i = 0; i < resultType.getRank(); ++i) { + offsets.push_back(0); + strides.push_back(1); + } + + auto newTransferReadOp = rewriter.create( + readOp.getLoc(), newType, readOp.getSource(), readOp.getIndices(), + paddingValue); + + rewriter.replaceOpWithNewOp( + readOp, newTransferReadOp, offsets, resultType.getShape(), strides); + return success(); + } +}; + +struct PadSubbyteVectorLoadPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::LoadOp loadOp, + PatternRewriter &rewriter) const final { + auto result = loadOp.getResult(); + auto resultType = mlir::cast(result.getType()); + if (isByteAligned(resultType)) { + return failure(); + } + + unsigned elementBits = resultType.getElementType().getIntOrFloatBitWidth(); + auto numElements = resultType.getNumElements(); + + SmallVector newShape = SmallVector(resultType.getShape()); + newShape.back() += (8 - (numElements * elementBits) % 8) / elementBits; + auto newTargetType = VectorType::get(newShape, resultType.getElementType()); + + // create a new vector load op with the new type + auto newVectorLoad = rewriter.create( + loadOp.getLoc(), newTargetType, loadOp.getBase(), loadOp.getIndices()); + + auto newNumElements = newTargetType.getNumElements(); + SmallVector zeroValues; + for (unsigned i = 0; i < newNumElements; ++i) { + zeroValues.push_back(false); + } + + // extract strided slice + SmallVector offsets, strides; + for (unsigned i = 0; i < resultType.getRank(); ++i) { + offsets.push_back(0); + strides.push_back(1); + } + + rewriter.replaceOpWithNewOp( + loadOp, newVectorLoad, offsets, resultType.getShape(), strides); + return success(); + } +}; + +struct PadSubbyteVectorStorePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::StoreOp storeOp, + PatternRewriter &rewriter) const final { + auto storeValue = storeOp.getValueToStore(); + auto valueType = mlir::cast(storeValue.getType()); + if (isByteAligned(valueType)) { + return failure(); + } + + auto target = storeOp.getBase(); + auto targetType = mlir::cast(target.getType()); + // check that the type size is byte aligned + auto elemType = valueType.getElementType(); + unsigned elementBits = valueType.getElementType().getIntOrFloatBitWidth(); + auto numElements = valueType.getNumElements(); + + SmallVector newShape = SmallVector(valueType.getShape()); + newShape.back() += (8 - (numElements * elementBits) % 8) / elementBits; + auto newValueType = VectorType::get(newShape, elemType); + + SmallVector strides; + SmallVector offsets; + for (unsigned i = 0; i < targetType.getRank(); ++i) { + strides.push_back(1); + offsets.push_back(0); + } + + // create an empty vector of the correct size + SmallVector zeroValues; + for (unsigned i = 0; i < newValueType.getNumElements(); ++i) { + zeroValues.push_back(false); + } + auto zeroVector = rewriter.create( + storeOp.getLoc(), DenseIntElementsAttr::get(newValueType, zeroValues)); + + auto extendedOp = rewriter.create( + storeOp->getLoc(), storeValue, zeroVector, offsets, strides); + + // create a mask and use masked store: + SmallVector maskShape; + for (auto dim : valueType.getShape()) { + maskShape.push_back( + rewriter.create(storeOp.getLoc(), dim)); + } + auto mask = rewriter.create(storeOp.getLoc(), + newValueType, maskShape); + + rewriter.replaceOpWithNewOp( + storeOp, target, storeOp.getIndices(), mask, extendedOp); + return success(); + } +}; + +static void populateSubbyteTypeHandlingPatterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -133,6 +329,7 @@ struct EmulateNarrowTypePass final affine::AffineDialect, IREE::HAL::HALDialect>(opLegalCallback); RewritePatternSet patterns(ctx); + populateSubbyteTypeHandlingPatterns(patterns); arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns); memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns); populateIREEResolveExtractStridedMetadataPatterns(ctx, patterns); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/emulate_narrow_type_optional.mlir b/compiler/src/iree/compiler/Codegen/Common/test/emulate_narrow_type_optional.mlir new file mode 100644 index 0000000000000..66ed8fdc6edba --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/emulate_narrow_type_optional.mlir @@ -0,0 +1,33 @@ +// RUN: iree-opt --split-input-file --iree-codegen-emulate-narrow-type %s | FileCheck %s + +#pipeline_layout = #hal.pipeline.layout +]> + +func.func @i1_datatype_emulation() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<8xi1, strided<[1], offset: ?>> + %3 = vector.load %0[%c0] : memref<8xi1, strided<[1], offset: ?>>, vector<6xi1> + %4 = vector.load %0[%c0] : memref<8xi1, strided<[1], offset: ?>>, vector<6xi1> + %5 = arith.addi %3, %4 : vector<6xi1> + vector.store %5, %0[%c0] : memref<8xi1, strided<[1], offset: ?>>, vector<6xi1> + return +} +// CHECK-LABEL: @i1_datatype_emulation + + +// CHECK: %[[EMU_LOAD:.+]] = vector.load +// CHECK-SAME: vector<1xi8> +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMU_LOAD]] +// CHECK-SAME: vector<1xi8> to vector<8xi1> +// CHECK: vector.extract_strided_slice %[[BITCAST]] +// CHECK-SAME: vector<8xi1> to vector<6xi1> + +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice +// CHECK-SAME: vector<6xi1> into vector<8xi1> +// CHECK: vector.create_mask +// CHECK-SAME: vector<8xi1> + +// CHECK: vector.maskedstore +// CHECK-SAME: vector<1xi1>, vector<1xi8> + diff --git a/compiler/src/iree/compiler/Codegen/Common/test/subbyte_vectorize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/subbyte_vectorize.mlir new file mode 100644 index 0000000000000..0233c1ff2ea68 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/subbyte_vectorize.mlir @@ -0,0 +1,31 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-generic-vectorization{vectorize-padding=true}))" --split-input-file %s | FileCheck %s + +func.func @test_subbyte_6_i1() attributes {translation_info = #iree_codegen.translation_info} { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [6], strides = [1] : !flow.dispatch.tensor> -> tensor<6xi1> + %4 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [6], strides = [1] : !flow.dispatch.tensor> -> tensor<6xi1> + %5 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [6], strides = [1] : !flow.dispatch.tensor> -> tensor<6xi1> + %6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%4, %5 : tensor<6xi1>, tensor<6xi1>) outs(%3 : tensor<6xi1>) attrs = {lowering_config = #iree_codegen.lowering_config} { + ^bb0(%in: i1, %in_0: i1, %out: i1): + %7 = arith.addi %in, %in_0 : i1 + linalg.yield %7 : i1 + } -> tensor<6xi1> + flow.dispatch.tensor.store %6, %2, offsets = [0], sizes = [6], strides = [1] : tensor<6xi1> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: @test_subbyte_6_i1 + +// CHECK: %[[TR1:.+]] = vector.transfer_read %[[.+]][%c0], %false : tensor<6xi1>, vector<8xi1> +// CHECK: %[[ESS1:.+]] = vector.extract_strided_slice %[[TR1]] {offsets = [0], sizes = [6], strides = [1]} : vector<8xi1> to vector<6xi1> + +// CHECK: %[[TR2:.+]] = vector.transfer_read %[[.+]][%c0], %false : tensor<6xi1>, vector<8xi1> +// CHECK: %[[ESS2:.+]] = vector.extract_strided_slice %[[TR2]] {offsets = [0], sizes = [6], strides = [1]} : vector<8xi1> to vector<6xi1> + +// CHECK: %[[ISS:.+]] = vector.insert_strided_slice +// CHECK-SAME: {offsets = [0], strides = [1]} : vector<6xi1> into vector<8xi1> +// CHECK: vector.transfer_write %[[ISS]], %[[.+]][%c0] {in_bounds = [true]} : vector<8xi1>, tensor<6xi1> + diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index c68c905c14d15..df7ea63666989 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -2942,6 +2942,16 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn, } } + // check to make sure the innermost tile size times element size is multiple + // of byte + auto elementTypeSize = + cast(rootOperation->getResultTypes().front()) + .getElementType() + .getIntOrFloatBitWidth(); + auto innermostTileSize = commonVecTileSizes.back(); + commonVecTileSizes.back() = + llvm::alignTo(innermostTileSize * elementTypeSize, 8) / elementTypeSize; + // Set the lowering configs with new tile sizes. for (auto op : computeOps) { int numLoops = cast(op).getLoopIteratorTypes().size(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_vectorize_nd_extract_tests.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_vectorize_nd_extract_tests.mlir index f6782a5f04e75..c8cb247fd2a81 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_vectorize_nd_extract_tests.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_vectorize_nd_extract_tests.mlir @@ -4,7 +4,7 @@ #hal.pipeline.binding, #hal.pipeline.binding ]> -#executable_target_system_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "system-elf-riscv_64", {cpu = "generic-rv64", cpu_features = "+m,+a,+f,+d,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 64 : index, target_triple = "riscv64"}> +#executable_target_system_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "system-elf-riscv_64", {cpu = "generic-rv64", cpu_features = "+m,+a,+f,+d,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 4 : index, target_triple = "riscv64"}> #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1) -> (d0 + d1 * 257)> func.func @main_dispatch_77_generic_1x257x257x21() attributes {hal.executable.target = #executable_target_system_elf_riscv_64_} { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 389eeadb75a9f..9d5f57cb99135 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -202,7 +202,8 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, Value maskVal = args[0]; // TODO: Replace bool mask condition once treated as i1 (instead of i8) - if (maskVal.getType().isInteger()) { + auto maskValType = maskVal.getType(); + if (maskValType.isInteger() && !maskValType.isInteger(1)) { maskVal = b.create(loc, builder.getI1Type(), maskVal); maskVal = b.create(loc, maskVal, zero, negInf); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index 31d61516e3eb2..fbb12ef861643 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -937,6 +937,21 @@ static bool insertBindingOp(BlockArgument arg, } } + // align tensor type to multiple of 8 bits: + auto rankedTensorType = tensorType.asRankedTensorType(); + auto elementSize = rankedTensorType.getElementType().getIntOrFloatBitWidth(); + auto typeSize = tensorType.getNumElements() * elementSize; + + if (typeSize * elementSize % 8 != 0) { + SmallVector newShape(rankedTensorType.getShape()); + newShape.back() = llvm::alignTo(newShape.back(), 8 / elementSize); + + auto newTensorType = IREE::Flow::DispatchTensorType::get( + tensorType.getAccess(), newShape, + rankedTensorType.getElementType(), rankedTensorType.getEncoding()); + tensorType = newTensorType; + } + auto subspanOp = builder.create( arg.getLoc(), tensorType, arg, zero, dynamicDims); arg.replaceAllUsesExcept(subspanOp.getResult(), subspanOp); diff --git a/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp b/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp index a7fa1c3f7b363..0e558b3b7755b 100644 --- a/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp +++ b/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp @@ -15,7 +15,16 @@ namespace mlir::iree_compiler { +llvm::cl::opt clEnableI1Support( + "iree-enable-i1", + llvm::cl::desc("enable i1 data type codegen"), + llvm::cl::init(false)); + bool needToPackSubByteElementBitWidth(unsigned bitWidth) { + // Enable i1 support if requested. + if (clEnableI1Support) { + return bitWidth == 1; + } // Require the original bit width to be some power of two for now to avoid // trickiness and weirdness of packing and cross-byte access. // Also disallow boolean values for now--they may require separate interface @@ -99,6 +108,11 @@ Value calculateStorageElementCountInBytes(Location loc, } } + // make sure the last dimension is byte aligned. + if (needToPackSubByteElementBitWidth(elementBits)) { + paddedShape.back() = llvm::alignTo(paddedShape.back(), 8 / elementBits); + } + for (unsigned i = 0; i < shapedType.getRank(); ++i) { if (!shapedType.isDynamicDim(i)) staticCount *= paddedShape[i];