diff --git a/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp b/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp index 3c2301e4b1..cce5c6ca9f 100644 --- a/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp +++ b/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp @@ -491,6 +491,28 @@ static bool matchExpOpForLUT(math::ExpOp::Adaptor adaptor) { return true; } +static bool matchInvOpForLUT(arith::DivFOp::Adaptor adaptor, + arith::DivFOp divOp) { + Type srcType = adaptor.getLhs().getType(); + if (!divOp->hasOneUse() || isa(srcType) || + !isa(srcType)) + return false; + + if (!isNarrowingOp(*divOp->getUsers().begin())) + return false; + + auto fType = cast(srcType); + if (fType.getWidth() != 32) + return false; + + auto constOp = divOp.getLhs().getDefiningOp(); + if (!constOp || + cast(constOp.getValue()).getValue().convertToDouble() != + 1.0f) { + return false; + } + return true; +} //===----------------------------------------------------------------------===// // Rewrite patterns @@ -2010,6 +2032,34 @@ struct ComputeExpOpByLUTPattern : OpConversionPattern { } }; +struct ComputeInvOpByLUTLLVMPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!matchInvOpForLUT(adaptor, divOp)) + return failure(); + + StringRef funcName = "getInvBf16"; + auto moduleOp = divOp->getParentOfType(); + Type floatTy = rewriter.getF32Type(); + Type bfloat16Ty = rewriter.getBF16Type(); + func::FuncOp fn_op = + getOrInsertFuncDecl(rewriter, moduleOp, funcName, TypeRange{floatTy}, + TypeRange{bfloat16Ty}); + + auto truncOp = cast(*divOp->getUsers().begin()); + + rewriter.setInsertionPoint(truncOp); + SmallVector invOperands = {adaptor.getRhs()}; + rewriter.replaceOpWithNewOp(truncOp, fn_op, invOperands); + rewriter.eraseOp(divOp); + + return success(); + } +}; + // Lower the inverse of a float to a function call // Convert the pattern- // %cst = arith.constant 1.000000e+00 : f32 @@ -2023,24 +2073,8 @@ struct ComputeInvOpByLUTPattern : OpConversionPattern { LogicalResult matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type srcType = adaptor.getLhs().getType(); - if (!divOp->hasOneUse() || isa(srcType) || - !isa(srcType)) + if (!matchInvOpForLUT(adaptor, divOp)) return failure(); - - if (!isNarrowingOp(*divOp->getUsers().begin())) - return failure(); - - auto fType = cast(srcType); - if (fType.getWidth() != 32) - return failure(); - - auto constOp = dyn_cast(divOp.getLhs().getDefiningOp()); - if (!constOp || - cast(constOp.getValue()).getValue().convertToDouble() != - 1.0f) - return failure(); - StringRef includeName = "lut_based_ops.h"; auto moduleOp = divOp->getParentOfType(); rewriter.setInsertionPointToStart( @@ -3095,6 +3129,7 @@ static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns, >(patterns.getContext(), 128, 1024, 256, 1024); patterns.add< ComputeExpOpByLUTPattern, + ComputeInvOpByLUTPattern, LowerVectorAddFOpToAIEVecAddElemOp, LowerVectorSubFOpToAIEVecSubElemOp, LowerVectorAddIOpToAIEVecAddElemOp, @@ -3102,11 +3137,11 @@ static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns, >(patterns.getContext()); } else if (backend == TargetBackend::LLVMIR){ patterns.add< - ComputeExpOpByLUTLLVMPattern + ComputeExpOpByLUTLLVMPattern, + ComputeInvOpByLUTLLVMPattern >(patterns.getContext()); } patterns.add< - ComputeInvOpByLUTPattern, ComputeTanhOpByLUTPattern, ComputeSqrtOpPattern, ComputeRsqrtOpPattern, diff --git a/test/Conversion/VectorToAIEVec/test_lut_based_ops.mlir b/test/Conversion/VectorToAIEVec/test_lut_based_ops.mlir index 282b1c3dc0..6a51f5ba2b 100644 --- a/test/Conversion/VectorToAIEVec/test_lut_based_ops.mlir +++ b/test/Conversion/VectorToAIEVec/test_lut_based_ops.mlir @@ -4,6 +4,7 @@ // CHECK-LABEL: func private @getExpBf16(vector<16xbf16>) -> vector<8xi64> // CHECK-LABEL: func @test_exp_lut // CHECK-SAME: %[[A:[A-Za-z0-9]+]]: vector<16xbf16> +module{ func.func @test_exp_lut(%a: vector<16xbf16>) -> vector<16xbf16> { // CHECK: %[[C0:.*]] = arith.constant 0 : i32 // CHECK: %[[CALL:.*]] = call @getExpBf16(%[[A]]) : (vector<16xbf16>) -> vector<8xi64> @@ -13,3 +14,19 @@ func.func @test_exp_lut(%a: vector<16xbf16>) -> vector<16xbf16> { // CHECK: return %[[SRS]] : vector<16xbf16> return %0 : vector<16xbf16> } + +} + +module{ +// CHECK-LABEL: func private @getInvBf16(f32) -> bf16 +// CHECK-LABEL: func @test_inv_lut +// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: f32 +func.func @test_inv_lut(%a: f32) -> bf16{ + // CHECK: %[[RET:.*]] = call @getInvBf16(%[[A]]) : (f32) -> bf16 + %cst = arith.constant 1.000000e+00 : f32 + %0 = arith.divf %cst, %a : f32 + %1 = arith.truncf %0 : f32 to bf16 + // CHECK: return %[[RET]] : bf16 + return %1 : bf16 +} +} diff --git a/test/unit_tests/aievec_tests/bf16_inv_lut/bf16_inv_lut-llvm.mlir b/test/unit_tests/aievec_tests/bf16_inv_lut/bf16_inv_lut-llvm.mlir new file mode 100644 index 0000000000..2b964184eb --- /dev/null +++ b/test/unit_tests/aievec_tests/bf16_inv_lut/bf16_inv_lut-llvm.mlir @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Copyright (C) 2023, Advanced Micro Devices, Inc. + +// REQUIRES: valid_xchess_license +// REQUIRES: peano, peano_and_chess +// RUN: mkdir -p %t/data; cd %t +// RUN: aie-opt %s --mlir-print-ir-after-all -affine-super-vectorize="virtual-vector-size=16" %vector-to-llvmir% -o llvmir.mlir >& mlir_passes.txt +// RUN: aie-translate --mlir-to-llvmir llvmir.mlir -o dut_part.ll +// RUN: %PEANO_INSTALL_DIR/bin/clang -S -emit-llvm %clang_aie2_lib_args -I%aie_runtime_lib%/AIE2/ -c %S/dut_simple.cc -o lut_based_ops.ll +// RUN: %PEANO_INSTALL_DIR/bin/clang -S -emit-llvm %clang_aie2_lib_args -c %aie_runtime_lib%/AIE2/lut_based_ops.cpp -o lut_constants.ll +// RUN: llvm-link -S lut_based_ops.ll dut_part.ll -o dut_functions.ll +// RUN: llvm-link -S lut_constants.ll dut_functions.ll -o dut.ll +// RUN: %PEANO_INSTALL_DIR/bin/clang %clang_aie2_args -c dut.ll -o dut.o +// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I%aie_runtime_lib%/AIE2 -I %aietools/include -DTO_LLVM -D__AIEARCH__=20 -D__AIENGINE__ -I. %S/testbench.cc dut.o +// RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout +// RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s +// CHECK: TEST PASSED + +module { + func.func @dut(%arg0: memref<1024xbf16>{llvm.noalias}, %arg1: f32, %arg2: memref<1024xbf16>{llvm.noalias}) { + memref.assume_alignment %arg0, 32 : memref<1024xbf16> + memref.assume_alignment %arg2, 32 : memref<1024xbf16> + %cst = arith.constant 1.000000e+00 : f32 + %0 = arith.divf %cst, %arg1 : f32 + %1 = arith.truncf %0 : f32 to bf16 + affine.for %arg3 = 0 to 1024 { + %2 = affine.load %arg0[%arg3] : memref<1024xbf16> + %3 = arith.mulf %1, %2 : bf16 + affine.store %3, %arg2[%arg3] : memref<1024xbf16> + } + return + } +} + diff --git a/test/unit_tests/aievec_tests/bf16_inv_lut/bf16_inv_lut.mlir b/test/unit_tests/aievec_tests/bf16_inv_lut/bf16_inv_lut.mlir index 9d0b435cd7..e0a786efa3 100644 --- a/test/unit_tests/aievec_tests/bf16_inv_lut/bf16_inv_lut.mlir +++ b/test/unit_tests/aievec_tests/bf16_inv_lut/bf16_inv_lut.mlir @@ -2,11 +2,11 @@ // Copyright (C) 2023, Advanced Micro Devices, Inc. // REQUIRES: valid_xchess_license +// RUN: mkdir -p %t/data; cd %t // RUN: aie-opt %s -affine-super-vectorize="virtual-vector-size=16" --convert-vector-to-aievec="aie-target=aie2" -lower-affine | aie-translate -aie2=true --aievec-to-cpp -o dut.cc // RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I%aie_runtime_lib%/AIE2 -I %aietools/include -D__AIEARCH__=20 -D__AIENGINE__ -I. -c %aie_runtime_lib%/AIE2/lut_based_ops.cpp -o lut_based_ops.o // RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I%aie_runtime_lib%/AIE2 -I %aietools/include -D__AIEARCH__=20 -D__AIENGINE__ -I. -c dut.cc -o dut.o // RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I%aie_runtime_lib%/AIE2 -I %aietools/include -D__AIEARCH__=20 -D__AIENGINE__ -I. %S/testbench.cc work/dut.o work/lut_based_ops.o -// RUN: mkdir -p data // RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout // RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s // CHECK: TEST PASSED diff --git a/test/unit_tests/aievec_tests/bf16_inv_lut/dut_simple.cc b/test/unit_tests/aievec_tests/bf16_inv_lut/dut_simple.cc new file mode 100644 index 0000000000..2f34961ad5 --- /dev/null +++ b/test/unit_tests/aievec_tests/bf16_inv_lut/dut_simple.cc @@ -0,0 +1 @@ +#include "lut_based_ops.h" diff --git a/test/unit_tests/aievec_tests/bf16_inv_lut/testbench.cc b/test/unit_tests/aievec_tests/bf16_inv_lut/testbench.cc index 1f0dd03cc8..39f5f3a8dc 100644 --- a/test/unit_tests/aievec_tests/bf16_inv_lut/testbench.cc +++ b/test/unit_tests/aievec_tests/bf16_inv_lut/testbench.cc @@ -5,7 +5,13 @@ #include #include +#ifdef TO_LLVM +extern "C" { +#endif void dut(bfloat16 *restrict in0, float sum, bfloat16 *restrict out0); +#ifdef TO_LLVM +} +#endif void dut_ref(bfloat16 *in0, float sum, bfloat16 *out0); alignas(32) bfloat16 g_in0[IN0_SIZE];