Skip to content

Commit 6e52a12

Browse files
authored
[mlir][vector] Create VectorToLLVMDialectInterface (llvm#121440)
Create `VectorToLLVMDialectInterface` which allows automatic conversion discovery by generic `--convert-to-llvm` pass. This only covers final dialect conversion step and not any previous preparation steps. Also, currently there is no way to pass any additional parameters through this conversion interface, but most users using default parameters anyway.
1 parent 585b75e commit 6e52a12

File tree

5 files changed

+46
-0
lines changed

5 files changed

+46
-0
lines changed

mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ void populateVectorToLLVMConversionPatterns(
2424
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2525
bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
2626

27+
namespace vector {
28+
void registerConvertVectorToLLVMInterface(DialectRegistry &registry);
29+
}
2730
} // namespace mlir
2831

2932
#endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_

mlir/include/mlir/InitAllExtensions.h

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
2727
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
2828
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
29+
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
2930
#include "mlir/Dialect/AMX/Transforms.h"
3031
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
3132
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
@@ -76,6 +77,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
7677
registerConvertAMXToLLVMInterface(registry);
7778
gpu::registerConvertGpuToLLVMInterface(registry);
7879
NVVM::registerConvertGpuToNVVMInterface(registry);
80+
vector::registerConvertVectorToLLVMInterface(registry);
7981

8082
// Register all transform dialect extensions.
8183
affine::registerTransformDialectExtension(registry);

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
1010

1111
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1213
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
1314
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1415
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
@@ -1942,3 +1943,27 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns(
19421943
patterns.add<VectorMatmulOpConversion>(converter);
19431944
patterns.add<VectorFlatTransposeOpConversion>(converter);
19441945
}
1946+
1947+
namespace {
1948+
struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1949+
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
1950+
void loadDependentDialects(MLIRContext *context) const final {
1951+
context->loadDialect<LLVM::LLVMDialect>();
1952+
}
1953+
1954+
/// Hook for derived dialect interface to provide conversion patterns
1955+
/// and mark dialect legal for the conversion target.
1956+
void populateConvertToLLVMConversionPatterns(
1957+
ConversionTarget &target, LLVMTypeConverter &typeConverter,
1958+
RewritePatternSet &patterns) const final {
1959+
populateVectorToLLVMConversionPatterns(typeConverter, patterns);
1960+
}
1961+
};
1962+
} // namespace
1963+
1964+
void mlir::vector::registerConvertVectorToLLVMInterface(
1965+
DialectRegistry &registry) {
1966+
registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
1967+
dialect->addInterfaces<VectorToLLVMDialectInterface>();
1968+
});
1969+
}

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1515

16+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1617
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
1718
#include "mlir/Dialect/Arith/IR/Arith.h"
1819
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -428,6 +429,7 @@ void VectorDialect::initialize() {
428429
TransferWriteOp>();
429430
declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
430431
declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
432+
declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
431433
}
432434

433435
/// Materialize a single constant operation from a given attribute value with
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Most of the vector lowering is tested in vector-to-llvm.mlir, this file only for the interface smoke test
2+
// RUN: mlir-opt --convert-to-llvm="filter-dialects=vector" --split-input-file %s | FileCheck %s
3+
4+
func.func @bitcast_f32_to_i32_vector_0d(%arg0: vector<f32>) -> vector<i32> {
5+
%0 = vector.bitcast %arg0 : vector<f32> to vector<i32>
6+
return %0 : vector<i32>
7+
}
8+
9+
// CHECK-LABEL: @bitcast_f32_to_i32_vector_0d
10+
// CHECK-SAME: %[[ARG_0:.*]]: vector<f32>
11+
// CHECK: %[[VEC_F32_1D:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : vector<f32> to vector<1xf32>
12+
// CHECK: %[[VEC_I32_1D:.*]] = llvm.bitcast %[[VEC_F32_1D]] : vector<1xf32> to vector<1xi32>
13+
// CHECK: %[[VEC_I32_0D:.*]] = builtin.unrealized_conversion_cast %[[VEC_I32_1D]] : vector<1xi32> to vector<i32>
14+
// CHECK: return %[[VEC_I32_0D]] : vector<i32>

0 commit comments

Comments
 (0)