Skip to content

[NVPTX] Select bfloat16 add/mul/sub as fma on SM80 #121065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 16, 2025

Conversation

peterbell10
Copy link
Contributor

@peterbell10 peterbell10 commented Dec 24, 2024

SM80 has fma for bfloat16 but not add/mul/sub. Currently these ops incur a promotion to f32, but we can avoid this by writing them in terms of the fma:

FADD(a, b) -> FMA(a, 1.0, b)
FMUL(a, b) -> FMA(a, b, -0.0)
FSUB(a, b) -> FMA(b, -1.0, a)

Unfortunately there is no fma.ftz so when ftz is enabled, we still fall back to promotion.

This is also the inverse of some generic DAGCombiner patterns, so I've had to add checks to avoid it reversing the legalization which would cause an infinite loop.

Copy link

github-actions bot commented Dec 24, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@peterbell10 peterbell10 force-pushed the nvptx-fma-sm80 branch 2 times, most recently from 4b5a501 to 299a919 Compare December 24, 2024 20:30
@peterbell10 peterbell10 marked this pull request as ready for review December 24, 2024 21:46
@llvmbot llvmbot added backend:NVPTX llvm:SelectionDAG SelectionDAGISel as well labels Dec 24, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 24, 2024

@llvm/pr-subscribers-backend-nvptx

@llvm/pr-subscribers-llvm-selectiondag

Author: None (peterbell10)

Changes

SM80 has fma for bfloat16 but not add/mul/sub. Currently these are just promoted to f32 but we can instead write them in terms of the fma:

FADD(a, b) -> FMA(a, 1.0, b)
FMUL(a, b) -> FMA(a, b, 0.0)
FSUB(a, b) -> FMA(b, -1.0, a)

Unfortunately there is no fma.ftz so when ftz is enabled, we still fall back to promotion.

This is also the inverse of some generic DAGCombiner patterns, so I've had to add checks to avoid it reversing the legalization which would cause an infinite loop.


Patch is 40.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121065.diff

9 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+14-12)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+73)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+4)
  • (modified) llvm/test/CodeGen/NVPTX/atomics-sm90.ll (+25-31)
  • (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+28-63)
  • (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll (+37-103)
  • (modified) llvm/test/CodeGen/NVPTX/fma-relu-contract.ll (+12-48)
  • (modified) llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll (+12-38)
  • (modified) llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll (+24-86)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6cbfef2d238bbe..a50ac311c82869 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -17534,10 +17534,13 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
       return N2;
   }
 
+  const bool PreferFMAAdd = (TLI.isOperationLegal(ISD::FMA, VT) &&
+                             !TLI.isOperationLegal(ISD::FADD, VT));
+
   // FIXME: Support splat of constant.
-  if (N0CFP && N0CFP->isExactlyValue(1.0))
+  if (!PreferFMAAdd && N0CFP && N0CFP->isExactlyValue(1.0))
     return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
-  if (N1CFP && N1CFP->isExactlyValue(1.0))
+  if (!PreferFMAAdd && N1CFP && N1CFP->isExactlyValue(1.0))
     return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
 
   // Canonicalize (fma c, x, y) -> (fma x, c, y)
@@ -17569,7 +17572,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
 
   // (fma x, -1, y) -> (fadd (fneg x), y)
   // FIXME: Support splat of constant.
-  if (N1CFP) {
+  if (N1CFP && !PreferFMAAdd) {
     if (N1CFP->isExactlyValue(1.0))
       return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
 
@@ -17579,15 +17582,14 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
       AddToWorklist(RHSNeg.getNode());
       return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
     }
-
-    // fma (fneg x), K, y -> fma x -K, y
-    if (matcher.match(N0, ISD::FNEG) &&
-        (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
-         (N1.hasOneUse() &&
-          !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
-      return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
-                             matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
-    }
+  }
+  // fma (fneg x), K, y -> fma x -K, y
+  if (N1CFP && matcher.match(N0, ISD::FNEG) &&
+      (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
+       (N1.hasOneUse() &&
+        !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
+    return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
+                           matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
   }
 
   // FIXME: Support splat of constant.
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 5c1f717694a4c7..47f56abae3c056 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -853,6 +853,16 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
       AddPromotedToType(Op, MVT::bf16, MVT::f32);
   }
 
+  // Lower bf16 add/mul/sub as fma when it avoids promotion
+  for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
+    for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
+      if (getOperationAction(Op, VT) != Legal &&
+          getOperationAction(ISD::FMA, VT) == Legal) {
+        setOperationAction(Op, VT, Custom);
+      }
+    }
+  }
+
   // f16/f16x2 neg was introduced in PTX 60, SM_53.
   const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
                                         STI.getPTXVersion() >= 60 &&
@@ -2490,6 +2500,62 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
   return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
 }
 
+static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
+  EVT VT = N->getValueType(0);
+  EVT NVT = MVT::f32;
+  if (VT.isVector()) {
+    NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
+  }
+  SDLoc DL(N);
+  SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
+  SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
+  SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
+  return DAG.getFPExtendOrRound(Res, DL, VT);
+}
+
+SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
+  // No fma.ftz for bf16, so fall back to promotion
+  if (useF32FTZ(DAG.getMachineFunction())) {
+    return PromoteBinOpToF32(Op.getNode(), DAG);
+  }
+
+  // FADD(a, b) -> FMA(a, 1.0, b)
+  SDLoc DL(Op);
+  auto VT = Op.getValueType();
+  auto One = DAG.getConstantFP(1.0, DL, VT);
+  SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
+  return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
+  // No fma.ftz for bf16, so fall back to promotion
+  if (useF32FTZ(DAG.getMachineFunction())) {
+    return PromoteBinOpToF32(Op.getNode(), DAG);
+  }
+
+  // FSUB(a, b) -> FMA(b, -1.0, a)
+  SDLoc DL(Op);
+  auto VT = Op.getValueType();
+  auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
+  SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
+                                   Op->getOperand(0)};
+  return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+  // No fma.ftz for bf16, so fall back to promotion
+  if (useF32FTZ(DAG.getMachineFunction())) {
+    return PromoteBinOpToF32(Op.getNode(), DAG);
+  }
+
+  // FMUL(a, b) -> FMA(a, b, 0.0)
+  SDLoc DL(Op);
+  auto VT = Op.getValueType();
+  auto Zero = DAG.getConstantFP(0.0, DL, VT);
+  SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1), Zero};
+  return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
 SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
                                             SelectionDAG &DAG) const {
   assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
@@ -2681,6 +2747,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerSTACKSAVE(Op, DAG);
   case ISD::CopyToReg:
     return LowerCopyToReg_128(Op, DAG);
+  case ISD::FADD:
+    return LowerFADD(Op, DAG);
+  case ISD::FSUB:
+    return LowerFSUB(Op, DAG);
+  case ISD::FMUL:
+    return LowerFMUL(Op, DAG);
+
   default:
     llvm_unreachable("Custom lowering not defined for operation");
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 4a98fe21b81dc6..b7d32dd5327646 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -279,6 +279,10 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerFADD(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFSUB(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
 
diff --git a/llvm/test/CodeGen/NVPTX/atomics-sm90.ll b/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
index f81b785f13225c..67552b95e04915 100644
--- a/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
+++ b/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
@@ -46,58 +46,52 @@ define void @test(ptr %dp0, ptr addrspace(1) %dp1, ptr addrspace(3) %dp3, bfloat
 ; CHECKPTX71-LABEL: test(
 ; CHECKPTX71:       {
 ; CHECKPTX71-NEXT:    .reg .pred %p<5>;
-; CHECKPTX71-NEXT:    .reg .b16 %rs<22>;
+; CHECKPTX71-NEXT:    .reg .b16 %rs<26>;
 ; CHECKPTX71-NEXT:    .reg .b32 %r<4>;
-; CHECKPTX71-NEXT:    .reg .f32 %f<12>;
 ; CHECKPTX71-EMPTY:
 ; CHECKPTX71-NEXT:  // %bb.0:
 ; CHECKPTX71-NEXT:    ld.param.b16 %rs13, [test_param_3];
 ; CHECKPTX71-NEXT:    ld.param.u32 %r3, [test_param_2];
 ; CHECKPTX71-NEXT:    ld.param.u32 %r2, [test_param_1];
 ; CHECKPTX71-NEXT:    ld.param.u32 %r1, [test_param_0];
-; CHECKPTX71-NEXT:    ld.b16 %rs18, [%r1];
-; CHECKPTX71-NEXT:    cvt.f32.bf16 %f1, %rs13;
+; CHECKPTX71-NEXT:    ld.b16 %rs22, [%r1];
 ; CHECKPTX71-NEXT:  $L__BB0_1: // %atomicrmw.start14
 ; CHECKPTX71-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT:    cvt.f32.bf16 %f2, %rs18;
-; CHECKPTX71-NEXT:    add.rn.f32 %f3, %f2, %f1;
-; CHECKPTX71-NEXT:    cvt.rn.bf16.f32 %rs14, %f3;
-; CHECKPTX71-NEXT:    atom.cas.b16 %rs3, [%r1], %rs18, %rs14;
-; CHECKPTX71-NEXT:    setp.ne.s16 %p1, %rs3, %rs18;
-; CHECKPTX71-NEXT:    mov.u16 %rs18, %rs3;
+; CHECKPTX71-NEXT:    mov.b16 %rs14, 0x3F80;
+; CHECKPTX71-NEXT:    fma.rn.bf16 %rs15, %rs22, %rs14, %rs13;
+; CHECKPTX71-NEXT:    atom.cas.b16 %rs3, [%r1], %rs22, %rs15;
+; CHECKPTX71-NEXT:    setp.ne.s16 %p1, %rs3, %rs22;
+; CHECKPTX71-NEXT:    mov.u16 %rs22, %rs3;
 ; CHECKPTX71-NEXT:    @%p1 bra $L__BB0_1;
 ; CHECKPTX71-NEXT:  // %bb.2: // %atomicrmw.end13
-; CHECKPTX71-NEXT:    ld.b16 %rs19, [%r1];
+; CHECKPTX71-NEXT:    ld.b16 %rs23, [%r1];
 ; CHECKPTX71-NEXT:  $L__BB0_3: // %atomicrmw.start8
 ; CHECKPTX71-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT:    cvt.f32.bf16 %f4, %rs19;
-; CHECKPTX71-NEXT:    add.rn.f32 %f5, %f4, 0f3F800000;
-; CHECKPTX71-NEXT:    cvt.rn.bf16.f32 %rs15, %f5;
-; CHECKPTX71-NEXT:    atom.cas.b16 %rs6, [%r1], %rs19, %rs15;
-; CHECKPTX71-NEXT:    setp.ne.s16 %p2, %rs6, %rs19;
-; CHECKPTX71-NEXT:    mov.u16 %rs19, %rs6;
+; CHECKPTX71-NEXT:    mov.b16 %rs16, 0x3F80;
+; CHECKPTX71-NEXT:    fma.rn.bf16 %rs17, %rs23, %rs16, %rs16;
+; CHECKPTX71-NEXT:    atom.cas.b16 %rs6, [%r1], %rs23, %rs17;
+; CHECKPTX71-NEXT:    setp.ne.s16 %p2, %rs6, %rs23;
+; CHECKPTX71-NEXT:    mov.u16 %rs23, %rs6;
 ; CHECKPTX71-NEXT:    @%p2 bra $L__BB0_3;
 ; CHECKPTX71-NEXT:  // %bb.4: // %atomicrmw.end7
-; CHECKPTX71-NEXT:    ld.global.b16 %rs20, [%r2];
+; CHECKPTX71-NEXT:    ld.global.b16 %rs24, [%r2];
 ; CHECKPTX71-NEXT:  $L__BB0_5: // %atomicrmw.start2
 ; CHECKPTX71-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT:    cvt.f32.bf16 %f7, %rs20;
-; CHECKPTX71-NEXT:    add.rn.f32 %f8, %f7, %f1;
-; CHECKPTX71-NEXT:    cvt.rn.bf16.f32 %rs16, %f8;
-; CHECKPTX71-NEXT:    atom.global.cas.b16 %rs9, [%r2], %rs20, %rs16;
-; CHECKPTX71-NEXT:    setp.ne.s16 %p3, %rs9, %rs20;
-; CHECKPTX71-NEXT:    mov.u16 %rs20, %rs9;
+; CHECKPTX71-NEXT:    mov.b16 %rs18, 0x3F80;
+; CHECKPTX71-NEXT:    fma.rn.bf16 %rs19, %rs24, %rs18, %rs13;
+; CHECKPTX71-NEXT:    atom.global.cas.b16 %rs9, [%r2], %rs24, %rs19;
+; CHECKPTX71-NEXT:    setp.ne.s16 %p3, %rs9, %rs24;
+; CHECKPTX71-NEXT:    mov.u16 %rs24, %rs9;
 ; CHECKPTX71-NEXT:    @%p3 bra $L__BB0_5;
 ; CHECKPTX71-NEXT:  // %bb.6: // %atomicrmw.end1
-; CHECKPTX71-NEXT:    ld.shared.b16 %rs21, [%r3];
+; CHECKPTX71-NEXT:    ld.shared.b16 %rs25, [%r3];
 ; CHECKPTX71-NEXT:  $L__BB0_7: // %atomicrmw.start
 ; CHECKPTX71-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT:    cvt.f32.bf16 %f10, %rs21;
-; CHECKPTX71-NEXT:    add.rn.f32 %f11, %f10, %f1;
-; CHECKPTX71-NEXT:    cvt.rn.bf16.f32 %rs17, %f11;
-; CHECKPTX71-NEXT:    atom.shared.cas.b16 %rs12, [%r3], %rs21, %rs17;
-; CHECKPTX71-NEXT:    setp.ne.s16 %p4, %rs12, %rs21;
-; CHECKPTX71-NEXT:    mov.u16 %rs21, %rs12;
+; CHECKPTX71-NEXT:    mov.b16 %rs20, 0x3F80;
+; CHECKPTX71-NEXT:    fma.rn.bf16 %rs21, %rs25, %rs20, %rs13;
+; CHECKPTX71-NEXT:    atom.shared.cas.b16 %rs12, [%r3], %rs25, %rs21;
+; CHECKPTX71-NEXT:    setp.ne.s16 %p4, %rs12, %rs25;
+; CHECKPTX71-NEXT:    mov.u16 %rs25, %rs12;
 ; CHECKPTX71-NEXT:    @%p4 bra $L__BB0_7;
 ; CHECKPTX71-NEXT:  // %bb.8: // %atomicrmw.end
 ; CHECKPTX71-NEXT:    ret;
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 6828bac18cad7f..eeb13b52130042 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -42,17 +42,14 @@ define bfloat @test_fadd(bfloat %0, bfloat %1) {
 ;
 ; SM80-LABEL: test_fadd(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<4>;
-; SM80-NEXT:    .reg .f32 %f<4>;
+; SM80-NEXT:    .reg .b16 %rs<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b16 %rs1, [test_fadd_param_0];
 ; SM80-NEXT:    ld.param.b16 %rs2, [test_fadd_param_1];
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs1;
-; SM80-NEXT:    add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs3, %f3;
-; SM80-NEXT:    st.param.b16 [func_retval0], %rs3;
+; SM80-NEXT:    mov.b16 %rs3, 0x3F80;
+; SM80-NEXT:    fma.rn.bf16 %rs4, %rs1, %rs3, %rs2;
+; SM80-NEXT:    st.param.b16 [func_retval0], %rs4;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fadd(
@@ -113,17 +110,14 @@ define bfloat @test_fsub(bfloat %0, bfloat %1) {
 ;
 ; SM80-LABEL: test_fsub(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<4>;
-; SM80-NEXT:    .reg .f32 %f<4>;
+; SM80-NEXT:    .reg .b16 %rs<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b16 %rs1, [test_fsub_param_0];
 ; SM80-NEXT:    ld.param.b16 %rs2, [test_fsub_param_1];
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs1;
-; SM80-NEXT:    sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs3, %f3;
-; SM80-NEXT:    st.param.b16 [func_retval0], %rs3;
+; SM80-NEXT:    mov.b16 %rs3, 0xBF80;
+; SM80-NEXT:    fma.rn.bf16 %rs4, %rs2, %rs3, %rs1;
+; SM80-NEXT:    st.param.b16 [func_retval0], %rs4;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fsub(
@@ -202,23 +196,14 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_faddx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<5>;
-; SM80-NEXT:    .reg .b32 %r<4>;
-; SM80-NEXT:    .reg .f32 %f<7>;
+; SM80-NEXT:    .reg .b32 %r<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
-; SM80-NEXT:    ld.param.b32 %r1, [test_faddx2_param_0];
-; SM80-NEXT:    ld.param.b32 %r2, [test_faddx2_param_1];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT:    add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT:    add.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
+; SM80-NEXT:    ld.param.b32 %r1, [test_faddx2_param_1];
+; SM80-NEXT:    ld.param.b32 %r2, [test_faddx2_param_0];
+; SM80-NEXT:    mov.b32 %r3, 1065369472;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_faddx2(
@@ -303,23 +288,14 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_fsubx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<5>;
-; SM80-NEXT:    .reg .b32 %r<4>;
-; SM80-NEXT:    .reg .f32 %f<7>;
+; SM80-NEXT:    .reg .b32 %r<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b32 %r1, [test_fsubx2_param_0];
 ; SM80-NEXT:    ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT:    sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT:    sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
+; SM80-NEXT:    mov.b32 %r3, -1082081408;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fsubx2(
@@ -404,23 +380,14 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_fmulx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<5>;
-; SM80-NEXT:    .reg .b32 %r<4>;
-; SM80-NEXT:    .reg .f32 %f<7>;
+; SM80-NEXT:    .reg .b32 %r<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
-; SM80-NEXT:    ld.param.b32 %r1, [test_fmulx2_param_0];
-; SM80-NEXT:    ld.param.b32 %r2, [test_fmulx2_param_1];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT:    mul.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT:    mul.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
+; SM80-NEXT:    ld.param.b32 %r1, [test_fmulx2_param_1];
+; SM80-NEXT:    ld.param.b32 %r2, [test_fmulx2_param_0];
+; SM80-NEXT:    mov.b32 %r3, 0;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r1, %r3;
+; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fmulx2(
@@ -727,15 +694,13 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
 ;
 ; SM80-LABEL: test_fadd_imm_1(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<3>;
-; SM80-NEXT:    .reg .f32 %f<3>;
+; SM80-NEXT:    .reg .b16 %rs<4>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs2, %f2;
-; SM80-NEXT:    st.param.b16 [func_retval0], %rs2;
+; SM80-NEXT:    mov.b16 %rs2, 0x3F80;
+; SM80-NEXT:    fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
+; SM80-NEXT:    st.param.b16 [func_retval0], %rs3;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fadd_imm_1(
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
index 03cdeb9683abae..31d089a19450e1 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -22,19 +22,14 @@ define <2 x bfloat> @test_ret_const() #0 {
 define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
 ; SM80-LABEL: test_fadd_imm_0(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<3>;
-; SM80-NEXT:    .reg .b32 %r<3>;
-; SM80-NEXT:    .reg .f32 %f<5>;
+; SM80-NEXT:    .reg .b32 %r<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b32 %r1, [test_fadd_imm_0_param_0];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT:    cvt.f32.bf16 %f3, %rs2;
-; SM80-NEXT:    add.rn.f32 %f4, %f3, 0f40000000;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r2, %f4, %f2;
-; SM80-NEXT:    st.param.b32 [func_retval0], %r2;
+; SM80-NEXT:    mov.b32 %r2, 1073758080;
+; SM80-NEXT:    mov.b32 %r3, 1065369472;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r1, %r3, %r2;
+; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
 ;
 ; SM90-LABEL: test_fadd_imm_0(
@@ -54,15 +49,13 @@ define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
 define bfloat @test_fadd_imm_1(bfloat %a) #0 {
 ; SM80-LABEL: test_fadd_imm_1(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<3>;
-; SM80-NEXT:    .reg .f32 %f<3>;
+; SM80-NEXT:    .reg .b16 %rs<4>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs2, %f2;
-; SM80-NEXT:    st.param.b16 [func_retval0], %rs2;
+; SM80-NEXT:    mov.b16 %rs2, 0x3F80;
+; SM80-NEXT:    fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
+; SM80-NEXT:    st.param.b16 [func_retval0], %rs3;
 ; SM80-NEXT:    ret;
 ;
 ; SM90-LABEL: test_fadd_imm_1(
@@ -82,23 +75,14 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
 define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-LABEL: test_fsubx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<5>;
-; SM80-NEXT:    .reg .b32 %r<4>;
-; SM80-NEXT:    .reg .f32 %f<7>;
+; SM80-NEXT:    .reg .b32 %r<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b32 %r1, [test_fsubx2_param_0];
 ; SM80-NEXT:    ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT:    sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT:    sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
+; SM80-NEXT:    mov.b32 %r3, -1082081408;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
 ;
 ; SM90-LABEL: test_fsubx2(
@@...
[truncated]

Copy link
Contributor

@justinfargnoli justinfargnoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this PR is to eliminate the conversions to/from fp32, correct? If so, can you add that to the description?

Comment on lines 859 to 860
if (getOperationAction(Op, VT) != Legal &&
getOperationAction(ISD::FMA, VT) == Legal) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cumbersome, we usually don't write legalizer rules in terms of other legalizer rules.

I think you'd be best off just putting this logic into the default Expand action for add/fmul/fsub. If the FMA is legal, you emit the appropriate sequence before falling back to the default libcall expansion. Then you shouldn't need to touch the target rules here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cumbersome, we usually don't write legalizer rules in terms of other legalizer rules.

I could easily write it in terms of SM and PTX version ranges instead, I just felt this more directly expressed the rationale behind the version range.

I think you'd be best off just putting this logic into the default Expand action for add/fmul/fsub. If the FMA is legal, you emit the appropriate sequence before falling back to the default libcall expansion. Then you shouldn't need to touch the target rules here

I'm not sure this makes sense. The FTZ logic is target specific and we also want to fallback to promotion, not a libcall here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this makes sense. The FTZ logic is target specific and we also want to fallback to promotion, not a libcall here.

The custom expansion can defer to the default expansion depending on the function state. Yes, the default Expand action can conditionally use the to FMA path instead of the default libcall expansion. The core transform code can be in the generic legalizer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be simplified and ISD::FADD, ISD::FMUL, ISD::FSUB mode for bf16 types does not need to be derived from FMA. It should be just set in sync with the action on FMA, at the same place where we do set it for FMA.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setBF16OperationAction only allows two variants, supported or unsupported. So we can't set custom for an in-between range.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the clarification we've arrived while sorting out my confusion about applicability of the patch, instead of relying on the setBF16OperationAction we can set the appropriate action by checking SM/PTX version range directly. It would be easier to understand that deriving the right action indirectly, from the actions set on other ops.

@peterbell10 peterbell10 force-pushed the nvptx-fma-sm80 branch 2 times, most recently from 0fdfca3 to ee6635b Compare December 27, 2024 19:59
@Artem-B
Copy link
Member

Artem-B commented Jan 8, 2025

Could we just declare bfloat16 add/mul/sub legal when FMA is available, and just lower them as FMA via pattern-matching, without messing with the legalizer and combiner? I'm not quite sure why we want to convert those mul/add to FMA in the DAG itself.

}
// fma (fneg x), K, y -> fma x -K, y
if (N1CFP && matcher.match(N0, ISD::FNEG) &&
(TLI.isOperationLegal(ISD::ConstantFP, VT) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since ConstantFP doesn't get used for vectors, this won't handle them correctly. I thought we had a helper for this but I can't seem to find it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't my code, but there is a FIXME comment above.

@peterbell10 peterbell10 changed the title [NVPTX] Lower bfloat16 add/mul/sub as fma on SM80 [NVPTX] Select bfloat16 add/mul/sub as fma on SM80 Jan 9, 2025
@peterbell10
Copy link
Contributor Author

Okay, I've changed this to make add/mul/sub legal and added manual selection code which I believe is required to generate the constant inputs. PTAL.

@peterbell10 peterbell10 removed the llvm:SelectionDAG SelectionDAGISel as well label Jan 9, 2025
@@ -3671,14 +3671,21 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
Results.push_back(ExpandConstant(CP));
break;
}
case ISD::FADD: {
if (SDValue Expand = TLI.expandFADD(Node, DAG)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: No need for {} for single-statement bodies. Applies here and in other places.

if (IsNativelySupported)
return false;

assert(VT == MVT::bf16 || VT == MVT::v2bf16);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assert does nothing here as we would have returned false for non-bf16 types, already.

If we're not expected to be called with non-bf16 types, I'd move the assert upwards, next to where VT is defined, and use it instead of the early return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm asserting it's either a scalar or vector of length 2.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is that both would be ruled out by if (VT.getScalarType() != MVT::bf16) above.

auto API = APF.bitcastToAPInt();
API = API.concat(API);
auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32);
return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does PTX have a way to specify/use bf16 constant args? It would be nice to avoid passing the constant via a register.
I suspect there's no good way to use a constant for bf16x2, but I would assume that there should be a way to use FP constants for a scalar FMA.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No there doesn't seem to be any support for bf16 immediate values, ptxas complains

ptxas /tmp/tmplmhxk1av.ptx, line 79; error   : Arguments mismatch for instruction 'fma'
ptxas fatal   : Ptx assembly aborted due to errors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Please add a comment about that. Otherwise these register moves look rather questionable.

};

switch (N->getOpcode()) {
case ISD::FADD: {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit. We do not need {} around individual cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment on lines 859 to 860
if (getOperationAction(Op, VT) != Legal &&
getOperationAction(ISD::FMA, VT) == Legal) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be simplified and ISD::FADD, ISD::FMUL, ISD::FSUB mode for bf16 types does not need to be derived from FMA. It should be just set in sync with the action on FMA, at the same place where we do set it for FMA.

; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
; SM80-NEXT: mov.b16 %rs3, 0x3F80;
; SM80-NEXT: fma.rn.bf16 %rs4, %rs1, %rs3, %rs2;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need some tests for FTZ mode, now, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every test here has an SM80-FTZ-NEXT variant, they just don't show up in the diff because they haven't changed.

@Artem-B
Copy link
Member

Artem-B commented Jan 9, 2025

After checking the PTX spec, the question I've got is -- should we bother with this attempt to lower those ops to FMA at all?

It will only be beneficial on sm_80 for PTX versions between 7.0 to 7.7 (CUDA-11.1 to 11.7).

These days nobody should be using anything older than CUDA-11.8 and even that is firmly on the way out.
I do not really see much point adding something to support a use case that's already obsolete.

@peterbell10
Copy link
Contributor Author

The spec says

add{.rnd}.bf16 and add{.rnd}.bf16x2 requires sm_90 or higher.

I don't see any suggestion that that only applies for specific PTX versions.

@Artem-B
Copy link
Member

Artem-B commented Jan 9, 2025

The spec says

add{.rnd}.bf16 and add{.rnd}.bf16x2 requires sm_90 or higher.

I don't see any suggestion that that only applies for specific PTX versions.

sm_90 is only supported by PTX 7.8 and newer:
image

FMA instruction for bf16 types requires PTX 7.0 and sm_80: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-fma

FADD/FMUL for bf16 requires PTX 7.8 and sm_90. https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul

So, the only cases where we can benefit from this patch are:

  • PTX <= 7.0 and GPU >= sm_80(otherwise, there's no BF16 FMA support)
  • PTX < 7.8 (otherwise FMUL is available, and we don't need the patch).

sm_90 or newer GPUs require PTX 7.8 and therefore do not benefit from the patch.
What's left is sm_80 and PTX versions 7.0 through 7.7.

What do I miss?

Update: what I missed is that mul/add instructions are still absent for sm_80 even in the newer PTX versions.
as @peterbell10 correctly pointed out below.

@peterbell10
Copy link
Contributor Author

FADD/FMUL for bf16 requires PTX 7.8 and sm_90. https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul

Exactly... It requires sm_90, so for sm_80 there is support for fma but not add/mul/sub. What's confusing here?

@peterbell10
Copy link
Contributor Author

peterbell10 commented Jan 10, 2025

PTX < 7.8 (otherwise FMUL is available, and we don't need the patch).

This is where your logic doesn't make sense. Why would PTX >= 7.8 mean that sm_80 doesn't benefit from this patch?

@Artem-B
Copy link
Member

Artem-B commented Jan 10, 2025

Why would PTX >= 7.8 mean that sm_80 doesn't benefit from this patch?

Ugh. Indeed. You're absolutely right. sm_80 still does not have mul/add, even in newer PTX versions. I was concentrating on sm_90 and somehow ignored that sm_80 is still affected.

OK, the patch is still useful. Sorry about the noise.

@peterbell10
Copy link
Contributor Author

Are there any remaining blockers here?

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few nits, but LGTM overall.

if (IsNativelySupported)
return false;

assert(VT == MVT::bf16 || VT == MVT::v2bf16);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is that both would be ruled out by if (VT.getScalarType() != MVT::bf16) above.

auto API = APF.bitcastToAPInt();
API = API.concat(API);
auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32);
return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Please add a comment about that. Otherwise these register moves look rather questionable.

Comment on lines 859 to 860
if (getOperationAction(Op, VT) != Legal &&
getOperationAction(ISD::FMA, VT) == Legal) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the clarification we've arrived while sorting out my confusion about applicability of the patch, instead of relying on the setBF16OperationAction we can set the appropriate action by checking SM/PTX version range directly. It would be easier to understand that deriving the right action indirectly, from the actions set on other ops.

SM80 has fma for bfloat16 but not add/mul/sub. Currently these are
just promoted to f32 but we can instead write them in terms of the
fma:
```
FADD(a, b) -> FMA(a, 1.0, b)
FMUL(a, b) -> FMA(a, b, 0.0)
FSUB(a, b) -> FMA(b, -1.0, a)
```

Unfortunately there is no `fma.ftz` so when ftz is enabled, we still
fall back to promotion.

This is also the inverse of some generic DAGCombiner patterns, so
I've had to add checks to avoid it reversing the legalization which
would cause an infinite loop.
@peterbell10 peterbell10 merged commit 5e5fd0e into llvm:main Jan 16, 2025
8 checks passed
@peterbell10 peterbell10 deleted the nvptx-fma-sm80 branch January 16, 2025 14:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants