-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[NVPTX] Select bfloat16 add/mul/sub as fma on SM80 #121065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
4b5a501
to
299a919
Compare
@llvm/pr-subscribers-backend-nvptx @llvm/pr-subscribers-llvm-selectiondag Author: None (peterbell10) ChangesSM80 has fma for bfloat16 but not add/mul/sub. Currently these are just promoted to f32 but we can instead write them in terms of the fma:
Unfortunately there is no This is also the inverse of some generic DAGCombiner patterns, so I've had to add checks to avoid it reversing the legalization which would cause an infinite loop. Patch is 40.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121065.diff 9 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6cbfef2d238bbe..a50ac311c82869 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -17534,10 +17534,13 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
return N2;
}
+ const bool PreferFMAAdd = (TLI.isOperationLegal(ISD::FMA, VT) &&
+ !TLI.isOperationLegal(ISD::FADD, VT));
+
// FIXME: Support splat of constant.
- if (N0CFP && N0CFP->isExactlyValue(1.0))
+ if (!PreferFMAAdd && N0CFP && N0CFP->isExactlyValue(1.0))
return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
- if (N1CFP && N1CFP->isExactlyValue(1.0))
+ if (!PreferFMAAdd && N1CFP && N1CFP->isExactlyValue(1.0))
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
// Canonicalize (fma c, x, y) -> (fma x, c, y)
@@ -17569,7 +17572,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
// (fma x, -1, y) -> (fadd (fneg x), y)
// FIXME: Support splat of constant.
- if (N1CFP) {
+ if (N1CFP && !PreferFMAAdd) {
if (N1CFP->isExactlyValue(1.0))
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
@@ -17579,15 +17582,14 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
AddToWorklist(RHSNeg.getNode());
return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
}
-
- // fma (fneg x), K, y -> fma x -K, y
- if (matcher.match(N0, ISD::FNEG) &&
- (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
- (N1.hasOneUse() &&
- !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
- return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
- matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
- }
+ }
+ // fma (fneg x), K, y -> fma x -K, y
+ if (N1CFP && matcher.match(N0, ISD::FNEG) &&
+ (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
+ (N1.hasOneUse() &&
+ !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
+ return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
+ matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
}
// FIXME: Support splat of constant.
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 5c1f717694a4c7..47f56abae3c056 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -853,6 +853,16 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
+ // Lower bf16 add/mul/sub as fma when it avoids promotion
+ for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
+ for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
+ if (getOperationAction(Op, VT) != Legal &&
+ getOperationAction(ISD::FMA, VT) == Legal) {
+ setOperationAction(Op, VT, Custom);
+ }
+ }
+ }
+
// f16/f16x2 neg was introduced in PTX 60, SM_53.
const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
STI.getPTXVersion() >= 60 &&
@@ -2490,6 +2500,62 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
}
+static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
+ EVT VT = N->getValueType(0);
+ EVT NVT = MVT::f32;
+ if (VT.isVector()) {
+ NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
+ }
+ SDLoc DL(N);
+ SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
+ SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
+ SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
+ return DAG.getFPExtendOrRound(Res, DL, VT);
+}
+
+SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
+
+ // FADD(a, b) -> FMA(a, 1.0, b)
+ SDLoc DL(Op);
+ auto VT = Op.getValueType();
+ auto One = DAG.getConstantFP(1.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
+
+ // FSUB(a, b) -> FMA(b, -1.0, a)
+ SDLoc DL(Op);
+ auto VT = Op.getValueType();
+ auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
+ Op->getOperand(0)};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
+
+ // FMUL(a, b) -> FMA(a, b, 0.0)
+ SDLoc DL(Op);
+ auto VT = Op.getValueType();
+ auto Zero = DAG.getConstantFP(0.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1), Zero};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
SelectionDAG &DAG) const {
assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
@@ -2681,6 +2747,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerSTACKSAVE(Op, DAG);
case ISD::CopyToReg:
return LowerCopyToReg_128(Op, DAG);
+ case ISD::FADD:
+ return LowerFADD(Op, DAG);
+ case ISD::FSUB:
+ return LowerFSUB(Op, DAG);
+ case ISD::FMUL:
+ return LowerFMUL(Op, DAG);
+
default:
llvm_unreachable("Custom lowering not defined for operation");
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 4a98fe21b81dc6..b7d32dd5327646 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -279,6 +279,10 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFADD(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFSUB(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
+
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/NVPTX/atomics-sm90.ll b/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
index f81b785f13225c..67552b95e04915 100644
--- a/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
+++ b/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
@@ -46,58 +46,52 @@ define void @test(ptr %dp0, ptr addrspace(1) %dp1, ptr addrspace(3) %dp3, bfloat
; CHECKPTX71-LABEL: test(
; CHECKPTX71: {
; CHECKPTX71-NEXT: .reg .pred %p<5>;
-; CHECKPTX71-NEXT: .reg .b16 %rs<22>;
+; CHECKPTX71-NEXT: .reg .b16 %rs<26>;
; CHECKPTX71-NEXT: .reg .b32 %r<4>;
-; CHECKPTX71-NEXT: .reg .f32 %f<12>;
; CHECKPTX71-EMPTY:
; CHECKPTX71-NEXT: // %bb.0:
; CHECKPTX71-NEXT: ld.param.b16 %rs13, [test_param_3];
; CHECKPTX71-NEXT: ld.param.u32 %r3, [test_param_2];
; CHECKPTX71-NEXT: ld.param.u32 %r2, [test_param_1];
; CHECKPTX71-NEXT: ld.param.u32 %r1, [test_param_0];
-; CHECKPTX71-NEXT: ld.b16 %rs18, [%r1];
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f1, %rs13;
+; CHECKPTX71-NEXT: ld.b16 %rs22, [%r1];
; CHECKPTX71-NEXT: $L__BB0_1: // %atomicrmw.start14
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f2, %rs18;
-; CHECKPTX71-NEXT: add.rn.f32 %f3, %f2, %f1;
-; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs14, %f3;
-; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs18, %rs14;
-; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs18;
-; CHECKPTX71-NEXT: mov.u16 %rs18, %rs3;
+; CHECKPTX71-NEXT: mov.b16 %rs14, 0x3F80;
+; CHECKPTX71-NEXT: fma.rn.bf16 %rs15, %rs22, %rs14, %rs13;
+; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs22, %rs15;
+; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs22;
+; CHECKPTX71-NEXT: mov.u16 %rs22, %rs3;
; CHECKPTX71-NEXT: @%p1 bra $L__BB0_1;
; CHECKPTX71-NEXT: // %bb.2: // %atomicrmw.end13
-; CHECKPTX71-NEXT: ld.b16 %rs19, [%r1];
+; CHECKPTX71-NEXT: ld.b16 %rs23, [%r1];
; CHECKPTX71-NEXT: $L__BB0_3: // %atomicrmw.start8
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f4, %rs19;
-; CHECKPTX71-NEXT: add.rn.f32 %f5, %f4, 0f3F800000;
-; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs15, %f5;
-; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs19, %rs15;
-; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs19;
-; CHECKPTX71-NEXT: mov.u16 %rs19, %rs6;
+; CHECKPTX71-NEXT: mov.b16 %rs16, 0x3F80;
+; CHECKPTX71-NEXT: fma.rn.bf16 %rs17, %rs23, %rs16, %rs16;
+; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs23, %rs17;
+; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs23;
+; CHECKPTX71-NEXT: mov.u16 %rs23, %rs6;
; CHECKPTX71-NEXT: @%p2 bra $L__BB0_3;
; CHECKPTX71-NEXT: // %bb.4: // %atomicrmw.end7
-; CHECKPTX71-NEXT: ld.global.b16 %rs20, [%r2];
+; CHECKPTX71-NEXT: ld.global.b16 %rs24, [%r2];
; CHECKPTX71-NEXT: $L__BB0_5: // %atomicrmw.start2
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f7, %rs20;
-; CHECKPTX71-NEXT: add.rn.f32 %f8, %f7, %f1;
-; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs16, %f8;
-; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs20, %rs16;
-; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs20;
-; CHECKPTX71-NEXT: mov.u16 %rs20, %rs9;
+; CHECKPTX71-NEXT: mov.b16 %rs18, 0x3F80;
+; CHECKPTX71-NEXT: fma.rn.bf16 %rs19, %rs24, %rs18, %rs13;
+; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs24, %rs19;
+; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs24;
+; CHECKPTX71-NEXT: mov.u16 %rs24, %rs9;
; CHECKPTX71-NEXT: @%p3 bra $L__BB0_5;
; CHECKPTX71-NEXT: // %bb.6: // %atomicrmw.end1
-; CHECKPTX71-NEXT: ld.shared.b16 %rs21, [%r3];
+; CHECKPTX71-NEXT: ld.shared.b16 %rs25, [%r3];
; CHECKPTX71-NEXT: $L__BB0_7: // %atomicrmw.start
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f10, %rs21;
-; CHECKPTX71-NEXT: add.rn.f32 %f11, %f10, %f1;
-; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs17, %f11;
-; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs21, %rs17;
-; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs21;
-; CHECKPTX71-NEXT: mov.u16 %rs21, %rs12;
+; CHECKPTX71-NEXT: mov.b16 %rs20, 0x3F80;
+; CHECKPTX71-NEXT: fma.rn.bf16 %rs21, %rs25, %rs20, %rs13;
+; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs25, %rs21;
+; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs25;
+; CHECKPTX71-NEXT: mov.u16 %rs25, %rs12;
; CHECKPTX71-NEXT: @%p4 bra $L__BB0_7;
; CHECKPTX71-NEXT: // %bb.8: // %atomicrmw.end
; CHECKPTX71-NEXT: ret;
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 6828bac18cad7f..eeb13b52130042 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -42,17 +42,14 @@ define bfloat @test_fadd(bfloat %0, bfloat %1) {
;
; SM80-LABEL: test_fadd(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<4>;
-; SM80-NEXT: .reg .f32 %f<4>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_param_0];
; SM80-NEXT: ld.param.b16 %rs2, [test_fadd_param_1];
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs1;
-; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
-; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
+; SM80-NEXT: mov.b16 %rs3, 0x3F80;
+; SM80-NEXT: fma.rn.bf16 %rs4, %rs1, %rs3, %rs2;
+; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fadd(
@@ -113,17 +110,14 @@ define bfloat @test_fsub(bfloat %0, bfloat %1) {
;
; SM80-LABEL: test_fsub(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<4>;
-; SM80-NEXT: .reg .f32 %f<4>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fsub_param_0];
; SM80-NEXT: ld.param.b16 %rs2, [test_fsub_param_1];
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs1;
-; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
-; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
+; SM80-NEXT: mov.b16 %rs3, 0xBF80;
+; SM80-NEXT: fma.rn.bf16 %rs4, %rs2, %rs3, %rs1;
+; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fsub(
@@ -202,23 +196,14 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_faddx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
-; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
-; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: add.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_1];
+; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_0];
+; SM80-NEXT: mov.b32 %r3, 1065369472;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_faddx2(
@@ -303,23 +288,14 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_fsubx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
; SM80-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: mov.b32 %r3, -1082081408;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fsubx2(
@@ -404,23 +380,14 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_fmulx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
-; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_0];
-; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_1];
+; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_0];
+; SM80-NEXT: mov.b32 %r3, 0;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r1, %r3;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fmulx2(
@@ -727,15 +694,13 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
;
; SM80-LABEL: test_fadd_imm_1(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<3>;
-; SM80-NEXT: .reg .f32 %f<3>;
+; SM80-NEXT: .reg .b16 %rs<4>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs2, %f2;
-; SM80-NEXT: st.param.b16 [func_retval0], %rs2;
+; SM80-NEXT: mov.b16 %rs2, 0x3F80;
+; SM80-NEXT: fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
+; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fadd_imm_1(
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
index 03cdeb9683abae..31d089a19450e1 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -22,19 +22,14 @@ define <2 x bfloat> @test_ret_const() #0 {
define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
; SM80-LABEL: test_fadd_imm_0(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<3>;
-; SM80-NEXT: .reg .b32 %r<3>;
-; SM80-NEXT: .reg .f32 %f<5>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fadd_imm_0_param_0];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT: cvt.f32.bf16 %f3, %rs2;
-; SM80-NEXT: add.rn.f32 %f4, %f3, 0f40000000;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r2, %f4, %f2;
-; SM80-NEXT: st.param.b32 [func_retval0], %r2;
+; SM80-NEXT: mov.b32 %r2, 1073758080;
+; SM80-NEXT: mov.b32 %r3, 1065369472;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r1, %r3, %r2;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM90-LABEL: test_fadd_imm_0(
@@ -54,15 +49,13 @@ define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
define bfloat @test_fadd_imm_1(bfloat %a) #0 {
; SM80-LABEL: test_fadd_imm_1(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<3>;
-; SM80-NEXT: .reg .f32 %f<3>;
+; SM80-NEXT: .reg .b16 %rs<4>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs2, %f2;
-; SM80-NEXT: st.param.b16 [func_retval0], %rs2;
+; SM80-NEXT: mov.b16 %rs2, 0x3F80;
+; SM80-NEXT: fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
+; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
; SM80-NEXT: ret;
;
; SM90-LABEL: test_fadd_imm_1(
@@ -82,23 +75,14 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-LABEL: test_fsubx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
; SM80-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: mov.b32 %r3, -1082081408;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM90-LABEL: test_fsubx2(
@@...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of this PR is to eliminate the conversions to/from fp32
, correct? If so, can you add that to the description?
if (getOperationAction(Op, VT) != Legal && | ||
getOperationAction(ISD::FMA, VT) == Legal) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is cumbersome, we usually don't write legalizer rules in terms of other legalizer rules.
I think you'd be best off just putting this logic into the default Expand action for add/fmul/fsub. If the FMA is legal, you emit the appropriate sequence before falling back to the default libcall expansion. Then you shouldn't need to touch the target rules here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is cumbersome, we usually don't write legalizer rules in terms of other legalizer rules.
I could easily write it in terms of SM and PTX version ranges instead, I just felt this more directly expressed the rationale behind the version range.
I think you'd be best off just putting this logic into the default Expand action for add/fmul/fsub. If the FMA is legal, you emit the appropriate sequence before falling back to the default libcall expansion. Then you shouldn't need to touch the target rules here
I'm not sure this makes sense. The FTZ logic is target specific and we also want to fallback to promotion, not a libcall here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this makes sense. The FTZ logic is target specific and we also want to fallback to promotion, not a libcall here.
The custom expansion can defer to the default expansion depending on the function state. Yes, the default Expand action can conditionally use the to FMA path instead of the default libcall expansion. The core transform code can be in the generic legalizer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be simplified and ISD::FADD, ISD::FMUL, ISD::FSUB
mode for bf16 types does not need to be derived from FMA. It should be just set in sync with the action on FMA, at the same place where we do set it for FMA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setBF16OperationAction
only allows two variants, supported or unsupported. So we can't set custom for an in-between range.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the clarification we've arrived while sorting out my confusion about applicability of the patch, instead of relying on the setBF16OperationAction
we can set the appropriate action by checking SM/PTX version range directly. It would be easier to understand that deriving the right action indirectly, from the actions set on other ops.
0fdfca3
to
ee6635b
Compare
Could we just declare bfloat16 add/mul/sub legal when FMA is available, and just lower them as FMA via pattern-matching, without messing with the legalizer and combiner? I'm not quite sure why we want to convert those mul/add to FMA in the DAG itself. |
} | ||
// fma (fneg x), K, y -> fma x -K, y | ||
if (N1CFP && matcher.match(N0, ISD::FNEG) && | ||
(TLI.isOperationLegal(ISD::ConstantFP, VT) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since ConstantFP doesn't get used for vectors, this won't handle them correctly. I thought we had a helper for this but I can't seem to find it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't my code, but there is a FIXME
comment above.
Okay, I've changed this to make add/mul/sub legal and added manual selection code which I believe is required to generate the constant inputs. PTAL. |
689f92a
to
f634361
Compare
@@ -3671,14 +3671,21 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) { | |||
Results.push_back(ExpandConstant(CP)); | |||
break; | |||
} | |||
case ISD::FADD: { | |||
if (SDValue Expand = TLI.expandFADD(Node, DAG)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style nit: No need for {} for single-statement bodies. Applies here and in other places.
if (IsNativelySupported) | ||
return false; | ||
|
||
assert(VT == MVT::bf16 || VT == MVT::v2bf16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assert does nothing here as we would have returned false for non-bf16 types, already.
If we're not expected to be called with non-bf16 types, I'd move the assert upwards, next to where VT
is defined, and use it instead of the early return.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm asserting it's either a scalar or vector of length 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point is that both would be ruled out by if (VT.getScalarType() != MVT::bf16)
above.
auto API = APF.bitcastToAPInt(); | ||
API = API.concat(API); | ||
auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32); | ||
return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does PTX have a way to specify/use bf16 constant args? It would be nice to avoid passing the constant via a register.
I suspect there's no good way to use a constant for bf16x2
, but I would assume that there should be a way to use FP constants for a scalar FMA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No there doesn't seem to be any support for bf16 immediate values, ptxas complains
ptxas /tmp/tmplmhxk1av.ptx, line 79; error : Arguments mismatch for instruction 'fma'
ptxas fatal : Ptx assembly aborted due to errors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Please add a comment about that. Otherwise these register moves look rather questionable.
}; | ||
|
||
switch (N->getOpcode()) { | ||
case ISD::FADD: { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit. We do not need {}
around individual cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
if (getOperationAction(Op, VT) != Legal && | ||
getOperationAction(ISD::FMA, VT) == Legal) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be simplified and ISD::FADD, ISD::FMUL, ISD::FSUB
mode for bf16 types does not need to be derived from FMA. It should be just set in sync with the action on FMA, at the same place where we do set it for FMA.
; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3; | ||
; SM80-NEXT: st.param.b16 [func_retval0], %rs3; | ||
; SM80-NEXT: mov.b16 %rs3, 0x3F80; | ||
; SM80-NEXT: fma.rn.bf16 %rs4, %rs1, %rs3, %rs2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need some tests for FTZ mode, now, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every test here has an SM80-FTZ-NEXT
variant, they just don't show up in the diff because they haven't changed.
f634361
to
ddac0cb
Compare
After checking the PTX spec, the question I've got is -- should we bother with this attempt to lower those ops to FMA at all? It will only be beneficial on These days nobody should be using anything older than CUDA-11.8 and even that is firmly on the way out. |
The spec says
I don't see any suggestion that that only applies for specific PTX versions. |
FMA instruction for bf16 types requires PTX 7.0 and sm_80: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-fma FADD/FMUL for bf16 requires PTX 7.8 and sm_90. https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul So, the only cases where we can benefit from this patch are:
sm_90 or newer GPUs require PTX 7.8 and therefore do not benefit from the patch.
Update: what I missed is that mul/add instructions are still absent for sm_80 even in the newer PTX versions. |
Exactly... It requires sm_90, so for sm_80 there is support for fma but not add/mul/sub. What's confusing here? |
This is where your logic doesn't make sense. Why would PTX >= 7.8 mean that sm_80 doesn't benefit from this patch? |
Ugh. Indeed. You're absolutely right. sm_80 still does not have mul/add, even in newer PTX versions. I was concentrating on sm_90 and somehow ignored that sm_80 is still affected. OK, the patch is still useful. Sorry about the noise. |
93102ec
to
20cb18f
Compare
Are there any remaining blockers here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few nits, but LGTM overall.
if (IsNativelySupported) | ||
return false; | ||
|
||
assert(VT == MVT::bf16 || VT == MVT::v2bf16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point is that both would be ruled out by if (VT.getScalarType() != MVT::bf16)
above.
auto API = APF.bitcastToAPInt(); | ||
API = API.concat(API); | ||
auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32); | ||
return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Please add a comment about that. Otherwise these register moves look rather questionable.
if (getOperationAction(Op, VT) != Legal && | ||
getOperationAction(ISD::FMA, VT) == Legal) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the clarification we've arrived while sorting out my confusion about applicability of the patch, instead of relying on the setBF16OperationAction
we can set the appropriate action by checking SM/PTX version range directly. It would be easier to understand that deriving the right action indirectly, from the actions set on other ops.
SM80 has fma for bfloat16 but not add/mul/sub. Currently these are just promoted to f32 but we can instead write them in terms of the fma: ``` FADD(a, b) -> FMA(a, 1.0, b) FMUL(a, b) -> FMA(a, b, 0.0) FSUB(a, b) -> FMA(b, -1.0, a) ``` Unfortunately there is no `fma.ftz` so when ftz is enabled, we still fall back to promotion. This is also the inverse of some generic DAGCombiner patterns, so I've had to add checks to avoid it reversing the legalization which would cause an infinite loop.
20cb18f
to
5df1167
Compare
5df1167
to
403aaee
Compare
SM80 has fma for bfloat16 but not add/mul/sub. Currently these ops incur a promotion to f32, but we can avoid this by writing them in terms of the fma:
Unfortunately there is no
fma.ftz
so when ftz is enabled, we still fall back to promotion.This is also the inverse of some generic DAGCombiner patterns, so I've had to add checks to avoid it reversing the legalization which would cause an infinite loop.