From 10ba28dcf142f8b30217f24d7781130d581663c3 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Sat, 17 Aug 2024 14:08:14 -0400 Subject: [PATCH] [Codegen][GPU] Add kernel config for LLVMGPUTileAndFuse (#17791) This adds kernel configuration logic for targeting simple thread distribution of linalg-based dispatches on LLVMGPU. The configuration logic is primarily copied from the same logic on the SPIR-V side due to the already well tested heuristics there for the kinds of varied target descriptions that are present on the SPIR-V side. Currently this is locked behind a flag `iree-codegen-llvmgpu-test-tile-and-fuse-vectorize`. Future patches will add specialized logic for matmul. --- .../Dialect/GPU/TargetUtils/BUILD.bazel | 2 + .../Dialect/GPU/TargetUtils/CMakeLists.txt | 2 + .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 267 ++++++++++++++++++ .../Dialect/GPU/TargetUtils/ConfigUtils.h | 6 + .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 30 +- .../Codegen/LLVMGPU/ROCDLKernelConfig.cpp | 10 +- .../test/ROCDL/config_tile_and_fuse.mlir | 90 +++++- .../compiler/Codegen/SPIRV/KernelConfig.cpp | 35 --- .../src/iree/compiler/Codegen/Utils/Utils.cpp | 33 +++ .../src/iree/compiler/Codegen/Utils/Utils.h | 8 + 10 files changed, 435 insertions(+), 48 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/BUILD.bazel index 45957a4d73c4..9aab5d11ab6e 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/BUILD.bazel @@ -24,10 +24,12 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common/GPU:GPUHeuristics", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", + "//compiler/src/iree/compiler/Codegen/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgUtils", "@llvm-project//mlir:Support", ], ) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/CMakeLists.txt index 8cf7e05c614f..e0fa5afb0db2 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/CMakeLists.txt @@ -22,10 +22,12 @@ iree_cc_library( MLIRFunctionInterfaces MLIRIR MLIRLinalgDialect + MLIRLinalgUtils MLIRSupport iree::compiler::Codegen::Common::GPU::GPUHeuristics iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect + iree::compiler::Codegen::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 143cba102c94..4fc2b67d22a9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -11,9 +11,11 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" +#include "iree/compiler/Codegen/Utils/Utils.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -201,4 +203,269 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, workgroupSize, targetSubgroupSize); } +LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target, + mlir::FunctionOpInterface entryPoint, + Operation *op) { + auto linalgOp = dyn_cast(op); + // Bail out on multi result cases as consumer fusion currently does not + // support multi result ops. + if (!linalgOp || linalgOp.getNumDpsInits() != 1) { + return failure(); + } + + // This pipeline requires tensor semantics. Also fail for gather semantics + // for now to simplify tile + fuse. + if (!linalgOp.hasPureTensorSemantics() || linalgOp.hasIndexSemantics()) { + return failure(); + } + + SmallVector partitionableLoops; + linalgOp.getParallelDims(partitionableLoops); + + // Bail out if op is not tilable. + if (partitionableLoops.empty()) { + return failure(); + } + + const int subgroupSize = target.getPreferredSubgroupSize(); + const unsigned loopDepth = linalgOp.getNumLoops(); + + // Configurations we need to decide. + std::array workgroupSize; + SmallVector workgroupTileSizes; + SmallVector threadTileSizes; + + // Initialize the configuration. + auto initConfiguration = [&]() { + workgroupSize = {subgroupSize, 1, 1}; + workgroupTileSizes.resize(loopDepth, 0); + threadTileSizes.resize(loopDepth, 0); + + // Initialize tiling along all partitioned loops with size 1. + for (int64_t loopIndex : partitionableLoops) { + workgroupTileSizes[loopIndex] = threadTileSizes[loopIndex] = 1; + } + // Override the innermost dimension to distribute to threads in a subgroup. + workgroupTileSizes[partitionableLoops.back()] = subgroupSize; + }; + + // Common case for all linalg ops. + + // The core idea is to distribute the partitioned loops to the workgroup + // dimensions. The goal is to fill up the GPU as much as possible, which means + // 1) distributing to as many threads as possible, and 2) avoid assigning too + // many threads to handle out-of-bound elements (thus idle). + + auto elementHasPowerOfTwoBitwidth = [](Value operand) { + Type elementType = getElementTypeOrSelf(operand.getType()); + return isa(elementType) && + llvm::isPowerOf2_64(IREE::Util::getTypeBitWidth(elementType)); + }; + + // Whether we can try to use the vectorization pipeline. + SmallVector loopBounds = linalgOp.getStaticLoopRanges(); + bool projPerm = + llvm::all_of(linalgOp.getIndexingMapsArray(), + [](AffineMap map) { return map.isProjectedPermutation(); }); + bool powTwo = + llvm::all_of(linalgOp->getOperands(), elementHasPowerOfTwoBitwidth); + bool staticShape = llvm::none_of(loopBounds, ShapedType::isDynamic); + + // Require all affine maps to be projected permutation so that we can + // generate vector transfer ops. + bool vectorizable = projPerm && powTwo && staticShape; + + const unsigned minBitwidth = getMinElementBitwidth(linalgOp); + // Make sure we use a tile size that results in some integral number of bytes. + const unsigned scaleToByte = + std::max(8 / minBitwidth, static_cast(1)); + + // Distribute workload to the given `numThreads` by allowing a potental loss. + auto distributeToThreads = [&](int64_t numThreads, + std::optional lossFactor = + std::nullopt) { + LDBG("Loss factor: " << lossFactor << "\n"); + initConfiguration(); + // If there are more than 3 parallel dim try to tile the extra higher level + // dimensions to 1 for extra dimensions. + if (isa(linalgOp.getOperation())) { + for (auto [i, tileSize] : llvm::enumerate(workgroupTileSizes)) { + if (tileSize != 0) + break; + if (loopBounds[i] != 1) + tileSize = 1; + } + } + // Scan from the innermost shape dimension and try to deduce the + // configuration for the corresponding GPU workgroup dimension. + int64_t wgDim = 0; + for (auto shapeDim : llvm::reverse(partitionableLoops)) { + int64_t loopBound = loopBounds[shapeDim]; + // Skip dynamic dimensions. + if (ShapedType::isDynamic(loopBound)) + continue; + + // Try to find some power of two that can devide the current shape dim + // size. This vector keeps the candidate tile sizes. + SmallVector candidates; + + // For the inner most workgroup dim, try to see if we can have 4 + // elements per thread. This enables vectorization. + if (vectorizable && wgDim == 0 && !lossFactor) { + candidates.push_back(4 * numThreads); + } + // Try all power of two numbers up to the subgroup size. + for (unsigned i = numThreads; i >= 1; i >>= 1) { + candidates.push_back(i); + } + LLVM_DEBUG({ + llvm::dbgs() << "Base candidate tile sizes: ["; + llvm::interleaveComma(candidates, llvm::dbgs()); + llvm::dbgs() << "]\n"; + }); + + for (int64_t candidate : candidates) { + int64_t scaledTileSize = candidate * scaleToByte; + if (loopBound % scaledTileSize != 0) { + if (!lossFactor) + continue; + // Skip this candidate if it causes many threads to be idle. + int64_t idleThreads = candidate - (loopBound % scaledTileSize); + if (idleThreads > candidate / *lossFactor) + continue; + } + // If the workload is too small and we cannot distribute to more than 2 + // workgroups, try a smaller tile size to increase parallelism. + if (partitionableLoops.size() == 1 && candidate > subgroupSize && + llvm::divideCeil(loopBound, scaledTileSize) <= 2) { + continue; + } + + // Found a suitable candidate. Try to let each thread handle 4 + // elements if this is the workgroup x dimension. + // TODO: Try to take into account element type bit width to get + // 4xdword reads instead of 4x{elements}. + workgroupTileSizes[shapeDim] = scaledTileSize; + LLVM_DEBUG(llvm::dbgs() + << "Chosen workgroup tile size: " << scaledTileSize << "\n"); + if (vectorizable && wgDim == 0 && !lossFactor && candidate % 4 == 0) { + // Use size-1 vectors to increase parallelism if larger ones causes + // idle threads in the subgroup. + bool hasIdleThreads = + partitionableLoops.size() == 1 && candidate <= subgroupSize; + int vectorSize = hasIdleThreads ? 1 : 4; + LLVM_DEBUG(llvm::dbgs() << "Use vector size: " << vectorSize << "\n"); + threadTileSizes[shapeDim] = vectorSize * scaleToByte; + workgroupSize[wgDim] = candidate / vectorSize; + assert(numThreads % (candidate / vectorSize) == 0); + numThreads /= candidate / vectorSize; + } else { + if (wgDim == 0) + vectorizable = false; + threadTileSizes[shapeDim] = scaleToByte; + workgroupSize[wgDim] = candidate; + assert(numThreads % candidate == 0); + numThreads /= candidate; + } + assert(numThreads >= 1); + break; + } + + // Stop if we have distributed all threads. + if (numThreads == 1) + break; + wgDim++; + } + return numThreads; + }; + + // First try to see if we can use up all threads without any loss. + if (distributeToThreads(subgroupSize) != 1) { + // Otherwise, allow larger and larger loss factor. + + // Threads for distribution. Use 32 at least. + int64_t numThreads = std::max(subgroupSize, 32); + // We can tolerate (1 / lossFactor) of threads in the workgroup to be idle. + int64_t lossFactor = 32; + + for (; lossFactor >= 1; lossFactor >>= 1) { + if (distributeToThreads(numThreads, lossFactor) == 1) + break; + } + } + + // TODO(qedawkins): Currently scf.forall resolution only supports static + // trip counts, meaning the workgroup tile size must perfectly divide the + // loop bound (and thread tile size must perfectly divide the workgroup tile) + // so that the trip count won't be static. Remove this check once proper + // dynamic trip count resolution support is added. + for (auto [loopId, threadTile] : llvm::enumerate(threadTileSizes)) { + if (threadTile == 0) { + continue; + } + int64_t bound = loopBounds[loopId]; + int64_t wkgpTile = workgroupTileSizes[loopId]; + if (bound % wkgpTile != 0 || wkgpTile % threadTile != 0) { + return failure(); + } + } + + TileSizesListType tileSizes; + tileSizes.push_back(workgroupTileSizes); + tileSizes.push_back(threadTileSizes); + + // Attach the MMA schedule as an attribute to the entry point export function + // for later access in the pipeline. + MLIRContext *context = linalgOp.getContext(); + SmallVector attrs; + Builder b(context); + attrs.emplace_back(StringAttr::get(context, "workgroup"), + b.getIndexArrayAttr(workgroupTileSizes)); + + attrs.emplace_back(StringAttr::get(context, "thread"), + b.getIndexArrayAttr(threadTileSizes)); + + // Heuristic value chosen to limit maximum vector sizes when tiling below. + const unsigned maxVectorSize = 32; + + // Try to tile all reductions by some small factor, preferrably 4, when + // possible. This gives us a chance to perform vector4 load if an input has + // its innnermost dimension being reduction. It also avoids generating too + // many instructions when unrolling vector later. We limit the expected + // vector size by estimating it from the size of the iteration space tile and + // limit it to a reasonable value. We process the loops from inner most to + // outer most to try to align loads along inner dimensions. + int64_t vectorSize = 1; + int64_t numLoops = linalgOp.getNumLoops(); + SmallVector iterTypes = linalgOp.getIteratorTypesArray(); + SmallVector loopTileSizes(numLoops, 0); + for (auto [reverseIdx, iter] : llvm::enumerate(llvm::reverse(iterTypes))) { + unsigned i = numLoops - reverseIdx - 1; + if (linalg::isReductionIterator(iter) || i >= workgroupTileSizes.size() || + workgroupTileSizes[i] == 0) { + int64_t tileSize = getReductionTilingFactor(loopBounds[i]); + if (vectorSize * tileSize > maxVectorSize) { + tileSize = 1; + } + vectorSize *= tileSize; + loopTileSizes[i] = tileSize; + } + } + if (llvm::any_of(loopTileSizes, [](int64_t s) { return s != 0; })) { + attrs.emplace_back(StringAttr::get(context, "reduction"), + b.getIndexArrayAttr(loopTileSizes)); + } + + auto configDict = DictionaryAttr::get(context, attrs); + auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict); + + LDBG("Selected tile and fuse lowering config: " << loweringConfig << "\n"); + + // TODO(qedawkins): Use a shared pipeline identifier here. + return setOpConfigAndEntryPointFnTranslation( + entryPoint, op, loweringConfig, + IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse, + workgroupSize, subgroupSize, DictionaryAttr()); +} + } // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h index 913403193fe0..f87fab963a98 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h @@ -20,6 +20,12 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint, Operation *op); +/// Helper for setting up a default tile and fuse config for targeting +/// simple thread distribution. Currently restricted to linalg ops. +LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target, + mlir::FunctionOpInterface entryPoint, + Operation *op); + } // namespace mlir::iree_compiler::IREE::GPU #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TARGETUTILS_CONFIGUTILS_H_ diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index d7ee4e04a5f5..72091ea9ceeb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -46,9 +46,15 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir::iree_compiler { -llvm::cl::opt clGPUEnableTileAndFuse( - "iree-codegen-llvmgpu-use-tile-and-fuse", - llvm::cl::desc("enable the usage of the tile and fuse pipeline"), +llvm::cl::opt clGPUTestTileAndFuseMatmul( + "iree-codegen-llvmgpu-test-tile-and-fuse-matmul", + llvm::cl::desc("test the the tile and fuse pipeline for matmul"), + llvm::cl::init(false)); + +llvm::cl::opt clGPUTestTileAndFuseVectorize( + "iree-codegen-llvmgpu-test-tile-and-fuse-vectorize", + llvm::cl::desc( + "test the tile and fuse pipeline for all supported operations"), llvm::cl::init(false)); llvm::cl::opt clGPUEnableVectorDistribution( @@ -1946,10 +1952,19 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target, LDBG("Transform Dialect Config"); return success(); } - if (clGPUEnableTileAndFuse && succeeded(IREE::GPU::setMatmulLoweringConfig( - target, entryPointFn, computeOp))) { - LDBG("Tile and fuse matmul config"); - return success(); + if (clGPUTestTileAndFuseMatmul) { + if (succeeded(IREE::GPU::setMatmulLoweringConfig(target, entryPointFn, + computeOp))) { + LDBG("Tile and fuse matmul config"); + return success(); + } + } + if (clGPUTestTileAndFuseVectorize) { + if (succeeded(IREE::GPU::setTileAndFuseLoweringConfig(target, entryPointFn, + computeOp))) { + LDBG("Tile and fuse default config"); + return success(); + } } if (succeeded(setVectorDistributionConfig(target, entryPointFn, computeOp))) { return success(); @@ -2070,6 +2085,7 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) { } } } + // Translation info (lowering pipeline) is already set. return success(); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp index 2cd705dc844d..dcf287bebdc5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp @@ -280,6 +280,11 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target, if (succeeded(setWarpReductionConfig(target, entryPointFn, linalgOp))) { return success(); } + // TODO: Add configurations for matmul here too. + if (succeeded(IREE::GPU::setTileAndFuseLoweringConfig(target, entryPointFn, + computeOp))) { + return success(); + } } return failure(); @@ -386,7 +391,10 @@ LogicalResult initROCDLLaunchConfig(FunctionOpInterface funcOp) { if (failed(setRootConfig(target, funcOp, rootOp))) return failure(); - propagateLoweringConfig(rootOp, computeOps); + if (getTranslationInfo(funcOp).getDispatchLoweringPassPipeline() != + IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) { + propagateLoweringConfig(rootOp, computeOps); + } return success(); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index 2e36ccfc9b2a..cc572592f730 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -1,6 +1,10 @@ -// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 \ -// RUN: --iree-codegen-llvmgpu-use-tile-and-fuse --iree-codegen-llvmgpu-use-vector-distribution=false \ -// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s +// RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx940 \ +// RUN: --iree-codegen-llvmgpu-test-tile-and-fuse-matmul=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \ +// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s + +// TODO: This test is still using the legacy LLVMGPU kernel config. This needs +// to be migrated to the rocdl heuristics, but for now is just physically +// located here. #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> @@ -22,8 +26,8 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor return %7 : tensor<2x10x64x64xf16> } -// CHECK: #iree_codegen.translation_info // CHECK-LABEL: func.func @expanded_matmul_transpose_b +// CHECK-SAME: #iree_codegen.translation_info // Verify that the fill does not have the lowering config propagated to it. // CHECK: linalg.fill ins @@ -45,8 +49,8 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor< return %7 : tensor<1024x1024xf32> } -// CHECK: #iree_codegen.translation_info // CHECK-LABEL: func.func @mfma_matmul_1024x1024x1024 +// CHECK-SAME: #iree_codegen.translation_info // Verify that the fill does not have the lowering config propagated to it. // CHECK: linalg.fill ins @@ -56,3 +60,79 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor< // CHECK-SAME: reduction = [0 : index, 0 : index, 4 : index] // CHECK-SAME: subgroup = [2 : index, 4 : index, 0 : index] // CHECK-SAME: workgroup = [64 : index, 128 : index, 0 : index] + +// ----- + +module { + func.func @conv_nhwc(%3: tensor<2x258x514x768xf16>, %4: tensor<3x3x768x256xf16>) -> tensor<2x256x512x256xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %5 = tensor.empty() : tensor<2x256x512x256xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x256x512x256xf32>) -> tensor<2x256x512x256xf32> + %7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%3, %4 : tensor<2x258x514x768xf16>, tensor<3x3x768x256xf16>) outs(%6 : tensor<2x256x512x256xf32>) -> tensor<2x256x512x256xf32> + return %7 : tensor<2x256x512x256xf32> + } +} + +// CHECK-LABEL: func.func @conv_nhwc +// CHECK-SAME: #iree_codegen.translation_info +// CHECK: linalg.conv_2d_nhwc_hwcf {{.*}} lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: reduction = [0 : index, 0 : index, 0 : index, 0 : index, 1 : index, 3 : index, 4 : index] +// CHECK-SAME: thread = [1 : index, 1 : index, 1 : index, 1 : index, 0 : index, 0 : index, 0 : index] +// CHECK-SAME: workgroup = [1 : index, 1 : index, 1 : index, 64 : index, 0 : index, 0 : index, 0 : index] + +// ----- + +module { + func.func @matmul_dynamic_dim(%11: tensor, %12: tensor<256x256xf16>) -> tensor { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %8 = tensor.dim %11, %c0 : tensor + %13 = tensor.empty(%8) : tensor + %14 = linalg.fill ins(%cst : f32) outs(%13 : tensor) -> tensor + %15 = linalg.matmul ins(%11, %12 : tensor, tensor<256x256xf16>) outs(%14 : tensor) -> tensor + return %15 : tensor + } +} + +// CHECK-LABEL: func.func @matmul_dynamic_dim +// CHECK-SAME: #iree_codegen.translation_info +// CHECK: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: reduction = [0 : index, 0 : index, 4 : index] +// CHECK-SAME: thread = [1 : index, 1 : index, 0 : index] +// CHECK-SAME: workgroup = [1 : index, 64 : index, 0 : index] + +// ----- + +module { + func.func @elementwise_dynamic_dim(%11: tensor, %12: tensor) -> tensor { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %8 = tensor.dim %11, %c0 : tensor + %13 = tensor.empty(%8) : tensor + %15 = linalg.add ins(%11, %12 : tensor, tensor) outs(%13 : tensor) -> tensor + return %15 : tensor + } +} + +// CHECK-LABEL: func.func @elementwise_dynamic_dim +// CHECK-SAME: #iree_codegen.translation_info +// CHECK: linalg.add {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: thread = [1 : index, 1 : index] +// CHECK-SAME: workgroup = [1 : index, 64 : index] + +// ----- + +module @elementwise_unaligned { + func.func @elementwise_unaligned(%11: tensor<180x180xf16>, %12: tensor<180x180xf16>) -> tensor<180x180xf16> { + %cst = arith.constant 0.000000e+00 : f32 + %13 = tensor.empty() : tensor<180x180xf16> + %15 = linalg.add ins(%11, %12 : tensor<180x180xf16>, tensor<180x180xf16>) outs(%13 : tensor<180x180xf16>) -> tensor<180x180xf16> + return %15 : tensor<180x180xf16> + } +} + +// Verify that this does not select this pipeline due to issues with resolving +// dynamic scf.forall loops. +// CHECK-LABEL: module @elementwise_unaligned +// CHECK-NOT: LLVMGPUTileAndFuse diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index b16dca1a666b..a226f79cce4e 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -1262,41 +1262,6 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, // Everything Default Configuration //===----------------------------------------------------------------------===// -/// Returns a small tiling factor for the given reduction `dimSize`. -/// Returns 0 to avoid tiling. -static int getReductionTilingFactor(int64_t dimSize) { - if (dimSize % 4 == 0) - return 4; - - // Try to find the smallest prime factor as the tiling factor. As a trade off - // between generated code size and compilation time, only look at prime - // numbers less than 50 right now. - static constexpr std::array primeNumbers = { - 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47}; - for (int n : primeNumbers) { - if (dimSize % n == 0) - return n; - } - - return 1; // Otherwise just tile with size 1. -} - -/// Returns the minimal element bitwidth used in the operands and results of the -/// given Linalg op. -static int64_t getMinElementBitwidth(linalg::LinalgOp linalgOp) { - unsigned bitwidth = std::numeric_limits::max(); - for (OpOperand *operand : linalgOp.getDpsInputOperands()) { - unsigned b = - IREE::Util::getTypeBitWidth(getElementTypeOrSelf(operand->get())); - bitwidth = std::min(bitwidth, b); - } - for (Value result : linalgOp.getDpsInits()) { - unsigned b = IREE::Util::getTypeBitWidth(getElementTypeOrSelf(result)); - bitwidth = std::min(bitwidth, b); - } - return bitwidth; -}; - static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, Operation *op, bool allowVectorization = true) { diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp index d4d37be9ec56..4ce88aec7afd 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp @@ -799,6 +799,39 @@ FailureOr getSoftwarePipelineStoreStage(DictionaryAttr config) { return llvm::cast(stage).getInt(); } +/// Returns a small tiling factor for the given reduction `dimSize`. +/// Returns 0 to avoid tiling. +int getReductionTilingFactor(int64_t dimSize) { + if (dimSize % 4 == 0) + return 4; + + // Try to find the smallest prime factor as the tiling factor. As a trade off + // between generated code size and compilation time, only look at prime + // numbers less than 50 right now. + static constexpr std::array primeNumbers = { + 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47}; + for (int n : primeNumbers) { + if (dimSize % n == 0) + return n; + } + + return 1; // Otherwise just tile with size 1. +} + +int64_t getMinElementBitwidth(linalg::LinalgOp linalgOp) { + unsigned bitwidth = std::numeric_limits::max(); + for (OpOperand *operand : linalgOp.getDpsInputOperands()) { + unsigned b = + IREE::Util::getTypeBitWidth(getElementTypeOrSelf(operand->get())); + bitwidth = std::min(bitwidth, b); + } + for (Value result : linalgOp.getDpsInits()) { + unsigned b = IREE::Util::getTypeBitWidth(getElementTypeOrSelf(result)); + bitwidth = std::min(bitwidth, b); + } + return bitwidth; +}; + //===---------------------------------------------------------------------===// // Misc. utility functions //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Utils/Utils.h index 11bbde6632be..1125957c03f4 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.h @@ -198,6 +198,14 @@ getSoftwarePipeliningAttrDict(MLIRContext *context, FailureOr getSoftwarePipelineDepth(DictionaryAttr); FailureOr getSoftwarePipelineStoreStage(DictionaryAttr); +// Returns a small tiling factor for the given reduction `dimSize`. +// Returns 0 to avoid tiling. +int getReductionTilingFactor(int64_t dimSize); + +// Returns the minimal element bitwidth used in the operands and results of the +// given Linalg op. +int64_t getMinElementBitwidth(linalg::LinalgOp linalgOp); + //===---------------------------------------------------------------------===// // Misc. utility functions. //===---------------------------------------------------------------------===//