Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RVV] Optimize Generic RVV Matmul codegen #18986

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ static llvm::cl::opt<bool> clDisableArmSMETiling(
"target (i.e., when the +sme feature flag is present)"),
llvm::cl::init(false));

static llvm::cl::opt<bool> clEnableRiscvAggressiveDist(
"iree-llvmcpu-riscv-aggressive-distribution",
llvm::cl::desc(
"Enable aggressive method for distribution tile size. "
"It is only applied for linalg contraction ops now. "
"If distConfig.minTileSizes[i] >= distConfig.maxTileSizes[i], "
"set distConfig.maxTileSizes[i] to 2 * distConfig.minTileSizes[i]."),
llvm::cl::init(false));

using IREE::Codegen::DispatchLoweringPassPipeline;

// Encodes the pre-processing strategy to be applied on a Linalg operation
Expand Down Expand Up @@ -1289,6 +1298,62 @@ static void getMatmulVectorSizesUsingFullVectorHeuristics(
sizes[1] = std::max<int64_t>(sizes[1], minNumElements);
}

/// Utility to compute the tile sizes for RISC-V Vector.
/// For now, it only supports nonWideningLinalgElementType float.
/// TileSize is set to m = 7, n = maxNumberElementsForLMUL4, and k = 1.
///
/// Example: for an pure f32-matmul and a 512-bit vector register.
/// nativeVectorSize is equal to VLEN * LMUL2 / 8, so it's 128.
/// maxNumberElementsForLMUL4 = 128 * 2 * 8 / 32 = 64.
///
/// TODO: Currently it only supports for nonWideningLinalgElementType.
static void
getMatmulRISCVVectorSizes(mlir::FunctionOpInterface entryPointFn,
linalg::LinalgOp op, int64_t vectorSize,
SmallVectorImpl<int64_t> &sizes,
SmallVectorImpl<bool> &scalableSizeFlags) {
if (sizes.empty())
getDefaultMatmulVectorSizes(op, vectorSize, sizes, scalableSizeFlags);
// TODO: support widening matmul.
// Determines n dimension tile size with VLEN for
// nonWideningLinalgElementType.
FailureOr<Type> elementType = nonWideningLinalgElementType(op);
if (failed(elementType))
return;

// nativeVectorSize is cacluated with VLEN and LMUL=2.
int64_t nativeVectorSize = getNativeVectorSizeInBytes(entryPointFn);
int64_t elementSize;
if (elementType->isF16()) {
elementSize = 16;
} else if (elementType->isF32()) {
elementSize = 32;
} else if (elementType->isF64()) {
elementSize = 64;
} else {
// TODO: support int data type
return;
}
// Use 7 x lmul4 to fully utilize vector registers.
sizes[0] = 7;
// Calculate tile size for the main vector dimension (N).
constexpr int64_t kByteSizeInBits = 8;
int64_t maxNumberElementsForLMUL4 =
(nativeVectorSize * 2 * kByteSizeInBits) / elementSize;
sizes[1] = maxNumberElementsForLMUL4;
sizes[2] = 1;
FailureOr<linalg::ContractionDimensions> cDims =
linalg::inferContractionDims(op);
if (failed(cDims))
return;
ArrayRef<int64_t> lhsShape = op.getShape(op.getDpsInputOperand(0));
// If m = 1, set tile size to 1 x lmul8
if (lhsShape[cDims->m[0]] == 1) {
sizes[0] = 1;
sizes[1] *= 2;
}
}

/// Utility to compute the tile sizes for AArch64 SME. Unlike other targets, the
/// tile sizes picked here must exactly match multiples of the SME hardware
/// virtual tiles, as there is currently no support for lowering non-standard
Expand Down Expand Up @@ -1354,6 +1419,16 @@ getMatmulVectorSizes(mlir::FunctionOpInterface entryPointFn,
}
}

if (isRISCV(targetAttr) && hasAnyVFeature(targetAttr)) {
// Use default tile size for matmul_transpose_b &
// batch_matmul_transpose_b to avoid performance drop.
if (!isa<linalg::MatmulTransposeBOp, linalg::BatchMatmulTransposeBOp>(op)) {
// Try to maximize the vector register utilization rate for matmul.
getMatmulRISCVVectorSizes(entryPointFn, op, vectorSize, matmulTileSizes,
matmulScalableFlags);
}
}

// If tile sizes were not computed by previous heuristics, use default
// hard-coded tile sizes.
if (matmulTileSizes.empty()) {
Expand Down Expand Up @@ -1494,6 +1569,25 @@ setRootConfig(mlir::FunctionOpInterface entryPointFn,
int64_t minTileSize = cacheTileSize != 0 ? cacheTileSize : vecTileSize;
distConfig.minTileSizes.push_back(minTileSize);
}
// FIXME: Apply maxTileSize modification for all targets.
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(entryPointFn);
if (isRISCV(targetAttr) && hasAnyVFeature(targetAttr)) {
LLVM_DEBUG(KD_DBGS() << "RISC-V Aggressive Distribution: "
<< clEnableRiscvAggressiveDist << "\n");
for (auto loopNum :
llvm::seq<unsigned>(static_cast<unsigned>(isBM), numLoops)) {
if (clEnableRiscvAggressiveDist) {
if (distConfig.maxTileSizes[loopNum] <=
distConfig.minTileSizes[loopNum]) {
distConfig.maxTileSizes[loopNum] =
2 * distConfig.minTileSizes[loopNum];
}
} else {
distConfig.maxTileSizes[loopNum] = std::max(
distConfig.maxTileSizes[loopNum], distConfig.minTileSizes[loopNum]);
}
}
}
SmallVector<int64_t> distTileSizes =
getDefaultDistributedLevelTileSizes(linalgOp, distConfig);

Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ bool hasZve64xFeature(IREE::HAL::ExecutableTargetAttr targetAttr) {
return hasFeature(targetAttr, "+zve64x");
}

bool hasAnyVFeature(IREE::HAL::ExecutableTargetAttr targetAttr) {
return hasVFeature(targetAttr) || hasZve32xFeature(targetAttr) ||
hasZve32fFeature(targetAttr) || hasZve64xFeature(targetAttr) ||
hasFeature(targetAttr, "+zve64f") || hasFeature(targetAttr, "+zve64d");
}

bool hasAnySVEFeature(IREE::HAL::ExecutableTargetAttr targetAttr) {
return hasFeature(targetAttr, "+sve") || hasFeature(targetAttr, "+sve2") ||
hasFeature(targetAttr, "+v9a");
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ bool hasZve32fFeature(IREE::HAL::ExecutableTargetAttr targetAttr);
/// Returns true if the 'targetAttr' contains '+zve64x' in its cpu features.
bool hasZve64xFeature(IREE::HAL::ExecutableTargetAttr targetAttr);

/// Returns true if the 'targetAttr' contains any riscv vector feature in its
/// cpu features.
bool hasAnyVFeature(IREE::HAL::ExecutableTargetAttr targetAttr);

/// Returns true if the 'targetAttr' contains '+sve' or '+sve2' in its cpu
/// features or any other feature flag that includes them.
bool hasAnySVEFeature(IREE::HAL::ExecutableTargetAttr targetAttr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ iree_lit_test_suite(
"pipeline_pad_conv_tests.mlir",
"pipeline_pad_tests.mlir",
"pipeline_peel_and_vectorize_tests.mlir",
"pipeline_riscv_aggressive_distribution_tests.mlir",
"pipeline_split_reduction_tests.mlir",
"pipeline_tests.mlir",
"pipeline_transpose_avx2_tests.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ iree_lit_test_suite(
"pipeline_pad_conv_tests.mlir"
"pipeline_pad_tests.mlir"
"pipeline_peel_and_vectorize_tests.mlir"
"pipeline_riscv_aggressive_distribution_tests.mlir"
"pipeline_split_reduction_tests.mlir"
"pipeline_tests.mlir"
"pipeline_transpose_avx2_tests.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: iree-opt --iree-llvmcpu-riscv-aggressive-distribution=true --pass-pipeline='builtin.module(iree-llvmcpu-select-lowering-strategy, func.func(iree-llvmcpu-lower-executable-target))' --split-input-file %s | FileCheck %s

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>

#executable_target_embedded_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-riscv_64", {cpu_features = "+m,+a,+f,+d,+zvl1024b,+v", data_layout = "e-m:e-p:64:64-i64:64-i256:256-n32:64-S256", native_vector_size = 256 : index, target_triple = "riscv64-unknown-unknown-eabi-elf"}>
builtin.module {
func.func @f32_rvv_matmul() attributes {hal.executable.target = #executable_target_embedded_elf_riscv_64_} {
%cst = arith.constant 0.0 : f32
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor<readonly:tensor<384x512xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor<readonly:tensor<512x256xf32>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : !flow.dispatch.tensor<writeonly:tensor<384x256xf32>>
%lhs = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<384x512xf32>> -> tensor<384x512xf32>
%rhs = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<512x256xf32>> -> tensor<512x256xf32>
%init = tensor.empty() : tensor<384x256xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<384x256xf32>) -> tensor<384x256xf32>
%res = linalg.matmul ins(%lhs, %rhs : tensor<384x512xf32>, tensor<512x256xf32>) outs(%fill : tensor<384x256xf32>) -> tensor<384x256xf32>
flow.dispatch.tensor.store %res, %2, offsets = [0, 0], sizes = [384, 256], strides = [1, 1] : tensor<384x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<384x256xf32>>
return
}
}
// CHECK-LABEL: func.func @f32_rvv_matmul(
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[c128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[c256:.+]] = arith.constant 256 : index
// CHECK-DAG: %[[c512:.+]] = arith.constant 512 : index
// CHECK: scf.for {{.*}} step %[[c7]]
// CHECK: scf.for {{.*}} step %[[c128]]
// CHECK: scf.for {{.*}} step %[[c1]]
// CHECK-COUNT-7: vector.fma
// CHECK-COUNT-7: vector.store
// CHECK: scf.for {{.*}} step %[[c128]]
// CHECK: scf.for {{.*}} step %[[c1]]
// CHECK-COUNT-4: vector.fma
// CHECK-COUNT-4: vector.store
Loading