diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index 5111b7668958..4991f222d08b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -99,6 +99,15 @@ static llvm::cl::opt clDisableArmSMETiling( "target (i.e., when the +sme feature flag is present)"), llvm::cl::init(false)); +static llvm::cl::opt clEnableRiscvAggressiveDist( + "iree-llvmcpu-riscv-aggressive-distribution", + llvm::cl::desc( + "Enable aggressive method for distribution tile size. " + "It is only applied for linalg contraction ops now. " + "If distConfig.minTileSizes[i] >= distConfig.maxTileSizes[i], " + "set distConfig.maxTileSizes[i] to 2 * distConfig.minTileSizes[i]."), + llvm::cl::init(false)); + using IREE::Codegen::DispatchLoweringPassPipeline; // Encodes the pre-processing strategy to be applied on a Linalg operation @@ -1289,6 +1298,62 @@ static void getMatmulVectorSizesUsingFullVectorHeuristics( sizes[1] = std::max(sizes[1], minNumElements); } +/// Utility to compute the tile sizes for RISC-V Vector. +/// For now, it only supports nonWideningLinalgElementType float. +/// TileSize is set to m = 7, n = maxNumberElementsForLMUL4, and k = 1. +/// +/// Example: for an pure f32-matmul and a 512-bit vector register. +/// nativeVectorSize is equal to VLEN * LMUL2 / 8, so it's 128. +/// maxNumberElementsForLMUL4 = 128 * 2 * 8 / 32 = 64. +/// +/// TODO: Currently it only supports for nonWideningLinalgElementType. +static void +getMatmulRISCVVectorSizes(mlir::FunctionOpInterface entryPointFn, + linalg::LinalgOp op, int64_t vectorSize, + SmallVectorImpl &sizes, + SmallVectorImpl &scalableSizeFlags) { + if (sizes.empty()) + getDefaultMatmulVectorSizes(op, vectorSize, sizes, scalableSizeFlags); + // TODO: support widening matmul. + // Determines n dimension tile size with VLEN for + // nonWideningLinalgElementType. + FailureOr elementType = nonWideningLinalgElementType(op); + if (failed(elementType)) + return; + + // nativeVectorSize is cacluated with VLEN and LMUL=2. + int64_t nativeVectorSize = getNativeVectorSizeInBytes(entryPointFn); + int64_t elementSize; + if (elementType->isF16()) { + elementSize = 16; + } else if (elementType->isF32()) { + elementSize = 32; + } else if (elementType->isF64()) { + elementSize = 64; + } else { + // TODO: support int data type + return; + } + // Use 7 x lmul4 to fully utilize vector registers. + sizes[0] = 7; + // Calculate tile size for the main vector dimension (N). + constexpr int64_t kByteSizeInBits = 8; + int64_t maxNumberElementsForLMUL4 = + (nativeVectorSize * 2 * kByteSizeInBits) / elementSize; + sizes[1] = maxNumberElementsForLMUL4; + sizes[2] = 1; + FailureOr cDims = + linalg::inferContractionDims(op); + if (failed(cDims)) + return; + ArrayRef lhsShape = op.getShape(op.getDpsInputOperand(0)); + // If m = 1, set tile size to 1 x lmul8 + if (lhsShape[cDims->m[0]] == 1) { + sizes[0] = 1; + sizes[1] *= 2; + } +} + /// Utility to compute the tile sizes for AArch64 SME. Unlike other targets, the /// tile sizes picked here must exactly match multiples of the SME hardware /// virtual tiles, as there is currently no support for lowering non-standard @@ -1354,6 +1419,16 @@ getMatmulVectorSizes(mlir::FunctionOpInterface entryPointFn, } } + if (isRISCV(targetAttr) && hasAnyVFeature(targetAttr)) { + // Use default tile size for matmul_transpose_b & + // batch_matmul_transpose_b to avoid performance drop. + if (!isa(op)) { + // Try to maximize the vector register utilization rate for matmul. + getMatmulRISCVVectorSizes(entryPointFn, op, vectorSize, matmulTileSizes, + matmulScalableFlags); + } + } + // If tile sizes were not computed by previous heuristics, use default // hard-coded tile sizes. if (matmulTileSizes.empty()) { @@ -1494,6 +1569,25 @@ setRootConfig(mlir::FunctionOpInterface entryPointFn, int64_t minTileSize = cacheTileSize != 0 ? cacheTileSize : vecTileSize; distConfig.minTileSizes.push_back(minTileSize); } + // FIXME: Apply maxTileSize modification for all targets. + auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(entryPointFn); + if (isRISCV(targetAttr) && hasAnyVFeature(targetAttr)) { + LLVM_DEBUG(KD_DBGS() << "RISC-V Aggressive Distribution: " + << clEnableRiscvAggressiveDist << "\n"); + for (auto loopNum : + llvm::seq(static_cast(isBM), numLoops)) { + if (clEnableRiscvAggressiveDist) { + if (distConfig.maxTileSizes[loopNum] <= + distConfig.minTileSizes[loopNum]) { + distConfig.maxTileSizes[loopNum] = + 2 * distConfig.minTileSizes[loopNum]; + } + } else { + distConfig.maxTileSizes[loopNum] = std::max( + distConfig.maxTileSizes[loopNum], distConfig.minTileSizes[loopNum]); + } + } + } SmallVector distTileSizes = getDefaultDistributedLevelTileSizes(linalgOp, distConfig); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp index b8d14dee1782..eefa70a14d0e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp @@ -45,6 +45,12 @@ bool hasZve64xFeature(IREE::HAL::ExecutableTargetAttr targetAttr) { return hasFeature(targetAttr, "+zve64x"); } +bool hasAnyVFeature(IREE::HAL::ExecutableTargetAttr targetAttr) { + return hasVFeature(targetAttr) || hasZve32xFeature(targetAttr) || + hasZve32fFeature(targetAttr) || hasZve64xFeature(targetAttr) || + hasFeature(targetAttr, "+zve64f") || hasFeature(targetAttr, "+zve64d"); +} + bool hasAnySVEFeature(IREE::HAL::ExecutableTargetAttr targetAttr) { return hasFeature(targetAttr, "+sve") || hasFeature(targetAttr, "+sve2") || hasFeature(targetAttr, "+v9a"); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.h index 1308f87e8169..bf88f28215c4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.h @@ -32,6 +32,10 @@ bool hasZve32fFeature(IREE::HAL::ExecutableTargetAttr targetAttr); /// Returns true if the 'targetAttr' contains '+zve64x' in its cpu features. bool hasZve64xFeature(IREE::HAL::ExecutableTargetAttr targetAttr); +/// Returns true if the 'targetAttr' contains any riscv vector feature in its +/// cpu features. +bool hasAnyVFeature(IREE::HAL::ExecutableTargetAttr targetAttr); + /// Returns true if the 'targetAttr' contains '+sve' or '+sve2' in its cpu /// features or any other feature flag that includes them. bool hasAnySVEFeature(IREE::HAL::ExecutableTargetAttr targetAttr); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel index 1cd905233264..44616988e113 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel @@ -41,6 +41,7 @@ iree_lit_test_suite( "pipeline_pad_conv_tests.mlir", "pipeline_pad_tests.mlir", "pipeline_peel_and_vectorize_tests.mlir", + "pipeline_riscv_aggressive_distribution_tests.mlir", "pipeline_split_reduction_tests.mlir", "pipeline_tests.mlir", "pipeline_transpose_avx2_tests.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt index 46d22db238ab..933a75fb4a91 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt @@ -36,6 +36,7 @@ iree_lit_test_suite( "pipeline_pad_conv_tests.mlir" "pipeline_pad_tests.mlir" "pipeline_peel_and_vectorize_tests.mlir" + "pipeline_riscv_aggressive_distribution_tests.mlir" "pipeline_split_reduction_tests.mlir" "pipeline_tests.mlir" "pipeline_transpose_avx2_tests.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_riscv_aggressive_distribution_tests.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_riscv_aggressive_distribution_tests.mlir new file mode 100644 index 000000000000..edc2f88ab391 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_riscv_aggressive_distribution_tests.mlir @@ -0,0 +1,39 @@ +// RUN: iree-opt --iree-llvmcpu-riscv-aggressive-distribution=true --pass-pipeline='builtin.module(iree-llvmcpu-select-lowering-strategy, func.func(iree-llvmcpu-lower-executable-target))' --split-input-file %s | FileCheck %s + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +#executable_target_embedded_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-riscv_64", {cpu_features = "+m,+a,+f,+d,+zvl1024b,+v", data_layout = "e-m:e-p:64:64-i64:64-i256:256-n32:64-S256", native_vector_size = 256 : index, target_triple = "riscv64-unknown-unknown-eabi-elf"}> +builtin.module { + func.func @f32_rvv_matmul() attributes {hal.executable.target = #executable_target_embedded_elf_riscv_64_} { + %cst = arith.constant 0.0 : f32 + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : !flow.dispatch.tensor> + %lhs = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<384x512xf32> + %rhs = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<512x256xf32> + %init = tensor.empty() : tensor<384x256xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<384x256xf32>) -> tensor<384x256xf32> + %res = linalg.matmul ins(%lhs, %rhs : tensor<384x512xf32>, tensor<512x256xf32>) outs(%fill : tensor<384x256xf32>) -> tensor<384x256xf32> + flow.dispatch.tensor.store %res, %2, offsets = [0, 0], sizes = [384, 256], strides = [1, 1] : tensor<384x256xf32> -> !flow.dispatch.tensor> + return + } +} +// CHECK-LABEL: func.func @f32_rvv_matmul( +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index +// CHECK-DAG: %[[c128:.+]] = arith.constant 128 : index +// CHECK-DAG: %[[c256:.+]] = arith.constant 256 : index +// CHECK-DAG: %[[c512:.+]] = arith.constant 512 : index +// CHECK: scf.for {{.*}} step %[[c7]] +// CHECK: scf.for {{.*}} step %[[c128]] +// CHECK: scf.for {{.*}} step %[[c1]] +// CHECK-COUNT-7: vector.fma +// CHECK-COUNT-7: vector.store +// CHECK: scf.for {{.*}} step %[[c128]] +// CHECK: scf.for {{.*}} step %[[c1]] +// CHECK-COUNT-4: vector.fma +// CHECK-COUNT-4: vector.store diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_riscv_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_riscv_lowering_strategy.mlir index 02095fcec42d..d1a57768e32f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_riscv_lowering_strategy.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_riscv_lowering_strategy.mlir @@ -1,4 +1,5 @@ // RUN: iree-opt --pass-pipeline='builtin.module(iree-llvmcpu-select-lowering-strategy)' --split-input-file %s | FileCheck %s +// RUN: iree-opt --iree-llvmcpu-riscv-aggressive-distribution=true --pass-pipeline='builtin.module(iree-llvmcpu-select-lowering-strategy)' --split-input-file %s | FileCheck %s -check-prefixes=CHECK-AGGRESSIVE #pipeline_layout = #hal.pipeline.layout, @@ -30,6 +31,105 @@ func.func @matmul_riscv() attributes {hal.executable.target = #executable_target // ----- +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +#executable_target_embedded_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-riscv_64", {cpu_features = "+m,+a,+f,+d,+zvl512b,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 128 : index, target_triple = "riscv64-unknown-unknown-eabi-elf"}> +builtin.module { + func.func @matmul_riscv_vl512() attributes {hal.executable.target = #executable_target_embedded_elf_riscv_64_} { + %cst = arith.constant 0.0 : f32 + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : !flow.dispatch.tensor> + %lhs = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<384x512xf32> + %rhs = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<512x128xf32> + %init = tensor.empty() : tensor<384x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<384x128xf32>) -> tensor<384x128xf32> + %res = linalg.matmul ins(%lhs, %rhs : tensor<384x512xf32>, tensor<512x128xf32>) outs(%fill : tensor<384x128xf32>) -> tensor<384x128xf32> + flow.dispatch.tensor.store %res, %2, offsets = [0, 0], sizes = [384, 128], strides = [1, 1] : tensor<384x128xf32> -> !flow.dispatch.tensor> + return + } +} + +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG2:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK: func.func @matmul_riscv_vl512() +// CHECK-SAME: translation_info = #[[TRANSLATION]] +// CHECK: linalg.matmul +// CHECK-SAME: lowering_config = #[[CONFIG2]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +#executable_target_embedded_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-riscv_64", {cpu_features = "+m,+a,+f,+d,+zvl1024b,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 256 : index, target_triple = "riscv64-unknown-unknown-eabi-elf"}> +builtin.module { + func.func @matmul_riscv_vl1024() attributes {hal.executable.target = #executable_target_embedded_elf_riscv_64_} { + %cst = arith.constant 0.0 : f32 + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : !flow.dispatch.tensor> + %lhs = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<384x512xf32> + %rhs = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<512x128xf32> + %init = tensor.empty() : tensor<384x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<384x128xf32>) -> tensor<384x128xf32> + %res = linalg.matmul ins(%lhs, %rhs : tensor<384x512xf32>, tensor<512x128xf32>) outs(%fill : tensor<384x128xf32>) -> tensor<384x128xf32> + flow.dispatch.tensor.store %res, %2, offsets = [0, 0], sizes = [384, 128], strides = [1, 1] : tensor<384x128xf32> -> !flow.dispatch.tensor> + return + } +} + +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG2:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK: func.func @matmul_riscv_vl1024() +// CHECK-SAME: translation_info = #[[TRANSLATION]] +// CHECK: linalg.matmul +// CHECK-SAME: lowering_config = #[[CONFIG2]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +#executable_target_embedded_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-riscv_64", {cpu_features = "+m,+a,+f,+d,+zvl512b,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 128 : index, target_triple = "riscv64-unknown-unknown-eabi-elf"}> +builtin.module { + func.func @gemv_riscv_vl512() attributes {hal.executable.target = #executable_target_embedded_elf_riscv_64_} { + %cst = arith.constant 0.0 : f32 + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : !flow.dispatch.tensor> + %lhs = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 512], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<1x512xf32> + %rhs = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<512x128xf32> + %init = tensor.empty() : tensor<1x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x128xf32>) -> tensor<1x128xf32> + %res = linalg.matmul ins(%lhs, %rhs : tensor<1x512xf32>, tensor<512x128xf32>) outs(%fill : tensor<1x128xf32>) -> tensor<1x128xf32> + flow.dispatch.tensor.store %res, %2, offsets = [0, 0], sizes = [1, 128], strides = [1, 1] : tensor<1x128xf32> -> !flow.dispatch.tensor> + return + } +} + +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG2:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK: func.func @gemv_riscv_vl512() +// CHECK-SAME: translation_info = #[[TRANSLATION]] +// CHECK: linalg.matmul +// CHECK-SAME: lowering_config = #[[CONFIG2]] + +// ----- + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, @@ -56,3 +156,36 @@ func.func @thin_depthwise_conv_static() attributes {hal.executable.target = #exe // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK: linalg.depthwise_conv_2d_nhwc_hwc // CHECK-SAME: lowering_config = #[[CONFIG]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +#executable_target_embedded_elf_riscv_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-riscv_64", {cpu_features = "+m,+a,+f,+d,+zvl1024b,+v", data_layout = "e-m:e-p:64:64-i64:64-i256:256-n32:64-S256", native_vector_size = 256 : index, target_triple = "riscv64-unknown-unknown-eabi-elf"}> +builtin.module { + func.func @matmul_riscv_vl1024() attributes {hal.executable.target = #executable_target_embedded_elf_riscv_64_} { + %cst = arith.constant 0.0 : f32 + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : !flow.dispatch.tensor> + %lhs = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<384x512xf32> + %rhs = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<512x256xf32> + %init = tensor.empty() : tensor<384x256xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<384x256xf32>) -> tensor<384x256xf32> + %res = linalg.matmul ins(%lhs, %rhs : tensor<384x512xf32>, tensor<512x256xf32>) outs(%fill : tensor<384x256xf32>) -> tensor<384x256xf32> + flow.dispatch.tensor.store %res, %2, offsets = [0, 0], sizes = [384, 256], strides = [1, 1] : tensor<384x256xf32> -> !flow.dispatch.tensor> + return + } +} + +// CHECK-AGGRESSIVE-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-AGGRESSIVE-DAG: #[[CONFIG2:.+]] = #iree_codegen.lowering_config +// CHECK-AGGRESSIVE-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-AGGRESSIVE: func.func @matmul_riscv_vl1024() +// CHECK-AGGRESSIVE-SAME: translation_info = #[[TRANSLATION]] +// CHECK-AGGRESSIVE: linalg.matmul +// CHECK-AGGRESSIVE-SAME: lowering_config = #[[CONFIG2]]