Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aievec] MLIR->LLVM flow for float inverse #1697

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 54 additions & 19 deletions lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,28 @@ static bool matchExpOpForLUT(math::ExpOp::Adaptor adaptor) {

return true;
}
static bool matchInvOpForLUT(arith::DivFOp::Adaptor adaptor,
arith::DivFOp divOp) {
Type srcType = adaptor.getLhs().getType();
if (!divOp->hasOneUse() || isa<VectorType>(srcType) ||
!isa<FloatType>(srcType))
return false;

if (!isNarrowingOp(*divOp->getUsers().begin()))
return false;

auto fType = cast<FloatType>(srcType);
if (fType.getWidth() != 32)
return false;

auto constOp = divOp.getLhs().getDefiningOp<arith::ConstantOp>();
if (!constOp ||
cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, does this FloatAttr floating point comparison include a built-in epsilon for err?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a comparison is between double values, not FloatAttr objects. FloatAttr::getValue() returns an APFloat, and APFloat::convertToDouble() returns a double. So, no epsilon.

1.0f) {
return false;
}
return true;
}

//===----------------------------------------------------------------------===//
// Rewrite patterns
Expand Down Expand Up @@ -2010,6 +2032,34 @@ struct ComputeExpOpByLUTPattern : OpConversionPattern<math::ExpOp> {
}
};

struct ComputeInvOpByLUTLLVMPattern : OpConversionPattern<arith::DivFOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!matchInvOpForLUT(adaptor, divOp))
return failure();

StringRef funcName = "getInvBf16";
auto moduleOp = divOp->getParentOfType<mlir::ModuleOp>();
Type floatTy = rewriter.getF32Type();
Type bfloat16Ty = rewriter.getBF16Type();
func::FuncOp fn_op =
getOrInsertFuncDecl(rewriter, moduleOp, funcName, TypeRange{floatTy},
TypeRange{bfloat16Ty});

auto truncOp = cast<arith::TruncFOp>(*divOp->getUsers().begin());

rewriter.setInsertionPoint(truncOp);
SmallVector<Value> invOperands = {adaptor.getRhs()};
rewriter.replaceOpWithNewOp<func::CallOp>(truncOp, fn_op, invOperands);
rewriter.eraseOp(divOp);

return success();
}
};

// Lower the inverse of a float to a function call
// Convert the pattern-
// %cst = arith.constant 1.000000e+00 : f32
Expand All @@ -2023,24 +2073,8 @@ struct ComputeInvOpByLUTPattern : OpConversionPattern<arith::DivFOp> {
LogicalResult
matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = adaptor.getLhs().getType();
if (!divOp->hasOneUse() || isa<VectorType>(srcType) ||
!isa<FloatType>(srcType))
if (!matchInvOpForLUT(adaptor, divOp))
return failure();

if (!isNarrowingOp(*divOp->getUsers().begin()))
return failure();

auto fType = cast<FloatType>(srcType);
if (fType.getWidth() != 32)
return failure();

auto constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
if (!constOp ||
cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
1.0f)
return failure();

StringRef includeName = "lut_based_ops.h";
auto moduleOp = divOp->getParentOfType<mlir::ModuleOp>();
rewriter.setInsertionPointToStart(
Expand Down Expand Up @@ -3095,18 +3129,19 @@ static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns,
>(patterns.getContext(), 128, 1024, 256, 1024);
patterns.add<
ComputeExpOpByLUTPattern,
ComputeInvOpByLUTPattern,
LowerVectorAddFOpToAIEVecAddElemOp,
LowerVectorSubFOpToAIEVecSubElemOp,
LowerVectorAddIOpToAIEVecAddElemOp,
LowerVectorSubIOpToAIEVecSubElemOp
>(patterns.getContext());
} else if (backend == TargetBackend::LLVMIR){
patterns.add<
ComputeExpOpByLUTLLVMPattern
ComputeExpOpByLUTLLVMPattern,
ComputeInvOpByLUTLLVMPattern
>(patterns.getContext());
}
patterns.add<
ComputeInvOpByLUTPattern,
ComputeTanhOpByLUTPattern,
ComputeSqrtOpPattern,
ComputeRsqrtOpPattern,
Expand Down
17 changes: 17 additions & 0 deletions test/Conversion/VectorToAIEVec/test_lut_based_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// CHECK-LABEL: func private @getExpBf16(vector<16xbf16>) -> vector<8xi64>
// CHECK-LABEL: func @test_exp_lut
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: vector<16xbf16>
module{
func.func @test_exp_lut(%a: vector<16xbf16>) -> vector<16xbf16> {
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
// CHECK: %[[CALL:.*]] = call @getExpBf16(%[[A]]) : (vector<16xbf16>) -> vector<8xi64>
Expand All @@ -13,3 +14,19 @@ func.func @test_exp_lut(%a: vector<16xbf16>) -> vector<16xbf16> {
// CHECK: return %[[SRS]] : vector<16xbf16>
return %0 : vector<16xbf16>
}

}

module{
// CHECK-LABEL: func private @getInvBf16(f32) -> bf16
// CHECK-LABEL: func @test_inv_lut
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: f32
func.func @test_inv_lut(%a: f32) -> bf16{
// CHECK: %[[RET:.*]] = call @getInvBf16(%[[A]]) : (f32) -> bf16
%cst = arith.constant 1.000000e+00 : f32
%0 = arith.divf %cst, %a : f32
%1 = arith.truncf %0 : f32 to bf16
// CHECK: return %[[RET]] : bf16
return %1 : bf16
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Copyright (C) 2023, Advanced Micro Devices, Inc.

// REQUIRES: valid_xchess_license
// REQUIRES: peano, peano_and_chess
// RUN: mkdir -p %t/data; cd %t
// RUN: aie-opt %s --mlir-print-ir-after-all -affine-super-vectorize="virtual-vector-size=16" %vector-to-llvmir% -o llvmir.mlir >& mlir_passes.txt
// RUN: aie-translate --mlir-to-llvmir llvmir.mlir -o dut_part.ll
// RUN: %PEANO_INSTALL_DIR/bin/clang -S -emit-llvm %clang_aie2_lib_args -I%aie_runtime_lib%/AIE2/ -c %S/dut_simple.cc -o lut_based_ops.ll
// RUN: %PEANO_INSTALL_DIR/bin/clang -S -emit-llvm %clang_aie2_lib_args -c %aie_runtime_lib%/AIE2/lut_based_ops.cpp -o lut_constants.ll
// RUN: llvm-link -S lut_based_ops.ll dut_part.ll -o dut_functions.ll
// RUN: llvm-link -S lut_constants.ll dut_functions.ll -o dut.ll
// RUN: %PEANO_INSTALL_DIR/bin/clang %clang_aie2_args -c dut.ll -o dut.o
// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I%aie_runtime_lib%/AIE2 -I %aietools/include -DTO_LLVM -D__AIEARCH__=20 -D__AIENGINE__ -I. %S/testbench.cc dut.o
// RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout
// RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s
// CHECK: TEST PASSED

module {
func.func @dut(%arg0: memref<1024xbf16>{llvm.noalias}, %arg1: f32, %arg2: memref<1024xbf16>{llvm.noalias}) {
memref.assume_alignment %arg0, 32 : memref<1024xbf16>
memref.assume_alignment %arg2, 32 : memref<1024xbf16>
%cst = arith.constant 1.000000e+00 : f32
%0 = arith.divf %cst, %arg1 : f32
%1 = arith.truncf %0 : f32 to bf16
affine.for %arg3 = 0 to 1024 {
%2 = affine.load %arg0[%arg3] : memref<1024xbf16>
%3 = arith.mulf %1, %2 : bf16
affine.store %3, %arg2[%arg3] : memref<1024xbf16>
}
return
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
// Copyright (C) 2023, Advanced Micro Devices, Inc.

// REQUIRES: valid_xchess_license
// RUN: mkdir -p %t/data; cd %t
// RUN: aie-opt %s -affine-super-vectorize="virtual-vector-size=16" --convert-vector-to-aievec="aie-target=aie2" -lower-affine | aie-translate -aie2=true --aievec-to-cpp -o dut.cc
// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I%aie_runtime_lib%/AIE2 -I %aietools/include -D__AIEARCH__=20 -D__AIENGINE__ -I. -c %aie_runtime_lib%/AIE2/lut_based_ops.cpp -o lut_based_ops.o
// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I%aie_runtime_lib%/AIE2 -I %aietools/include -D__AIEARCH__=20 -D__AIENGINE__ -I. -c dut.cc -o dut.o
// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I%aie_runtime_lib%/AIE2 -I %aietools/include -D__AIEARCH__=20 -D__AIENGINE__ -I. %S/testbench.cc work/dut.o work/lut_based_ops.o
// RUN: mkdir -p data
// RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout
// RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s
// CHECK: TEST PASSED
Expand Down
1 change: 1 addition & 0 deletions test/unit_tests/aievec_tests/bf16_inv_lut/dut_simple.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include "lut_based_ops.h"
6 changes: 6 additions & 0 deletions test/unit_tests/aievec_tests/bf16_inv_lut/testbench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
#include <cstdio>
#include <cstdlib>

#ifdef TO_LLVM
extern "C" {
#endif
void dut(bfloat16 *restrict in0, float sum, bfloat16 *restrict out0);
#ifdef TO_LLVM
}
#endif
void dut_ref(bfloat16 *in0, float sum, bfloat16 *out0);

alignas(32) bfloat16 g_in0[IN0_SIZE];
Expand Down
Loading