Skip to content

Commit b47ed8c

Browse files
authored
[CPU] Add contract fast-math-flag to arith operations (#14551)
This patch adds the `contract` FMF to some arith operations so that they can be folded into an fma instruction. We are doing this by default as we are lowering matmul ops by default to fmas. We will add different fp modes to have more control on fp optimizations depending on the tolerance to fp errors.
1 parent d1d03cb commit b47ed8c

File tree

10 files changed

+73
-0
lines changed

10 files changed

+73
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright 2023 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/PassDetail.h"
8+
#include "iree/compiler/Codegen/Common/Passes.h"
9+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
10+
11+
#define DEBUG_TYPE "iree-codegen-add-fast-math-flags"
12+
13+
using namespace mlir;
14+
using namespace mlir::iree_compiler;
15+
16+
/// Add `contract` FMF to operations that support it.
17+
static void addContractFMF(Operation *op) {
18+
LLVM::FastmathFlags contract = LLVM::FastmathFlags::contract;
19+
TypeSwitch<Operation *>(op)
20+
.Case<LLVM::FMulOp, LLVM::FAddOp, LLVM::FSubOp, LLVM::FNegOp>(
21+
[&](auto llvmOp) { llvmOp.setFastmathFlags(contract); });
22+
}
23+
24+
namespace {
25+
26+
/// Add the corresponding fast-math flags to operations given a floating-point
27+
/// optimization mode.
28+
// TODO: For now we only allow default flags, such as arithmetic reassociation.
29+
struct AddFastMathFlagsPass
30+
: public AddFastMathFlagsBase<AddFastMathFlagsPass> {
31+
public:
32+
using AddFastMathFlagsBase::AddFastMathFlagsBase;
33+
34+
void runOnOperation() override {
35+
getOperation()->walk([](Operation *op) { addContractFMF(op); });
36+
}
37+
};
38+
39+
} // namespace
40+
41+
std::unique_ptr<OperationPass<LLVM::LLVMFuncOp>>
42+
mlir::iree_compiler::createAddFastMathFlagsPass() {
43+
return std::make_unique<AddFastMathFlagsPass>();
44+
}

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ iree_compiler_cc_library(
143143
iree_compiler_cc_library(
144144
name = "Common",
145145
srcs = [
146+
"AddFastMathFlags.cpp",
146147
"BubbleUpOrdinalOps.cpp",
147148
"BufferizationAnalysis.cpp",
148149
"BufferizeCopyOnlyDispatchesPass.cpp",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ iree_cc_library(
118118
"Transforms.h"
119119
"UserConfig.h"
120120
SRCS
121+
"AddFastMathFlags.cpp"
121122
"BubbleUpOrdinalOps.cpp"
122123
"BufferizationAnalysis.cpp"
123124
"BufferizeCopyOnlyDispatchesPass.cpp"

compiler/src/iree/compiler/Codegen/Common/PassDetail.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
1111
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1212
#include "mlir/Dialect/Func/IR/FuncOps.h"
13+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1314
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1415
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
1516
#include "mlir/Pass/Pass.h"

compiler/src/iree/compiler/Codegen/Common/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h"
1616
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
17+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1718
#include "mlir/Pass/Pass.h"
1819

1920
namespace mlir {
@@ -37,6 +38,8 @@ void addIREEComprehensiveBufferizePasses(
3738
std::nullopt,
3839
std::optional<BufferizationOptions::MemCpyFn> memCpyFn = std::nullopt);
3940

41+
std::unique_ptr<OperationPass<LLVM::LLVMFuncOp>> createAddFastMathFlagsPass();
42+
4043
/// Pass to bubble up ordinal operations to allow workgroup count computation
4144
/// based on slices to correlate back to workload computation.
4245
std::unique_ptr<Pass> createBubbleUpOrdinalOpsPass();

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ include "mlir/Pass/PassBase.td"
1313
// Common passes for all backends (keep alphabetical)
1414
//===---------------------------------------------------------------------===//
1515

16+
def AddFastMathFlags
17+
: Pass<"iree-codegen-add-fast-math-flags", "LLVM::LLVMFuncOp"> {
18+
let summary = "Add fast math flags to all the operations supporting them, "
19+
"given a floating-point mode.";
20+
let constructor = "mlir::iree_compiler::createAddFastMathFlagsPass()";
21+
}
22+
1623
def BubbleUpOrdinalOps : Pass<"iree-codegen-bubble-up-ordinal-ops", ""> {
1724
let summary = "Bubbles op ordinal ops to allow for workgroup count computation";
1825
let constructor = "mlir::iree_compiler::createBubbleUpOrdinalOpsPass()";

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ iree_lit_test_suite(
1818
name = "lit",
1919
srcs = enforce_glob(
2020
[
21+
"add_fmfs.mlir",
2122
"affinemin_canonicalization.mlir",
2223
"batch_matmuls.mlir",
2324
"bubble_up_ordinal_ops.mlir",

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ iree_lit_test_suite(
1414
NAME
1515
lit
1616
SRCS
17+
"add_fmfs.mlir"
1718
"affinemin_canonicalization.mlir"
1819
"batch_matmuls.mlir"
1920
"bubble_up_ordinal_ops.mlir"
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: iree-opt -iree-codegen-add-fast-math-flags --split-input-file %s | FileCheck %s
2+
3+
// LABEL: llvm.func @fmfs
4+
llvm.func @fmfs() -> f32 {
5+
%c3 = llvm.mlir.constant(3.000000e+00 : f32) : f32
6+
%c6 = llvm.mlir.constant(6.000000e+00 : f32) : f32
7+
%mul = llvm.fmul %c3, %c3 : f32
8+
%add = llvm.fadd %c3, %c6 : f32
9+
llvm.return %add : f32
10+
}
11+
12+
// CHECK: llvm.fmul %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<contract>} : f32
13+
// CHECK: llvm.fadd %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<contract>} : f32

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,7 @@ static void addLowerToLLVMPasses(OpPassManager &passManager) {
752752

753753
passManager.addPass(createCanonicalizerPass());
754754
passManager.addPass(createCSEPass());
755+
passManager.addNestedPass<LLVM::LLVMFuncOp>(createAddFastMathFlagsPass());
755756
}
756757

757758
void buildLLVMCPUCodegenPassPipeline(OpPassManager &passManager) {

0 commit comments

Comments
 (0)