Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ iree_compiler_cc_library(
"TileInferenceUtils.h",
"Transforms.h",
"UserConfig.h",
"VectorLayoutAnalysis.h",
],
deps = [
":PassHeaders",
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ iree_cc_library(
"TileInferenceUtils.h"
"Transforms.h"
"UserConfig.h"
"VectorLayoutAnalysis.h"
SRCS
"AddFastMathFlags.cpp"
"BlockDynamicDimensions.cpp"
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <cstdint>
#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand All @@ -17,6 +17,8 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include <deque>

#define DEBUG_TYPE "iree-codegen-gpu-vector-distribution"

using namespace mlir::iree_compiler::IREE::VectorExt;
Expand All @@ -34,16 +36,17 @@ constexpr StringLiteral kVectorLayoutRedistributeAttrName =
/// Set signature for the operation based on the analysis. Returns failure if
/// an operation contains vectors that cannot be distributed i.e. they have no
/// layout.
LogicalResult setOpSignature(Operation *op, VectorLayoutAnalysis &analysis,
const VectorLayoutOptions &options) {
LogicalResult
setOpSignature(Operation *op,
llvm::MapVector<Value, VectorLayoutInterface> &layouts,
const VectorLayoutOptions &options) {
SmallVector<Attribute> operands;
SmallVector<Attribute> results;

for (Value operand : op->getOperands()) {
if (auto vectorOperand = dyn_cast<VectorValue>(operand)) {
if (auto layout =
analysis.getLayout<VectorLayoutInterface>(vectorOperand)) {
operands.push_back(layout);
if (layouts.contains(vectorOperand)) {
operands.push_back(layouts[vectorOperand]);
continue;
}
if (auto layout = options.getDefaultLayout(vectorOperand.getType())) {
Expand All @@ -57,9 +60,8 @@ LogicalResult setOpSignature(Operation *op, VectorLayoutAnalysis &analysis,

for (Value result : op->getResults()) {
if (auto vectorResult = dyn_cast<VectorValue>(result)) {
if (auto layout =
analysis.getLayout<VectorLayoutInterface>(vectorResult)) {
results.push_back(layout);
if (layouts.contains(vectorResult)) {
results.push_back(layouts[vectorResult]);
continue;
}
if (auto layout = options.getDefaultLayout(vectorResult.getType())) {
Expand Down Expand Up @@ -356,17 +358,19 @@ LogicalResult distributeVectorOps(Operation *root,
VectorLayoutOptions &options) {
// Run the analysis and determine the layouts.
LLVM_DEBUG(llvm::dbgs() << "Running Layout Analysis\n");
VectorLayoutAnalysis analysis(root);
if (failed(analysis.run()))
llvm::MapVector<Value, VectorLayoutInterface> layouts;
if (failed(propagateVectorLayoutInfo(root, layouts))) {
LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Failed\n");
return failure();
}
LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Succeded\n");
LLVM_DEBUG(llvm::dbgs() << "\n\n");

// Go to each operation, and set its distribution signature.
LLVM_DEBUG(
llvm::dbgs() << "Setting distribution signatures for operations\n");
root->walk([&](Operation *op) {
if (failed(setOpSignature(op, analysis, options))) {
if (failed(setOpSignature(op, layouts, options))) {
LLVM_DEBUG({
llvm::dbgs() << "Skipping operation because not all vector "
"operands/results have a layout:\n";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#ifndef IREE_COMPILER_CODEGEN_COMMON_GPU_VECTOR_DISTRIBUTION_H_
#define IREE_COMPILER_CODEGEN_COMMON_GPU_VECTOR_DISTRIBUTION_H_

#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
Expand All @@ -16,6 +15,8 @@

namespace mlir::iree_compiler {

using IREE::VectorExt::VectorLayoutInterface;

/// A signature describing the layout for each value of vector type which is
/// an operand or result of this operation.
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
// batch, m, k1, k2
#lowering_config = #iree_gpu.lowering_config<{reduction = [ 0, 0, 0, 32]}>

// CHECK-LABEL: func.func @online_attention_fail_to_pad_no_mask
func.func @online_attention_fail_to_pad_no_mask(%query: tensor<192x1024x64xf32>, %key: tensor<192x?x64xf32>, %value: tensor<192x?x64xf32>) -> tensor<192x1024x64xf32> {
%scale = arith.constant 1.0 : f32

Expand All @@ -25,10 +24,7 @@ func.func @online_attention_fail_to_pad_no_mask(%query: tensor<192x1024x64xf32>,
%acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
%sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>

// CHECK-NOT: tensor.pad
// CHECK: iree_linalg_ext.online_attention {{.*}} ins(%{{[0-9a-z_]*}}, %{{[0-9a-z_]*}}, %{{[0-9a-z_]*}}, %{{[0-9a-z_]*}}
// CHECK-SAME: : tensor<192x1024x64xf32>, tensor<192x?x64xf32>, tensor<192x?x64xf32>, f32)
// expected-remark@+1{{failed to pad op: requires a mask operand to pad to the proper value. Consider materializing the mask operand explicitly.}}
// expected-error@+1{{Padding OnlineAttention without existing mask is not yet supported}}
%out:3 = iree_linalg_ext.online_attention
{
indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR],
Expand Down Expand Up @@ -118,7 +114,6 @@ func.func @online_attention_tile_then_pad_7(%n_batches: index, %query: tensor<?x
%acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<?x1021xf32>) -> tensor<?x1021xf32>
%sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<?x1021xf32>) -> tensor<?x1021xf32>

// CHECK: arith.constant 0xFF800000 : f32
// CHECK-COUNT-7: tensor.pad
// CHECK: iree_linalg_ext.online_attention {{.*}} ins(%{{[0-9a-z_]*}}, %{{[0-9a-z_]*}}, %{{[0-9a-z_]*}}, %{{[0-9a-z_]*}}, %{{[0-9a-z_]*}}
// CHECK-SAME: : tensor<4x8x64xf32>, tensor<4x32x64xf32>, tensor<4x32x64xf32>, f32, tensor<4x8xf32>)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
// CHECK: tensor.pad %arg0 low[0, 0] high[0, %[[CEIL]]
// CHECK: linalg.generic
// CHECK: ^bb0(
// CHECK: arith.subf
// CHECK: %[[EXP:.+]] = math.exp
// CHECK: %[[INDEX1:.+]] = linalg.index 1 : index
// CHECK: %[[CMP:.+]] = arith.cmpi ult, %[[INDEX1]], %[[DIMARG0]] : index
// CHECK: arith.subf
// CHECK: %[[EXP:.+]] = math.exp
// CHECK: %[[SELECT:.+]] = arith.select %[[CMP]], %[[EXP:.+]], %[[ZEROF32]] : f32
// CHECK: %[[ADD:.+]] = arith.addf %[[SELECT]], %out : f32
// CHECK: linalg.yield %[[ADD]] : f32
Expand Down Expand Up @@ -123,9 +123,9 @@ func.func @min_reduction(%arg0: tensor<1x?xf32>, %arg1: tensor<1xf32>) -> tensor
// CHECK-DAG: %[[DIMARG0:.+]] = tensor.dim %arg0, %[[C1]] : tensor<1x?xf16>
// CHECK: linalg.generic
// CHECK: ^bb0(
// CHECK: %[[MUL:.+]] = arith.mulf
// CHECK: %[[INDEX1:.+]] = linalg.index 1 : index
// CHECK: %[[CMP:.+]] = arith.cmpi ult, %[[INDEX1]], %[[DIMARG0]] : index
// CHECK: %[[MUL:.+]] = arith.mulf
// CHECK: %[[SELECT:.+]] = arith.select %[[CMP]], %[[MUL]], %[[ZERO]] : f16
// CHECK: %[[ADD:.+]] = arith.addf %out, %[[SELECT]] : f16
// CHECK: linalg.yield %[[ADD]] : f16
Expand Down Expand Up @@ -159,10 +159,10 @@ func.func @standard_inner_product(%arg0 : tensor<1x?xf16>, %arg1 : tensor<1x?xf1
// CHECK-NOT: tensor.pad
// CHECK: linalg.generic
// CHECK: ^bb0(
// CHECK: %[[MUL:.+]] = arith.mulf
// CHECK: %[[TRUNC:.+]] = arith.truncf %[[MUL]] : f32 to f16
// CHECK: %[[INDEX1:.+]] = linalg.index 1 : index
// CHECK: %[[CMP:.+]] = arith.cmpi ult, %[[INDEX1]], %[[DIMARG0]] : index
// CHECK: %[[MUL:.+]] = arith.mulf
// CHECK: %[[TRUNC:.+]] = arith.truncf %[[MUL]] : f32 to f16
// CHECK: %[[SELECT:.+]] = arith.select %[[CMP]], %[[TRUNC]], %[[ZERO]] : f16
// CHECK: %[[ADD:.+]] = arith.addf %out, %[[SELECT]] : f16
// CHECK: linalg.yield %[[ADD]] : f16
Expand Down Expand Up @@ -191,8 +191,8 @@ func.func @standard_inner_product_with_trunc(%arg0 : tensor<1x?xf32>, %arg1 : te

// CHECK-LABEL: product_of_sum_reduction
// CHECK: %[[ONE:.+]] = arith.constant 1.000000e+00 : f16
// CHECK: %[[ADD:.+]] = arith.addf
// CHECK: %[[CMP:.+]] = arith.cmpi
// CHECK: %[[ADD:.+]] = arith.addf
// CHECK: arith.select %[[CMP]], %[[ADD]], %[[ONE]] : f16
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
Expand Down Expand Up @@ -223,12 +223,12 @@ func.func @product_of_sum_reduction(%arg0 : tensor<1x?xf16>, %arg1 : tensor<1x?x
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf16>
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf16>
// CHECK: linalg.generic
// CHECK-DAG: %[[MUL:.+]] = arith.mulf
// CHECK-DAG: %[[INDEX0:.+]] = linalg.index 0 : index
// CHECK-DAG: %[[INDEX1:.+]] = linalg.index 1 : index
// CHECK-DAG: %[[CMP0:.+]] = arith.cmpi ult, %[[INDEX0]], %[[DIM1]] : index
// CHECK-DAG: %[[CMP1:.+]] = arith.cmpi ult, %[[INDEX1]], %[[DIM0]] : index
// CHECK: %[[AND:.+]] = arith.andi %[[CMP0]], %[[CMP1]] : i1
// CHECK-DAG: %[[MUL:.+]] = arith.mulf
// CHECK: %[[SELECT:.+]] = arith.select %[[AND]], %[[MUL]], %[[ZEROF16]] : f16
// CHECK: %[[ADD:.+]] = arith.addf %out, %[[SELECT]] : f16
#map = affine_map<(d0, d1) -> (d1, d0)>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
Expand Down
70 changes: 70 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,76 @@ FailureOr<IREETilingResult>
tileDispatchUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
linalg::LinalgTilingOptions options);

namespace IREE::VectorExt {
class VectorLayoutInterface;
} // namespace IREE::VectorExt

/// Analyzes the root op and it's nested ops to propagate vector layouts
/// originating from to_vector operations. Example:
///
/// %root = vector.transfer_read
/// |
/// --> anchored to layout L (using a to_layout op)
/// %root2 = vector.transfer_read
/// %c = arith.mulf %root, %b
/// |
/// --> %root, %b and %c must have the same layout
/// %e = arith.divf %b, %root2
/// |
/// --> %root2, %b and %e must have the same layout
///
/// Here, the user provided an anchor point for %root, fixing it's layout to L.
/// The layout then uses it's inference rules to find the layout of other
/// values:
///
/// %root = vector.transfer_read
/// |
/// --> infered to layout L
/// %root2 = vector.transfer_read
/// |
/// --> infered to layout L
/// %c = arith.mulf %root, %b
/// |
/// --> infered to layout L
/// %e = arith.divf %b, %root2
/// |
/// --> infered to layout L
///
/// If at any point, a value has a layout, but the user of that value requires
/// a different layout, the analysis inserts a resolution operation. This
/// resolution operation is `iree_vector_ext.layout_conflict_resolution`.
/// For Example:
///
/// %0 = vector.transfer_read
/// |
/// --> anchored to layout L
/// %1 = vector.transfer_read
/// |
/// --> anchored to layout L'
/// arith.addf %0, %1
/// |
/// --> %0 and %1 must have the same layout
///
/// To resolve the conflict, the analysis chooses one of the layouts, say
/// L, and inserts a resolution operation to convert the other layout to L.
///
/// %0 = vector.transfer_read
/// |
/// --> anchored to layout L
/// %1 = vector.transfer_read
/// |
/// --> anchored to layout L'
/// %resolved = iree_vector_ext.layout_conflict_resolution %1
/// |
/// --> infered to layout L
/// arith.addf %0, %resolved
///
/// The analysis itself will not try to resolve the conflict, but instead
/// leaves it to the user to resolve the conflict.
LogicalResult propagateVectorLayoutInfo(
Operation *root,
llvm::MapVector<Value, IREE::VectorExt::VectorLayoutInterface> &layouts);

/// Transform a `scf.for` loop with a strictly positive step
/// for %i = %lb to %ub step %s
/// into a 0-based loop with step 1
Expand Down
Loading