@@ -970,8 +970,8 @@ struct ElementwiseOpConversion
970970                                    ConversionPatternRewriter &rewriter,
971971                                    Type elemTy, MultipleOperandsRange operands,
972972                                    Location loc) const  {
973-     assert ((!getElementType (op. getLhs () ).isBF16 () &&
974-             !getElementType (op. getRhs () ).isBF16 ()) &&
973+     assert ((!getElementType (operands[ 0 ][ 0 ] ).isBF16 () &&
974+             !getElementType (operands[ 0 ][ 1 ] ).isBF16 ()) &&
975975           " unsupported conversion"  );
976976    return  {
977977        rewriter.create <DestOp>(loc, elemTy, operands[0 ][0 ], operands[0 ][1 ])};
@@ -1146,57 +1146,10 @@ struct PreciseSqrtOpConversion
11461146                                   ConversionPatternRewriter &rewriter,
11471147                                   Type elemTy, MultipleOperandsRange operands,
11481148                                   Location loc) const  {
1149-     auto  b = TritonLLVMOpBuilder (loc, rewriter);
1150-     Value input = operands[0 ][0 ];
1151-     Type origTy = input.getType ();
1152-     if  (!origTy.isF64 ())
1153-       input = b.fpext (f64_ty, input);
1154-     Type funcType = LLVM::LLVMFunctionType::get (f64_ty, {f64_ty});
1155-     LLVM::LLVMFuncOp funcOp =
1156-         appendOrGetExternFuncOp (rewriter, op, " __imf_sqrt_rn"  , funcType);
1157-     funcOp.setCConv (triton::gpu::intel::getDefaultCConv (op));
1158-     LLVM::CallOp callOp =
1159-         LLVM::createLLVMCallOp (rewriter, loc, funcOp, {input});
1160-     callOp.setCConv (funcOp.getCConv ());
1161-     Value result = callOp.getResult ();
1162-     if  (!origTy.isF64 ())
1163-       result = rewriter.create <LLVM::FPTruncOp>(loc, origTy, result);
1164-     return  {result};
1165-   }
1166- };
1167- 
1168- template  <typename  TritonOp>
1169- struct  OpToExternCallConversion 
1170-     : public ElementwiseOpConversionBase<TritonOp,
1171-                                          OpToExternCallConversion<TritonOp>> {
1172-   using  Base =
1173-       ElementwiseOpConversionBase<TritonOp, OpToExternCallConversion<TritonOp>>;
1174-   using  Base::Base;
1175-   using  Adaptor = typename  Base::OpAdaptor;
1176- 
1177-   explicit  OpToExternCallConversion (LLVMTypeConverter &typeConverter,
1178-                                     ModuleAxisInfoAnalysis &axisAnalysisPass,
1179-                                     StringRef externFuncName,
1180-                                     PatternBenefit benefit)
1181-       : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass,
1182-                                           benefit),
1183-         funcName(externFuncName) {}
1184- 
1185-   SmallVector<Value> createDestOps (TritonOp op, Adaptor adaptor,
1186-                                    ConversionPatternRewriter &rewriter,
1187-                                    Type elemTy, MultipleOperandsRange operands,
1188-                                    Location loc) const  {
1189-     Type funcType = getFunctionType (elemTy, operands[0 ]);
1190-     LLVM::LLVMFuncOp funcOp =
1191-         appendOrGetExternFuncOp (rewriter, op, funcName, funcType);
1192-     funcOp.setCConv (triton::gpu::intel::getDefaultCConv (op));
1193-     auto  callOp = LLVM::createLLVMCallOp (rewriter, loc, funcOp, operands[0 ]);
1194-     callOp.setCConv (funcOp.getCConv ());
1195-     return  {callOp.getResult ()};
1149+     //  Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise sqrt.
1150+     return  {rewriter.create <LLVM::SqrtOp>(loc, elemTy, operands[0 ],
1151+                                           adaptor.getAttributes ().getValue ())};
11961152  }
1197- 
1198- private: 
1199-   StringRef funcName;
12001153};
12011154
12021155//  Following two patterns are copied from the common part to fix-up calling
@@ -1273,8 +1226,9 @@ void populateElementwiseOpToLLVMPatterns(
12731226
12741227  patterns.add <PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
12751228                                        benefit);
1276-   patterns.add <OpToExternCallConversion<triton::PreciseDivFOp>>(
1277-       typeConverter, axisInfoAnalysis, " __imf_fdiv_rn"  , benefit);
1229+   //  Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise divide.
1230+   patterns.add <ElementwiseOpConversion<triton::PreciseDivFOp, LLVM::FDivOp>>(
1231+       typeConverter, axisInfoAnalysis, benefit);
12781232  patterns.add <MulhiUIOpConversion>(typeConverter, axisInfoAnalysis, targetInfo,
12791233                                    benefit);
12801234  patterns.add <ExternElementwiseOpConversion>(typeConverter, axisInfoAnalysis,
0 commit comments