[mlir][vector] Add more patterns to Vector Linearize transformation#136193
[mlir][vector] Add more patterns to Vector Linearize transformation#136193
Conversation
|
@llvm/pr-subscribers-mlir-vector Author: Nishant Patel (nbpatel) ChangesThis PR adds linearization patterns for vector.load, vector.store, vector.create_mask, vector.splat, vector.insert_strided_slice & RegionBranchOps. This is because SPIR-V only supports 1D vectors. Patch is 40.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136193.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..6de5d0c5a101e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -27,6 +28,10 @@
using namespace mlir;
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+ // For BW-0, all operations are legal
+ if (targetBitWidth == 0) {
+ return false;
+ }
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
VectorType vecType = dyn_cast<VectorType>(resType);
@@ -273,6 +278,77 @@ struct LinearizeVectorExtractStridedSlice final
unsigned targetVectorBitWidth;
};
+/// This pattern linearizes the InsertStridedSliceOp by extracting rows from the
+/// source vector using ExtractStridedSliceOp and inserting them into the
+/// destination vector using InsertStridedSliceOp.
+/// Following,
+/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32>
+/// is converted to :
+/// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32>
+/// %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+/// %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
+struct LinearizeVectorInsertStridedSlice final
+ : public OpConversionPattern<vector::InsertStridedSliceOp> {
+ using OpConversionPattern<
+ vector::InsertStridedSliceOp>::OpConversionPattern;
+ LinearizeVectorInsertStridedSlice(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto srcTy = op.getSourceVectorType();
+ auto dstTy = op.getDestVectorType();
+
+ if (op.hasNonUnitStrides()) {
+ return rewriter.notifyMatchFailure(
+ op, "InsertStridedSliceOp linearization only supports unit strides.");
+ }
+
+ if (srcTy.getRank() != 2) {
+ return rewriter.notifyMatchFailure(
+ op, "InsertStridedSliceOp linearization only supports 2D source.");
+ }
+
+ if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ op, "InsertStridedSliceOp linerization only supports static shapes.");
+ }
+
+ auto dstShape = dstTy.getShape();
+ auto dstStrides = dstShape.drop_front().vec();
+ dstStrides.push_back(1);
+ int64_t linearizedOffset = 0;
+ for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), dstStrides)) {
+ linearizedOffset += getConstantIntValue(off).value() * stride;
+ }
+
+ // extracts a row from source, and insert it into the destination
+ auto srcShape = srcTy.getShape();
+ Value dstValue = adaptor.getDest();
+ for (auto i = 0; i < srcShape[0]; i++) {
+ auto srcOffset = i * srcShape[1];
+ auto value = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, adaptor.getValueToStore(), srcOffset, srcShape[1], 1);
+
+ auto dstOffset = linearizedOffset + i * dstShape.back();
+ dstValue = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, value, dstValue, dstOffset, 1);
+ }
+
+ rewriter.replaceOp(op, dstValue);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
/// This pattern converts the ShuffleOp that works on nD (n > 1)
/// vectors to a ShuffleOp that works on linearized vectors.
/// Following,
@@ -369,6 +445,11 @@ struct LinearizeVectorExtract final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ // Skip if result is not a vector type
+ if (!isa<VectorType>(extractOp.getType()))
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalar extract is not supported.");
+
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
if (!dstTy)
return rewriter.notifyMatchFailure(extractOp,
@@ -531,12 +612,312 @@ struct LinearizeVectorBitCast final
unsigned targetVectorBitWidth;
};
+/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
+/// that works on a linearized vector.
+/// Following,
+/// vector.load %base[%indices] : vector<4x4xf32>
+/// is converted to :
+/// %result = arith.constant dense<0.0> : vector<4x4xf32>
+/// %slice_0 = vector.load %base[%indices] : vector<4xf32>
+/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
+/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+/// ...
+/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
+/// them into the result vector. The pattern currently supports only 2D vectors
+struct LinearizeVectorLoad final
+ : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
+
+ LinearizeVectorLoad(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = loadOp->getLoc();
+ auto vecType = loadOp.getVectorType();
+ auto shape = vecType.getShape();
+
+ if (shape.size() != 2) {
+ return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+ }
+ auto unrollCount = shape[0];
+ auto vecSize = shape[1];
+ auto newVecType =
+ VectorType::get({vecSize}, vecType.getElementType());
+
+ llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+ Value xBaseIndex = indices[0];
+
+ // Construct the 2D vector.
+ Value resultVec = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(vecType));
+ // Emit unrolled loads for each 1D vector slice.
+ for (auto i = 0; i < unrollCount; i++) {
+ Value xIndex = xBaseIndex;
+ if (i) {
+ auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ xIndex =
+ rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ }
+ indices[0] = xIndex;
+ auto vec = rewriter.create<vector::LoadOp>(
+ loc, newVecType, adaptor.getBase(), indices);
+ resultVec =
+ rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+ }
+
+ rewriter.replaceOp(loadOp, resultVec);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
+/// that works on a linearized vector.
+/// Following,
+/// vector.store %source, %base[%indices] : vector<4x4xf32>
+/// is converted to :
+/// %slice_0 = vector.extract %source[0] : vector<4xf32>
+/// vector.store %slice_0, %base[%indices] : vector<4xf32>
+/// %slice_1 = vector.extract %source[1] : vector<4xf32>
+/// vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+/// ...
+/// This unrolls the 2D vector store into multiple 1D vector stores by extracting
+/// slices from the source vector and storing them into the destination.
+/// The pattern currently supports only 2D vectors
+struct LinearizeVectorStore final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
+
+ LinearizeVectorStore(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = storeOp->getLoc();
+ auto vecType = storeOp.getVectorType();
+ auto shape = vecType.getShape();
+
+ if (shape.size() != 2) {
+ return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+ }
+
+ auto unrollCount = shape[0];
+ llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+ Value xBaseIndex = indices[0];
+
+ auto vec = rewriter.create<vector::ShapeCastOp>(
+ loc, vecType, adaptor.getValueToStore());
+
+ for (auto i = 0; i < unrollCount; i++) {
+ auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
+ Value xIndex = xBaseIndex;
+ if (i) {
+ auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ xIndex =
+ rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ }
+ indices[0] = xIndex;
+ rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
+ indices);
+ }
+ rewriter.eraseOp(storeOp);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+/// vector.splat %value : vector<4x4xf32>
+/// is converted to:
+/// %out_1d = vector.splat %value : vector<16xf32>
+/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// It ensures that the operation is compatible with the target vector
+/// bit width and replaces the original operation with a new SplatOp
+/// that operates on the converted type.
+struct LinearizeVectorSplat final
+ : public OpConversionPattern<vector::SplatOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorSplat(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+ rewriter.replaceOpWithNewOp<vector::SplatOp>(
+ splatOp, adaptor.getInput(), dstTy);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. It ensures that the operation is compatible with the
+/// target vector bit width and replaces the original operation with a new
+/// CreateMaskOp that operates on the converted type. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+/// vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+/// %out_1d = vector.create_mask %dims : vector<4xi1>
+/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+ createMaskOp, dstTy, adaptor.getOperands().back());
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts operations implementing the RegionBranchOpInterface
+/// to ensure compatibility with linearized vector types. It updates the
+/// operands, result types, and region types (block arguments and yields) to
+/// match the converted types. Additionally, it processes yields within each
+/// region to ensure that the types of yielded values are compatible with the
+/// target vector bit width. If the result types of the operation are updated,
+/// shape cast operations are inserted to maintain compatibility with the
+/// original types. This pattern ensures that operations with regions are
+/// properly linearized and remain valid after type conversion.
+struct LinearizeRegionBranchOp final
+ : public OpInterfaceConversionPattern<RegionBranchOpInterface> {
+ using OpInterfaceConversionPattern<
+ RegionBranchOpInterface>::OpInterfaceConversionPattern;
+
+ LinearizeRegionBranchOp(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(RegionBranchOpInterface op,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto converter = getTypeConverter();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.startOpModification(op);
+
+ llvm::SmallVector<Type> convertedTypes;
+ for (Type ty : op->getResultTypes()) {
+ convertedTypes.push_back(converter->convertType(ty));
+ }
+
+ if (convertedTypes == op->getResultTypes() &&
+ op->getOperands() == operands) {
+ return failure();
+ }
+
+ op->setOperands(operands);
+
+ // Convert region types (block arguments and yields)
+ for (Region ®ion : op->getRegions()) {
+ if (failed(rewriter.convertRegionTypes(®ion, *converter))) {
+ return failure();
+ }
+
+ // Process yields within each region
+ for (Block &block : region) {
+ if (auto *terminator = block.getTerminator()) {
+ for (OpOperand &yieldOperand : terminator->getOpOperands()) {
+ Value value = yieldOperand.get();
+ Type type = value.getType();
+ if (!converter->isLegal(type)) {
+ Type newTy = converter->convertType(type);
+ rewriter.setInsertionPoint(terminator);
+ Value newValue =
+ rewriter.create<vector::ShapeCastOp>(loc, newTy, value);
+ yieldOperand.set(newValue);
+ }
+ }
+ }
+ }
+ }
+
+ // Update result types
+ rewriter.setInsertionPointAfter(op);
+ llvm::SmallVector<Value> newResults;
+ for (Value result : op->getResults()) {
+ Type oldTy = result.getType();
+ if (!converter->isLegal(oldTy)) {
+ Type newTy = converter->convertType(oldTy);
+ result.setType(newTy);
+ Operation *castOp =
+ rewriter.create<vector::ShapeCastOp>(loc, oldTy, result);
+ result.replaceAllUsesExcept(castOp->getResult(0), castOp);
+ newResults.push_back(castOp->getResult(0));
+ } else {
+ newResults.push_back(result);
+ }
+ }
+
+ rewriter.finalizeOpModification(op);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth) {
+ typeConverter.addConversion([](Type type) -> Type { return type; });
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
return type;
@@ -555,9 +936,12 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
};
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
+ target.addLegalOp<mlir::vector::ShapeCastOp>();
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<vector::BitCastOp>(op) ||
+ if ((isa<vector::BitCastOp, vector::LoadOp,
+ vector::StoreOp, vector::CreateMaskOp,
+ RegionBranchOpInterface, vector::SplatOp>(op) ||
op->hasTrait<OpTrait::ConstantLike>() ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -568,7 +952,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
});
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+ LinearizeVectorBitCast, LinearizeVectorLoad,
+ LinearizeVectorStore, LinearizeVectorSplat,
+ LinearizeVectorCreateMask, LinearizeRegionBranchOp
+ >(typeConverter, patterns.getContext(),
targetBitWidth);
}
@@ -583,7 +970,21 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
.getRank() == 1)
: true;
});
+
+ target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
+ [=](vector::InsertStridedSliceOp op) -> bool {
+ if(isLessThanTargetBitWidth(op, targetBitWidth)) {
+ auto srcTy = op.getSourceVectorType();
+ auto dstTy = op.getDestVectorType();
+ if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
+ srcTy.hasStaticShape() && dstTy.hasStaticShape())
+ return false;
+ }
+ return true;
+ });
+
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
- LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
+ LinearizeVectorInsertStridedSlice>(
typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..e47e7c4a84d68 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -399,3 +399,338 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
return %1 : vector<[4]x4xf16>
}
+
+// -----
+// ALL-LABEL: test_vector_load
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>)
+func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+ // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+ // BW-128: %[[C1:.*]] = arith.constant 1 : index
+ // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+ // BW-128: %[[C2:.*]] = arith.constant 2 : index
+...
[truncated]
|
|
@llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesThis PR adds linearization patterns for vector.load, vector.store, vector.create_mask, vector.splat, vector.insert_strided_slice & RegionBranchOps. This is because SPIR-V only supports 1D vectors. Patch is 40.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136193.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..6de5d0c5a101e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -27,6 +28,10 @@
using namespace mlir;
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+ // For BW-0, all operations are legal
+ if (targetBitWidth == 0) {
+ return false;
+ }
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
VectorType vecType = dyn_cast<VectorType>(resType);
@@ -273,6 +278,77 @@ struct LinearizeVectorExtractStridedSlice final
unsigned targetVectorBitWidth;
};
+/// This pattern linearizes the InsertStridedSliceOp by extracting rows from the
+/// source vector using ExtractStridedSliceOp and inserting them into the
+/// destination vector using InsertStridedSliceOp.
+/// Following,
+/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32>
+/// is converted to :
+/// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32>
+/// %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+/// %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
+struct LinearizeVectorInsertStridedSlice final
+ : public OpConversionPattern<vector::InsertStridedSliceOp> {
+ using OpConversionPattern<
+ vector::InsertStridedSliceOp>::OpConversionPattern;
+ LinearizeVectorInsertStridedSlice(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto srcTy = op.getSourceVectorType();
+ auto dstTy = op.getDestVectorType();
+
+ if (op.hasNonUnitStrides()) {
+ return rewriter.notifyMatchFailure(
+ op, "InsertStridedSliceOp linearization only supports unit strides.");
+ }
+
+ if (srcTy.getRank() != 2) {
+ return rewriter.notifyMatchFailure(
+ op, "InsertStridedSliceOp linearization only supports 2D source.");
+ }
+
+ if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ op, "InsertStridedSliceOp linerization only supports static shapes.");
+ }
+
+ auto dstShape = dstTy.getShape();
+ auto dstStrides = dstShape.drop_front().vec();
+ dstStrides.push_back(1);
+ int64_t linearizedOffset = 0;
+ for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), dstStrides)) {
+ linearizedOffset += getConstantIntValue(off).value() * stride;
+ }
+
+ // extracts a row from source, and insert it into the destination
+ auto srcShape = srcTy.getShape();
+ Value dstValue = adaptor.getDest();
+ for (auto i = 0; i < srcShape[0]; i++) {
+ auto srcOffset = i * srcShape[1];
+ auto value = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, adaptor.getValueToStore(), srcOffset, srcShape[1], 1);
+
+ auto dstOffset = linearizedOffset + i * dstShape.back();
+ dstValue = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, value, dstValue, dstOffset, 1);
+ }
+
+ rewriter.replaceOp(op, dstValue);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
/// This pattern converts the ShuffleOp that works on nD (n > 1)
/// vectors to a ShuffleOp that works on linearized vectors.
/// Following,
@@ -369,6 +445,11 @@ struct LinearizeVectorExtract final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ // Skip if result is not a vector type
+ if (!isa<VectorType>(extractOp.getType()))
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalar extract is not supported.");
+
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
if (!dstTy)
return rewriter.notifyMatchFailure(extractOp,
@@ -531,12 +612,312 @@ struct LinearizeVectorBitCast final
unsigned targetVectorBitWidth;
};
+/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
+/// that works on a linearized vector.
+/// Following,
+/// vector.load %base[%indices] : vector<4x4xf32>
+/// is converted to :
+/// %result = arith.constant dense<0.0> : vector<4x4xf32>
+/// %slice_0 = vector.load %base[%indices] : vector<4xf32>
+/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
+/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+/// ...
+/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
+/// them into the result vector. The pattern currently supports only 2D vectors
+struct LinearizeVectorLoad final
+ : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
+
+ LinearizeVectorLoad(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = loadOp->getLoc();
+ auto vecType = loadOp.getVectorType();
+ auto shape = vecType.getShape();
+
+ if (shape.size() != 2) {
+ return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+ }
+ auto unrollCount = shape[0];
+ auto vecSize = shape[1];
+ auto newVecType =
+ VectorType::get({vecSize}, vecType.getElementType());
+
+ llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+ Value xBaseIndex = indices[0];
+
+ // Construct the 2D vector.
+ Value resultVec = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(vecType));
+ // Emit unrolled loads for each 1D vector slice.
+ for (auto i = 0; i < unrollCount; i++) {
+ Value xIndex = xBaseIndex;
+ if (i) {
+ auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ xIndex =
+ rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ }
+ indices[0] = xIndex;
+ auto vec = rewriter.create<vector::LoadOp>(
+ loc, newVecType, adaptor.getBase(), indices);
+ resultVec =
+ rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+ }
+
+ rewriter.replaceOp(loadOp, resultVec);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
+/// that works on a linearized vector.
+/// Following,
+/// vector.store %source, %base[%indices] : vector<4x4xf32>
+/// is converted to :
+/// %slice_0 = vector.extract %source[0] : vector<4xf32>
+/// vector.store %slice_0, %base[%indices] : vector<4xf32>
+/// %slice_1 = vector.extract %source[1] : vector<4xf32>
+/// vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+/// ...
+/// This unrolls the 2D vector store into multiple 1D vector stores by extracting
+/// slices from the source vector and storing them into the destination.
+/// The pattern currently supports only 2D vectors
+struct LinearizeVectorStore final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
+
+ LinearizeVectorStore(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = storeOp->getLoc();
+ auto vecType = storeOp.getVectorType();
+ auto shape = vecType.getShape();
+
+ if (shape.size() != 2) {
+ return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+ }
+
+ auto unrollCount = shape[0];
+ llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+ Value xBaseIndex = indices[0];
+
+ auto vec = rewriter.create<vector::ShapeCastOp>(
+ loc, vecType, adaptor.getValueToStore());
+
+ for (auto i = 0; i < unrollCount; i++) {
+ auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
+ Value xIndex = xBaseIndex;
+ if (i) {
+ auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ xIndex =
+ rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ }
+ indices[0] = xIndex;
+ rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
+ indices);
+ }
+ rewriter.eraseOp(storeOp);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+/// vector.splat %value : vector<4x4xf32>
+/// is converted to:
+/// %out_1d = vector.splat %value : vector<16xf32>
+/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// It ensures that the operation is compatible with the target vector
+/// bit width and replaces the original operation with a new SplatOp
+/// that operates on the converted type.
+struct LinearizeVectorSplat final
+ : public OpConversionPattern<vector::SplatOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorSplat(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+ rewriter.replaceOpWithNewOp<vector::SplatOp>(
+ splatOp, adaptor.getInput(), dstTy);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. It ensures that the operation is compatible with the
+/// target vector bit width and replaces the original operation with a new
+/// CreateMaskOp that operates on the converted type. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+/// vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+/// %out_1d = vector.create_mask %dims : vector<4xi1>
+/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+ createMaskOp, dstTy, adaptor.getOperands().back());
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts operations implementing the RegionBranchOpInterface
+/// to ensure compatibility with linearized vector types. It updates the
+/// operands, result types, and region types (block arguments and yields) to
+/// match the converted types. Additionally, it processes yields within each
+/// region to ensure that the types of yielded values are compatible with the
+/// target vector bit width. If the result types of the operation are updated,
+/// shape cast operations are inserted to maintain compatibility with the
+/// original types. This pattern ensures that operations with regions are
+/// properly linearized and remain valid after type conversion.
+struct LinearizeRegionBranchOp final
+ : public OpInterfaceConversionPattern<RegionBranchOpInterface> {
+ using OpInterfaceConversionPattern<
+ RegionBranchOpInterface>::OpInterfaceConversionPattern;
+
+ LinearizeRegionBranchOp(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(RegionBranchOpInterface op,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto converter = getTypeConverter();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.startOpModification(op);
+
+ llvm::SmallVector<Type> convertedTypes;
+ for (Type ty : op->getResultTypes()) {
+ convertedTypes.push_back(converter->convertType(ty));
+ }
+
+ if (convertedTypes == op->getResultTypes() &&
+ op->getOperands() == operands) {
+ return failure();
+ }
+
+ op->setOperands(operands);
+
+ // Convert region types (block arguments and yields)
+ for (Region ®ion : op->getRegions()) {
+ if (failed(rewriter.convertRegionTypes(®ion, *converter))) {
+ return failure();
+ }
+
+ // Process yields within each region
+ for (Block &block : region) {
+ if (auto *terminator = block.getTerminator()) {
+ for (OpOperand &yieldOperand : terminator->getOpOperands()) {
+ Value value = yieldOperand.get();
+ Type type = value.getType();
+ if (!converter->isLegal(type)) {
+ Type newTy = converter->convertType(type);
+ rewriter.setInsertionPoint(terminator);
+ Value newValue =
+ rewriter.create<vector::ShapeCastOp>(loc, newTy, value);
+ yieldOperand.set(newValue);
+ }
+ }
+ }
+ }
+ }
+
+ // Update result types
+ rewriter.setInsertionPointAfter(op);
+ llvm::SmallVector<Value> newResults;
+ for (Value result : op->getResults()) {
+ Type oldTy = result.getType();
+ if (!converter->isLegal(oldTy)) {
+ Type newTy = converter->convertType(oldTy);
+ result.setType(newTy);
+ Operation *castOp =
+ rewriter.create<vector::ShapeCastOp>(loc, oldTy, result);
+ result.replaceAllUsesExcept(castOp->getResult(0), castOp);
+ newResults.push_back(castOp->getResult(0));
+ } else {
+ newResults.push_back(result);
+ }
+ }
+
+ rewriter.finalizeOpModification(op);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth) {
+ typeConverter.addConversion([](Type type) -> Type { return type; });
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
return type;
@@ -555,9 +936,12 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
};
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
+ target.addLegalOp<mlir::vector::ShapeCastOp>();
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<vector::BitCastOp>(op) ||
+ if ((isa<vector::BitCastOp, vector::LoadOp,
+ vector::StoreOp, vector::CreateMaskOp,
+ RegionBranchOpInterface, vector::SplatOp>(op) ||
op->hasTrait<OpTrait::ConstantLike>() ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -568,7 +952,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
});
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+ LinearizeVectorBitCast, LinearizeVectorLoad,
+ LinearizeVectorStore, LinearizeVectorSplat,
+ LinearizeVectorCreateMask, LinearizeRegionBranchOp
+ >(typeConverter, patterns.getContext(),
targetBitWidth);
}
@@ -583,7 +970,21 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
.getRank() == 1)
: true;
});
+
+ target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
+ [=](vector::InsertStridedSliceOp op) -> bool {
+ if(isLessThanTargetBitWidth(op, targetBitWidth)) {
+ auto srcTy = op.getSourceVectorType();
+ auto dstTy = op.getDestVectorType();
+ if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
+ srcTy.hasStaticShape() && dstTy.hasStaticShape())
+ return false;
+ }
+ return true;
+ });
+
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
- LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
+ LinearizeVectorInsertStridedSlice>(
typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..e47e7c4a84d68 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -399,3 +399,338 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
return %1 : vector<[4]x4xf16>
}
+
+// -----
+// ALL-LABEL: test_vector_load
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>)
+func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+ // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+ // BW-128: %[[C1:.*]] = arith.constant 1 : index
+ // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+ // BW-128: %[[C2:.*]] = arith.constant 2 : index
+...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
@Hardcode84 @charithaintc @chencha3 Please take a look as well |
newling
left a comment
There was a problem hiding this comment.
Just a few drive by comments. I'm no expert on this, so please ignore my suggestions where not appropriate
|
|
||
| static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { | ||
| // For BW-0, all operations are legal | ||
| if (targetBitWidth == 0) |
There was a problem hiding this comment.
targetBitWidth = std::numeric_limits<unsigned>::max() is used in places, please consolidate.
There was a problem hiding this comment.
Apologies I wasn't clear in this comment. I meant to consolidate use of targetBitWidth = 0 and targetBitWidth = 'max'. Is this a workaround for adding 1 to std::numeric_limits<unsigned>::max() ?
I would like to commit #136581 which would mean this logic doesn't live here anymore and this comment wouldn't be relevant, could you please take a look at that?
There was a problem hiding this comment.
I briefly looked at it. I'm ok with that change, but can we commit this first if possible?
There was a problem hiding this comment.
I have reached a point where I need to land this to make progress with other development, I hope that's ok @nbpatel
It should be easy to absorb the changes (just remove code related to bitwidth)
There was a problem hiding this comment.
yes go ahead and merge it
| /// vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : | ||
| /// vector<4xf32> from vector<8xf32> %3 = vector.insert_strided_slice %2, %1 | ||
| /// {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32> | ||
| struct LinearizeVectorInsertStridedSlice final |
There was a problem hiding this comment.
Perhaps it's possible to reuse
to convert to shuffle?There was a problem hiding this comment.
I think since this is not a pass and just a bunch of patterns, users can decide how they want to lower from this point
There was a problem hiding this comment.
I would like it if this were a collection of patterns that the user could opt in or out of one-by-one, but currently there are only 2 APIS exposed to the user (populateVectorLinearizeShuffleLikeOpsPatterns and populateVectorLinearizeTypeConversionsAndLegality)
There was a problem hiding this comment.
I guess the users can run populateVectorLinearizeTypeConversionsAndLegality and populateVectorInsertExtractStridedSliceTransforms to convert to shuffle
| // Skip if result is not a vector type | ||
| if (!isa<VectorType>(extractOp.getType())) | ||
| return rewriter.notifyMatchFailure(extractOp, | ||
| "scalar extract is not supported."); |
There was a problem hiding this comment.
| "scalar extract is not supported."); | |
| "scalar extract is not supported, because ..."); |
might to helpful!
| unsigned targetVectorBitWidth; | ||
| }; | ||
|
|
||
| /// This pattern converts the LoadOp to a series of LoadOp & InsertOp |
| /// Following, | ||
| /// vector.load %base[%indices] : vector<4x4xf32> | ||
| /// is converted to : | ||
| /// %result = arith.constant dense<0.0> : vector<4x4xf32> |
There was a problem hiding this comment.
I wonder if it would be beneficial to flatten/linearize out all contiguous dimensions first. i.e. if the load is actually unstrided, like
%result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8x100xf32>
if this flattened to
%result = vector.load %flat_base[%i] : memref<10000xf32>, vector<800xf32>
the IR generated wouldn't be unrolled.
It seems to me like this is more unrolling than linearizing?
There was a problem hiding this comment.
This is a more general solution for unstrided/strided memrefs and we can always fuse the loads later on as an optimization
|
|
||
| static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { | ||
| // For BW-0, all operations are legal | ||
| if (targetBitWidth == 0) |
There was a problem hiding this comment.
Apologies I wasn't clear in this comment. I meant to consolidate use of targetBitWidth = 0 and targetBitWidth = 'max'. Is this a workaround for adding 1 to std::numeric_limits<unsigned>::max() ?
I would like to commit #136581 which would mean this logic doesn't live here anymore and this comment wouldn't be relevant, could you please take a look at that?
| /// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into | ||
| /// vector<4x4xf32> | ||
| /// is converted to : | ||
| /// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} |
There was a problem hiding this comment.
Please check formatting. Maybe clang-format off and clang-format on will be helpful?
|
Hi @nbpatel I think these patterns will all be useful for users of the vector dialect, but I have a few requests that I would like to make first:
I know they're large requests, my apologies for not suggesting them earlier in the process. I think your contributions will be significantly more useful if done this way. |
banach-space
left a comment
There was a problem hiding this comment.
Thanks!
This is adding 4 unrelated patterns - could you split this into independent PRs?
| auto loc = loadOp->getLoc(); | ||
| auto vecType = loadOp.getVectorType(); | ||
| auto shape = vecType.getShape(); |
There was a problem hiding this comment.
Could you spell out the types here and other places? For reference, here are LLVM's guidelines for using auto:
| // ----- | ||
| // ALL-LABEL: test_vector_load | ||
| // ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>) | ||
| func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> { |
There was a problem hiding this comment.
Please avoid using test in test function names - that's unnecessary noise. I appreciate that you are trying to follow the existing convention in this file, but IMO we should focus on encoding unique information. Here are our guidelines:
So, what makes this test unique?
| // DEFAULT: %[[C1:.*]] = arith.constant 1 : index | ||
| // BW-128: %[[C1:.*]] = arith.constant 1 : index | ||
| // DEFAULT: %[[C2:.*]] = arith.constant 2 : index | ||
| // BW-128: %[[C2:.*]] = arith.constant 2 : index | ||
| // DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16> | ||
| // BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16> | ||
| // DEFAULT: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
| // BW-128: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
| // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> | ||
| // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> | ||
| // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index | ||
| // BW-128: %[[C1_0:.*]] = arith.constant 1 : index | ||
| // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index | ||
| // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index | ||
| // DEFAULT: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
| // BW-128: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
| // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> | ||
| // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> | ||
| // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index | ||
| // BW-128: %[[C2_1:.*]] = arith.constant 2 : index | ||
| // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index | ||
| // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index | ||
| // DEFAULT: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
| // BW-128: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
| // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> | ||
| // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> | ||
| // DEFAULT: %[[C3:.*]] = arith.constant 3 : index | ||
| // BW-128: %[[C3:.*]] = arith.constant 3 : index | ||
| // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index | ||
| // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index | ||
| // DEFAULT: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
| // BW-128: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
| // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16> | ||
| // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16> | ||
| // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16> | ||
| // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16> | ||
| // DEFAULT: return %[[CAST]] : vector<4x4xf16> | ||
| // BW-128: return %[[CAST]] : vector<4x4xf16> | ||
|
|
||
| // BW-0: %[[C1:.*]] = arith.constant 1 : index | ||
| // BW-0: %[[C2:.*]] = arith.constant 2 : index | ||
| // BW-0: %[[LOAD:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16> | ||
| // BW-0: return %[[LOAD]] : vector<4x4xf16> |
There was a problem hiding this comment.
It is very hard to follow this. Could you follow the pre-existing convention and split DEFAULT and BW-128 and BW-0 blocks (as opposed to interleaving), Similar comment for other tests.
| /// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} | ||
| /// : vector<4xf32> into vector<16xf32> |
There was a problem hiding this comment.
| /// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} | |
| /// : vector<4xf32> into vector<16xf32> | |
| /// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} | |
| /// : vector<4xf32> into vector<16xf32> |
| /// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into | ||
| /// vector<4x4xf32> |
There was a problem hiding this comment.
| /// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into | |
| /// vector<4x4xf32> | |
| /// vector.insert_strided_slice %s, %d {offsets=[0, 0]} | |
| /// : vector<2x4xf32> into vector<4x4xf32> |
| if (srcTy.getRank() != 2) | ||
| return rewriter.notifyMatchFailure( | ||
| insertOp, | ||
| "InsertStridedSliceOp linearization only supports 2D source."); |
There was a problem hiding this comment.
any reason for supporting only 2D?
So load/store as one PR and splat, create_mask and insert_strided_slice each as independent PR? |
This PR is a breakdown [2 / 4] of the PR #136193 The PR adds linearization patterns for vector.splat.
This PR is a breakdown [2 / 4] of the PR llvm#136193 The PR adds linearization patterns for vector.splat.
This PR is a breakdown [2 / 4] of the PR llvm#136193 The PR adds linearization patterns for vector.splat.
This PR adds linearization patterns for vector.load, vector.store, vector.create_mask, vector.splat, vector.insert_strided_slice & RegionBranchOps. This is because SPIR-V only supports 1D vectors.