Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build_tools/ci/cpu_comparison/input_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def generate_and_write_input(bin_fn, nb_elements, element_type, input_number, se

# Random integer values in range [lower_bound, upper_bound)
lower_bound = 0
upper_bound = 10
upper_bound = 2

rng = get_generator(seed)

Expand Down
47 changes: 38 additions & 9 deletions build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def get_dir(self, config):
def get_filename(self, config):
return self.get_dir(config) / f"{self.name}.mlir"

# Correctness
def vs_cpu(self, config):
filename = self.get_filename(config)

Expand Down Expand Up @@ -2518,16 +2519,44 @@ def __init__(self):
# )

# Reduction op tests:
self.register(
Reduction(
file_base_name="reduction_sum",
function_name="reduction_sum",
test_params=TestParams(
name_suffix="sum",
tile_pipeline="general-copy",
),
for data_type in ["bf16"]:
# for shapes in [
# "2x128",
# # "16x128",
# # # "32x128",
# # # "64x128",
# # # "128x128",
# # # "256x128",
# # # "512x128",
# # # "1024x128",
# # # "2048x128",
# # # "4096x128",
# ]:
# custom_input = 1.0 * np.ones((2, 16), dtype=np.float16) # f32
# print(custom_input)
self.register(
Reduction(
file_base_name=f"reduction_sum_{data_type}",
function_name=f"reduction_sum",
test_params=TestParams(
tile_pipeline="general-copy",
run_on_target=["npu4"],
use_chess=False,
use_chess_for_ukernel=False,
use_ukernel=False,
enable_ctrlpkt=False,
run_benchmark=False, # Correctness
stack_size=2048,
lower_to_aie_pipeline="objectFifo",
n_repeats=2,
n_kernel_runs=10, # This converts to 10xn_kernel_runs, how?
n_reconfigure_runs=0,
# preset_inputs={1: custom_input},
aie_compilation_flags=["--iree-hal-target-backends=amd-aie"
]
),
)
)
)

# Soak testing.
# See https://github.com/nod-ai/iree-amd-aie/issues/1264
Expand Down
21 changes: 21 additions & 0 deletions build_tools/ci/cpu_comparison/test_files/reduction_sum_bf16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// These lines are required for e2e numerical testing:
// input 2x16xbf16
// output 2xbf16

// Constraints:<D0xD1>
// D0 = [8, no-limit]
// D1 = [16, 1024]

!in_ty = tensor<2x16xbf16>
!out_ty = tensor<2xbf16>
func.func @reduction_sum(%arg0: !in_ty) -> !out_ty {
%cst = arith.constant 0.0 : bf16
%3 = tensor.empty() : !out_ty
%4 = linalg.fill ins(%cst : bf16) outs(%3 : !out_ty) -> !out_ty
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : !in_ty) outs(%4 : !out_ty) {
^bb0(%in: bf16, %out: bf16):
%6 = arith.addf %in, %out : bf16
linalg.yield %6 : bf16
} -> !out_ty
return %5 : !out_ty
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
// These lines are required for e2e numerical testing:
// input 1024x128xf32
// output 1024xf32
// input 2x128xf32
// output 2xf32

!in_ty = tensor<1024x128xf32>
!out_ty = tensor<1024xf32>
// Constraints:<D0xD1>
// Format: [Min, Max]
// D0 = [2, no-limit]
// D1 = [2, 256]
!in_ty = tensor<2x128xf32>
!out_ty = tensor<2xf32>

func.func @reduction_sum(%arg0: !in_ty) -> !out_ty {
%cst = arith.constant 0.0 : f32
Expand Down
42 changes: 42 additions & 0 deletions build_tools/ci/cpu_comparison/test_files/treads.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: iree-opt %s -pass-pipeline='builtin.module(func.func(iree-amdaie-vectorization))' | FileCheck %s
// iree-opt --pass-pipeline='builtin.module(func.func(iree-amdaie-vectorization))' treads.mlir --debug-only=vector-unroll &> after.mlir

#map2 = affine_map<(d0) -> (d0 * 16)>

module {
func.func @test_transfer_read_alloc() -> vector<32xbf16> {
%alloc = memref.alloc() : memref<32xbf16, 2 : i32>
%c0_11 = arith.constant 0 : index
%10 = ub.poison : bf16 // to note, it has poison value whereas the type in transfer_read is a vector
%13 = vector.transfer_read %alloc[%c0_11], %10 {in_bounds = [true]} : memref<32xbf16, 2 : i32>, vector<32xbf16>
return %13 : vector<32xbf16>
}

func.func @test_transfer_write() {
%cst = arith.constant dense<0.000000e+00> : vector<32xbf16>
%c0_11 = arith.constant 0 : index
%alloc = memref.alloc() : memref<32xbf16, 2 : i32>
vector.transfer_write %cst, %alloc[%c0_11] {in_bounds = [true]} : vector<32xbf16>, memref<32xbf16, 2 : i32>
return
}

func.func @test_transfer_read_subview() -> vector<32x16xbf16> {
// constants
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c0_11 = arith.constant 0 : index
%10 = ub.poison : bf16

// memory allocation
%alloc_0 = memref.alloc() : memref<32x128xbf16, 2 : i32>
// apply the affine map, currently static on index 0
%11 = affine.apply #map2(%c0)
%subview = memref.subview %alloc_0[0, %11] [32, 16] [1, 1] : memref<32x128xbf16, 2 : i32> to memref<32x16xbf16, strided<[128, 1], offset: ?>, 2 : i32>
%12 = vector.transfer_read %subview[%c0_11, %c0_11], %10 {in_bounds = [true, true]} : memref<32x16xbf16, strided<[128, 1], offset: ?>, 2 : i32>, vector<32x16xbf16>

return %12 : vector<32x16xbf16>
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -1002,38 +1002,7 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target) {
});

target.addDynamicallyLegalOp<vector::ReductionOp>(
[=](vector::ReductionOp op) {
if (auto kind = op.getKind(); kind != vector::CombiningKind::ADD &&
kind != vector::CombiningKind::MINSI &&
kind != vector::CombiningKind::MINUI &&
kind != vector::CombiningKind::MINIMUMF &&
kind != vector::CombiningKind::MAXSI &&
kind != vector::CombiningKind::MAXUI &&
kind != vector::CombiningKind::MAXIMUMF)
return true;

auto vType = dyn_cast<VectorType>(op.getVector().getType());
if (!vType) return true;

llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
laneSizeElWidthPairSet.insert({64, 8});
laneSizeElWidthPairSet.insert({32, 16});
laneSizeElWidthPairSet.insert({32, 32});
laneSizeElWidthPairSet.insert({16, 32});

Type scalarType = vType.getElementType();
unsigned elWidth = scalarType.getIntOrFloatBitWidth();
unsigned laneSize = aievec::getVectorLaneSize(vType);

if (isa<IntegerType>(scalarType) &&
!laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
return true;

if (isa<FloatType>(scalarType) && laneSize != 16 && laneSize != 32)
return true;

return false;
});
[=](vector::ReductionOp op) { return true; });

target.addIllegalOp<vector::ContractionOp, vector::TransposeOp>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ namespace {
// GenericVectorization will be extended in the future to support more
// AIE-specific vectorization patterns.

// converts {x, y} -> {1, y}
static std::optional<SmallVector<int64_t>> vectorLeadingOneDimShape(
Operation *op) {
auto vectorOp = dyn_cast<VectorUnrollOpInterface>(op);
if (!vectorOp) return std::nullopt;
auto shape = vectorOp.getShapeForUnroll();
if (!shape) return std::nullopt;

// For 2D vectors, unroll to 1D
if (shape->size() == 2) {
// Return {1, shape[1]} to unroll the first dimension
return SmallVector<int64_t>{1, (*shape)[1]};
}
return std::nullopt;
}

void populateVectorUnrollPatterns(RewritePatternSet &vectorizationPatterns) {
vector::UnrollVectorOptions options;
options.setNativeShapeFn(vectorLeadingOneDimShape);
vector::populateVectorUnrollPatterns(vectorizationPatterns, options);
}

class AMDAIEVectorizationPass
: public impl::AMDAIEVectorizationBase<AMDAIEVectorizationPass> {
public:
Expand Down Expand Up @@ -82,17 +104,16 @@ void AMDAIEVectorizationPass::runOnOperation() {
if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
if (isElementwise(genericOp)) {
for (Operation &innerOps : genericOp.getBody()->getOperations()) {
if (!isa<arith::TruncFOp, arith::TruncIOp, linalg::YieldOp>(
if (!isa<arith::TruncFOp, arith::TruncIOp, linalg::YieldOp, arith::AddFOp>(
innerOps)) {
return WalkResult::advance();
}
}
}
}

// AIE architecture has no vector instructions for 32/64-bit types.
if (!isa<linalg::FillOp>(op) && !hasOperandWithSmallElementType(op))
return WalkResult::advance();
// if (!isa<linalg::FillOp>(op) && !hasOperandWithSmallElementType(op))
// return WalkResult::advance();

candidates.push_back(op);
return WalkResult::advance();
Expand All @@ -107,17 +128,34 @@ void AMDAIEVectorizationPass::runOnOperation() {
}

RewritePatternSet vectorizationPatterns(funcOp.getContext());

vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
vector::populateSinkVectorOpsPatterns(vectorizationPatterns);

// TODO: Do we really need belowpattern?
// Including this pattern prevents broadcasting in vector.transfer_read ops
vector::populateVectorTransferPermutationMapLoweringPatterns(
vectorizationPatterns);

vector::populateVectorMultiReductionLoweringPatterns(
vectorizationPatterns,
vector::VectorMultiReductionLowering::InnerReduction);
// Converting transfer_read/writes -> vector.loads/stores
{
vector::populateVectorToVectorCanonicalizationPatterns(
vectorizationPatterns);
// 1. unroll
populateVectorUnrollPatterns(vectorizationPatterns);
// 2. Fully convert 2D->1D
// vector<1x10xbf16> -> vector<10xbf16>
vector::populateCastAwayVectorLeadingOneDimPatterns(vectorizationPatterns,
/*benefit=*/1);
// 3. Convert to vector.load/store
vector::populateVectorTransferLoweringPatterns(
vectorizationPatterns); // converts transfer.read -> vector.loads
vector::populateVectorToVectorCanonicalizationPatterns(
vectorizationPatterns);
}

(void)applyPatternsGreedily(funcOp, std::move(vectorizationPatterns));
}
} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -878,12 +878,17 @@ static LogicalResult setRootConfigForReductionCopyPipeline(
llvm::cast<ShapedType>(linalgOp.getDpsInputOperand(0)->get().getType())
.getShape();
assert(inputShape.size() == 2 && "expected the input as 2D");
int64_t m1Tile = std::min<int64_t>(inputShape[0], 32);
int64_t m1Tile = std::min<int64_t>(inputShape[0], 16);
int64_t m0Tile = std::min<int64_t>(inputShape[0], numRows * numCols * m1Tile);

// '0' here refers to the python way of saying all values in the respective
// dimension Example: {D0xD1} -----> <m0TilexD1}
SmallVector<int64_t> tileSizeLevel0 = {m0Tile, 0};
SmallVector<int64_t> tileSizeLevel1 = {m1Tile, 0};
SmallVector<int64_t> tileSizeLevel2 = {0, 0};
// Peano legalizer vectorizes <32xbf16>, split reduction dimension into 32
// chunks otherwise Peano generates scalarized code causing out of program
// memory errors
SmallVector<int64_t> tileSizeLevel2 = {0, 32};
if (failed(setOpConfigAndEntryPointFnTranslation(
entryPointFn, linalgOp,
TileSizesListType{tileSizeLevel0, tileSizeLevel1, tileSizeLevel2},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,11 @@ void addMLIRAIELoweringPasses(OpPassManager &pm,
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addPass(aievec::createConvertAIEVecToLLVMPass());
pm.addPass(createConvertVectorToLLVMPass());
{
ConvertVectorToLLVMPassOptions opts{};
opts.reassociateFPReductions = true;
pm.addPass(createConvertVectorToLLVMPass(opts));
}
pm.addPass(memref::createExpandStridedMetadataPass());
pm.addPass(createLowerAffinePass());
pm.addPass(createConvertMathToLLVMPass());
Expand Down
Loading