Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 40 additions & 40 deletions lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,17 +494,17 @@ static bool matchExpOpForLUT(math::ExpOp::Adaptor adaptor) {
// Rewrite patterns
//===----------------------------------------------------------------------===//

// This pattern fold `vector.extract` and `vector.splat` into
// This pattern fold `vector.extract` and `vector.broadcast` into
// `aievec.broadcast` for AIE2
struct FoldVectorExtractAndSplatToAIEBroadcast
: OpConversionPattern<vector::SplatOp> {
: OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto extOp = adaptor.getInput().getDefiningOp<vector::ExtractOp>();
auto extOp = adaptor.getSource().getDefiningOp<vector::ExtractOp>();

if (!extOp)
return failure();
Expand All @@ -513,7 +513,7 @@ struct FoldVectorExtractAndSplatToAIEBroadcast
auto pos = extOp.getStaticPosition();
int64_t posVal = pos[0];
auto srcVecType = cast<VectorType>(src.getType());
auto resultType = cast<VectorType>(splatOp.getResult().getType());
auto resultType = cast<VectorType>(bcastOp.getResult().getType());
if (srcVecType != resultType) {
if (srcVecType.getNumElements() != 2 * resultType.getNumElements())
return failure();
Expand All @@ -530,17 +530,17 @@ struct FoldVectorExtractAndSplatToAIEBroadcast
if (unsigned laneSize = getVectorLaneSize(resultType);
laneSize * elWidth == 512) {
// Common use case for the broadcast_elem intrinsic
rewriter.replaceOpWithNewOp<aievec::BroadcastOp>(splatOp, resultType, src,
rewriter.replaceOpWithNewOp<aievec::BroadcastOp>(bcastOp, resultType, src,
posVal);
} else if (laneSize * elWidth == 256) {
// e.g. need v16bf16 due to the subsequent v16accfloat operation
VectorType aievecBcastType =
createVectorType(512 / elWidth, resultType.getElementType());
auto concatOp = rewriter.create<aievec::ConcatOp>(
splatOp.getLoc(), aievecBcastType, SmallVector<Value>({src, src}));
bcastOp.getLoc(), aievecBcastType, SmallVector<Value>({src, src}));
auto aieBcastOp = rewriter.create<aievec::BroadcastOp>(
splatOp.getLoc(), aievecBcastType, concatOp.getResult(), posVal);
rewriter.replaceOpWithNewOp<aievec::ExtOp>(splatOp, resultType,
bcastOp.getLoc(), aievecBcastType, concatOp.getResult(), posVal);
rewriter.replaceOpWithNewOp<aievec::ExtOp>(bcastOp, resultType,
aieBcastOp.getResult(), 0);
} else if (laneSize * elWidth == 1024) {
// e.g. need v32int32 due to the subsequent v32acc32 operation
Expand All @@ -549,12 +549,12 @@ struct FoldVectorExtractAndSplatToAIEBroadcast
auto half = static_cast<int8_t>(posVal / resultType.getNumElements());
posVal -= half * resultType.getNumElements();
auto extOp =
rewriter.create<aievec::ExtOp>(splatOp.getLoc(), aievecBcastType, src,
rewriter.create<aievec::ExtOp>(bcastOp.getLoc(), aievecBcastType, src,
rewriter.getI8IntegerAttr(half));
auto aieBcastOp = rewriter.create<aievec::BroadcastOp>(
splatOp.getLoc(), aievecBcastType, extOp.getResult(), posVal);
bcastOp.getLoc(), aievecBcastType, extOp.getResult(), posVal);
rewriter.replaceOpWithNewOp<aievec::ConcatOp>(
splatOp, resultType,
bcastOp, resultType,
SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
} else {
return failure();
Expand All @@ -564,57 +564,57 @@ struct FoldVectorExtractAndSplatToAIEBroadcast
}
};

struct ConvertSplatToAIEBroadcast : OpConversionPattern<vector::SplatOp> {
struct ConvertSplatToAIEBroadcast : OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (adaptor.getInput().getDefiningOp<vector::ExtractOp>())
if (adaptor.getSource().getDefiningOp<vector::ExtractOp>())
return failure();

auto resultType = cast<VectorType>(splatOp.getResult().getType());
auto resultType = cast<VectorType>(bcastOp.getResult().getType());
auto flatResultType = getFlattenedVectorType(resultType);
Type scalarType = resultType.getElementType();
unsigned elWidth = scalarType.getIntOrFloatBitWidth();
unsigned laneSize = getVectorLaneSize(resultType);
auto src = splatOp.getInput();
auto src = bcastOp.getSource();

if (laneSize * elWidth == 512) {
Value newOp = rewriter.create<aievec::BroadcastScalarOp>(
splatOp.getLoc(), flatResultType, src);
bcastOp.getLoc(), flatResultType, src);
if (resultType != flatResultType)
newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
newOp = rewriter.create<vector::ShapeCastOp>(bcastOp.getLoc(),
resultType, newOp);
rewriter.replaceOp(splatOp, newOp);
rewriter.replaceOp(bcastOp, newOp);
return success();
}

if (laneSize * elWidth == 256) {
VectorType vecType = createVectorType(512 / elWidth, scalarType);
auto aieBcastOp = rewriter.create<aievec::BroadcastScalarOp>(
splatOp.getLoc(), vecType, src);
bcastOp.getLoc(), vecType, src);
Value newOp = rewriter.create<aievec::ExtOp>(
splatOp.getLoc(), flatResultType, aieBcastOp.getResult(), 0);
bcastOp.getLoc(), flatResultType, aieBcastOp.getResult(), 0);
if (resultType != flatResultType)
newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
newOp = rewriter.create<vector::ShapeCastOp>(bcastOp.getLoc(),
resultType, newOp);
rewriter.replaceOp(splatOp, newOp);
rewriter.replaceOp(bcastOp, newOp);
return success();
}

if (laneSize * elWidth == 1024) {
VectorType vecType = createVectorType(512 / elWidth, scalarType);
auto aieBcastOp = rewriter.create<aievec::BroadcastScalarOp>(
splatOp.getLoc(), vecType, src);
bcastOp.getLoc(), vecType, src);
Value newOp = rewriter.create<aievec::ConcatOp>(
splatOp.getLoc(), flatResultType,
bcastOp.getLoc(), flatResultType,
SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
if (resultType != flatResultType)
newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
newOp = rewriter.create<vector::ShapeCastOp>(bcastOp.getLoc(),
resultType, newOp);
rewriter.replaceOp(splatOp, newOp);
rewriter.replaceOp(bcastOp, newOp);
return success();
}

Expand Down Expand Up @@ -961,19 +961,19 @@ struct FoldSplatToFMAOp : OpConversionPattern<aievec::aie1::FMAOp> {
dyn_cast<aievec::ConcatOp>(adaptor.getLhs().getDefiningOp());
if (!concatOp)
return failure();
vector::SplatOp splatOp = nullptr;
vector::BroadcastOp bcastOp = nullptr;
auto *concatDefOp = concatOp.getSources()[0].getDefiningOp();
if (concatDefOp)
splatOp = dyn_cast<vector::SplatOp>(concatDefOp);
bcastOp = dyn_cast<vector::BroadcastOp>(concatDefOp);
Value lhs = adaptor.getRhs();
if (!splatOp) {
splatOp = dyn_cast<vector::SplatOp>(adaptor.getRhs().getDefiningOp());
if (!splatOp)
if (!bcastOp) {
bcastOp = dyn_cast<vector::BroadcastOp>(adaptor.getRhs().getDefiningOp());
if (!bcastOp)
return failure();
lhs = concatOp.getSources()[0];
}
auto extOp =
dyn_cast<vector::ExtractOp>(splatOp.getInput().getDefiningOp());
dyn_cast<vector::ExtractOp>(bcastOp.getSource().getDefiningOp());
if (!extOp)
return failure();

Expand Down Expand Up @@ -3540,18 +3540,18 @@ static void configureAIEVecV1Legalizations(ConversionTarget &target,
if (!concatOp)
return true;

vector::SplatOp srcSplat = nullptr;
vector::BroadcastOp srcBcast = nullptr;
if (auto *lhsOp = concatOp.getSources()[0].getDefiningOp())
srcSplat = dyn_cast<vector::SplatOp>(lhsOp);
if (!srcSplat) {
srcBcast = dyn_cast<vector::BroadcastOp>(lhsOp);
if (!srcBcast) {
auto *rhsOp = op.getRhs().getDefiningOp();
if (!rhsOp)
return true;
srcSplat = dyn_cast<vector::SplatOp>(rhsOp);
srcBcast = dyn_cast<vector::BroadcastOp>(rhsOp);
}

if (srcSplat)
if (auto *srcOp = srcSplat.getInput().getDefiningOp())
if (srcBcast)
if (auto *srcOp = srcBcast.getSource().getDefiningOp())
return !isa<vector::ExtractOp>(srcOp);

return true;
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ struct ConvertSplatTransferReadToBroadcastPattern
adaptor.getPadding());
auto extractOp = rewriter.create<vector::ExtractOp>(
readOp.getLoc(), newReadOp.getResult(), ArrayRef<int64_t>{offset});
rewriter.replaceOpWithNewOp<vector::SplatOp>(
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
readOp, newReadOp.getVector().getType(), extractOp.getResult());
return success();
}
Expand All @@ -276,7 +276,7 @@ struct HoistCastOpToDataSourcePattern : public RewritePattern {
return failure();

// At the moment, we only accept ops we know we can swap with cast.
if (!isa<vector::BroadcastOp, vector::ExtractOp, vector::SplatOp,
if (!isa<vector::BroadcastOp, vector::ExtractOp,
vector::ExtractStridedSliceOp>(defOp))
return failure();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
// (c) Copyright 2024 AMD Inc.

// RUN: aie-opt --verify-diagnostics --aie-materialize-bd-chains %s
// XFAIL:*

// This test ensures that the correct error gets emitted when a BD "chain" is not
// actually a proper chain, i.e. some blocks are not connected.
Expand Down
4 changes: 2 additions & 2 deletions test/dialect/AIEVec/precanonicalization-aieml-llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// CHECK-LABEL: @scalar_extsi_to_broadcast_swap(
// CHECK-SAME: %[[SIN:.*]]: i8
func.func @scalar_extsi_to_broadcast_swap(%s: i8) -> vector<32xi32> {
// CHECK: %[[SPLAT:.*]] = vector.splat %[[SIN]] : vector<32xi8>
// CHECK: %[[EXT:.*]] = arith.extsi %[[SPLAT]] : vector<32xi8> to vector<32xi32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SIN]] : i8 to vector<32xi8>
// CHECK: %[[EXT:.*]] = arith.extsi %[[BCAST]] : vector<32xi8> to vector<32xi32>
%0 = arith.extsi %s : i8 to i32
%1 = vector.broadcast %0 : i32 to vector<32xi32>
return %1 : vector<32xi32>
Expand Down
22 changes: 11 additions & 11 deletions test/dialect/AIEVec/precanonicalization.mlir
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
// RUN: aie-opt %s -canonicalize-vector-for-aievec -canonicalize -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @splat(
// CHECK-LABEL: func.func @broadcast(
// CHECK-SAME: %[[MEM:.*]]: memref<?xi32>,
// CHECK-SAME: %[[POS:.*]]: index
func.func @splat(%m : memref<?xi32>, %pos : index) -> vector<8xi32> {
func.func @broadcast(%m : memref<?xi32>, %pos : index) -> vector<8xi32> {
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
%c0_i32 = arith.constant 0 : i32
%i = affine.apply affine_map<(d0) -> (d0 + 5)>(%pos)
// CHECK: %[[V:.*]] = vector.transfer_read %[[MEM]][%[[POS]]], %[[C0]] : memref<?xi32>, vector<8xi32>
// CHECK: %[[E:.*]] = vector.extract %[[V]][5] : i32 from vector<8xi32>
// CHECK: %[[S:.*]] = vector.splat %[[E]] : vector<8xi32>
// CHECK: %[[B:.*]] = vector.broadcast %[[E]] : i32 to vector<8xi32>
%v = vector.transfer_read %m[%i], %c0_i32 {in_bounds = [true], permutation_map = affine_map<(d0) -> (0)>} : memref<?xi32>, vector<8xi32>
// CHECK: return %[[S]] : vector<8xi32>
// CHECK: return %[[B]] : vector<8xi32>
return %v : vector<8xi32>
}

// -----

// CHECK: #[[IDXMAP:.*]] = affine_map<()[s0] -> (s0 + 24)>
// CHECK-LABEL: func.func @far_splat(
// CHECK-LABEL: func.func @far_broadcast(
// CHECK-SAME: %[[MEM:.*]]: memref<?xi32>,
// CHECK-SAME: %[[POS:.*]]: index
func.func @far_splat(%m : memref<?xi32>, %pos : index) -> vector<8xi32> {
func.func @far_broadcast(%m : memref<?xi32>, %pos : index) -> vector<8xi32> {
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
%c0_i32 = arith.constant 0 : i32
// CHECK: %[[IDX:.*]] = affine.apply #[[IDXMAP]]()[%[[POS]]]
%i = affine.apply affine_map<(d0) -> (d0 + 29)>(%pos)
// CHECK: %[[V:.*]] = vector.transfer_read %[[MEM]][%[[IDX]]], %[[C0]] : memref<?xi32>, vector<8xi32>
// CHECK: %[[E:.*]] = vector.extract %[[V]][5] : i32 from vector<8xi32>
// CHECK: %[[S:.*]] = vector.splat %[[E]] : vector<8xi32>
// CHECK: %[[B:.*]] = vector.broadcast %[[E]] : i32 to vector<8xi32>
%v = vector.transfer_read %m[%i], %c0_i32 {in_bounds = [true], permutation_map = affine_map<(d0) -> (0)>} : memref<?xi32>, vector<8xi32>
// CHECK: return %[[S]] : vector<8xi32>
// CHECK: return %[[B]] : vector<8xi32>
return %v : vector<8xi32>
}

Expand Down Expand Up @@ -61,9 +61,9 @@ func.func @rank_zero_transfer_read(%m : memref<i16>) -> vector<16xi16> {
// CHECK-DAG: %[[EXPMEM:.*]] = memref.expand_shape %[[MEM]] [] output_shape [1] : memref<i16> into memref<1xi16>
// CHECK: %[[LV:.*]] = vector.transfer_read %[[EXPMEM]][%[[C0idx]]], %[[C0i16]] : memref<1xi16>, vector<16xi16>
// CHECK: %[[E:.*]] = vector.extract %[[LV]][0] : i16 from vector<16xi16>
// CHECK: %[[S:.*]] = vector.splat %[[E]] : vector<16xi16>
// CHECK: %[[B:.*]] = vector.broadcast %[[E]] : i16 to vector<16xi16>
%v = vector.transfer_read %m[], %c0_i16 {in_bounds = [true], permutation_map = affine_map<()->(0)>} : memref<i16>, vector<16xi16>
// CHECK: return %[[S]] : vector<16xi16>
// CHECK: return %[[B]] : vector<16xi16>
return %v : vector<16xi16>
}

Expand All @@ -75,7 +75,7 @@ func.func @extsi_hoisting_through_extract_n_bcast(%v : vector<16xi8>)
-> vector<32xi32> {
// CHECK: %[[EXV:.*]] = arith.extsi %[[VEC]] : vector<16xi8> to vector<16xi32>
// CHECK: %[[EXS:.*]] = vector.extract %[[EXV]][7] : i32 from vector<16xi32>
// CHECK: %[[BCAST:.*]] = vector.splat %[[EXS]] : vector<32xi32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXS]] : i32 to vector<32xi32>
// CHECK: return %[[BCAST]] : vector<32xi32>
%si8 = vector.extract %v[7] : i8 from vector<16xi8>
%vi8 = vector.broadcast %si8 : i8 to vector<32xi8>
Expand Down
4 changes: 2 additions & 2 deletions utils/clone-llvm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
##===----------------------------------------------------------------------===##

# The LLVM commit to use.
LLVM_PROJECT_COMMIT=d1e43f6c1a28c3c64b3655be1fc1aff1029c48c8
DATETIME=2025073104
LLVM_PROJECT_COMMIT=064f02dac0c81c19350a74415b3245f42fed09dc
DATETIME=2025090500
WHEEL_VERSION=22.0.0.$DATETIME+${LLVM_PROJECT_COMMIT:0:8}

############################################################################################
Expand Down
Loading