Skip to content

Commit f762b2b

Browse files
Rely on -cl-fp32-correctly-rounded-divide-sqrt for precise divide and sqrt (#5415)
OpenCL supports the build option `-cl-fp32-correctly-rounded-divide-sqrt`, which changes the requirement for `OpFDiv` to be correctly rounded. The disadvantage is it is added at kernel level, so all divide and sqrt are precise, i.e., we cannot have both precise and approximate div in one kernel. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 8e5c989 commit f762b2b

File tree

3 files changed

+28
-59
lines changed

3 files changed

+28
-59
lines changed

third_party/intel/backend/compiler.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ def make_ttir(mod, metadata, opt):
202202
passes.common.add_symbol_dce(pm)
203203
passes.ttir.add_loop_unroll(pm)
204204
pm.run(mod, 'make_ttir')
205+
206+
if intel.has_precise_divide_sqrt(mod):
207+
metadata["build_flags"] = "-cl-fp32-correctly-rounded-divide-sqrt"
208+
205209
return mod
206210

207211
@staticmethod
@@ -364,16 +368,15 @@ def make_llir(src, metadata, options):
364368
def make_spv(src, metadata, options, device_arch):
365369
spirv, name = intel.translate_to_spirv(src)
366370
metadata["name"] = name
371+
metadata.setdefault("build_flags", "")
367372
if options.grf_mode == 'small':
368-
metadata["build_flags"] = "-cl-intel-128-GRF-per-thread"
373+
metadata["build_flags"] += " -cl-intel-128-GRF-per-thread"
369374
elif options.grf_mode == 'large':
370375
if options.num_warps > 32:
371376
raise RuntimeError("grf_mode = large cannot be used with num_warps > 32")
372-
metadata["build_flags"] = "-cl-intel-256-GRF-per-thread"
377+
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
373378
elif options.grf_mode == 'auto':
374-
metadata["build_flags"] = "-cl-intel-enable-auto-large-GRF-mode"
375-
else:
376-
metadata["build_flags"] = ""
379+
metadata["build_flags"] += " -cl-intel-enable-auto-large-GRF-mode"
377380

378381
if knobs.intel.disable_igc_opt:
379382
metadata["build_flags"] += " -cl-opt-disable"

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 10 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -970,8 +970,8 @@ struct ElementwiseOpConversion
970970
ConversionPatternRewriter &rewriter,
971971
Type elemTy, MultipleOperandsRange operands,
972972
Location loc) const {
973-
assert((!getElementType(op.getLhs()).isBF16() &&
974-
!getElementType(op.getRhs()).isBF16()) &&
973+
assert((!getElementType(operands[0][0]).isBF16() &&
974+
!getElementType(operands[0][1]).isBF16()) &&
975975
"unsupported conversion");
976976
return {
977977
rewriter.create<DestOp>(loc, elemTy, operands[0][0], operands[0][1])};
@@ -1146,57 +1146,11 @@ struct PreciseSqrtOpConversion
11461146
ConversionPatternRewriter &rewriter,
11471147
Type elemTy, MultipleOperandsRange operands,
11481148
Location loc) const {
1149-
auto b = TritonLLVMOpBuilder(loc, rewriter);
1150-
Value input = operands[0][0];
1151-
Type origTy = input.getType();
1152-
if (!origTy.isF64())
1153-
input = b.fpext(f64_ty, input);
1154-
Type funcType = LLVM::LLVMFunctionType::get(f64_ty, {f64_ty});
1155-
LLVM::LLVMFuncOp funcOp =
1156-
appendOrGetExternFuncOp(rewriter, op, "__imf_sqrt_rn", funcType);
1157-
funcOp.setCConv(triton::gpu::intel::getDefaultCConv(op));
1158-
LLVM::CallOp callOp =
1159-
LLVM::createLLVMCallOp(rewriter, loc, funcOp, {input});
1160-
callOp.setCConv(funcOp.getCConv());
1161-
Value result = callOp.getResult();
1162-
if (!origTy.isF64())
1163-
result = rewriter.create<LLVM::FPTruncOp>(loc, origTy, result);
1164-
return {result};
1165-
}
1166-
};
1167-
1168-
template <typename TritonOp>
1169-
struct OpToExternCallConversion
1170-
: public ElementwiseOpConversionBase<TritonOp,
1171-
OpToExternCallConversion<TritonOp>> {
1172-
using Base =
1173-
ElementwiseOpConversionBase<TritonOp, OpToExternCallConversion<TritonOp>>;
1174-
using Base::Base;
1175-
using Adaptor = typename Base::OpAdaptor;
1176-
1177-
explicit OpToExternCallConversion(LLVMTypeConverter &typeConverter,
1178-
ModuleAxisInfoAnalysis &axisAnalysisPass,
1179-
StringRef externFuncName,
1180-
PatternBenefit benefit)
1181-
: Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass,
1182-
benefit),
1183-
funcName(externFuncName) {}
1184-
1185-
SmallVector<Value> createDestOps(TritonOp op, Adaptor adaptor,
1186-
ConversionPatternRewriter &rewriter,
1187-
Type elemTy, MultipleOperandsRange operands,
1188-
Location loc) const {
1189-
Type funcType = getFunctionType(elemTy, operands[0]);
1190-
LLVM::LLVMFuncOp funcOp =
1191-
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);
1192-
funcOp.setCConv(triton::gpu::intel::getDefaultCConv(op));
1193-
auto callOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]);
1194-
callOp.setCConv(funcOp.getCConv());
1195-
return {callOp.getResult()};
1149+
// FIXME: Use precise sqrt builtin: #5419
1150+
// Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise sqrt.
1151+
return {rewriter.create<LLVM::SqrtOp>(loc, elemTy, operands[0],
1152+
adaptor.getAttributes().getValue())};
11961153
}
1197-
1198-
private:
1199-
StringRef funcName;
12001154
};
12011155

12021156
// Following two patterns are copied from the common part to fix-up calling
@@ -1273,8 +1227,10 @@ void populateElementwiseOpToLLVMPatterns(
12731227

12741228
patterns.add<PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
12751229
benefit);
1276-
patterns.add<OpToExternCallConversion<triton::PreciseDivFOp>>(
1277-
typeConverter, axisInfoAnalysis, "__imf_fdiv_rn", benefit);
1230+
// FIXME: Use precise divide builtin: #5419
1231+
// Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise divide.
1232+
patterns.add<ElementwiseOpConversion<triton::PreciseDivFOp, LLVM::FDivOp>>(
1233+
typeConverter, axisInfoAnalysis, benefit);
12781234
patterns.add<MulhiUIOpConversion>(typeConverter, axisInfoAnalysis, targetInfo,
12791235
benefit);
12801236
patterns.add<ExternElementwiseOpConversion>(typeConverter, axisInfoAnalysis,

third_party/intel/triton_xpu.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,16 @@ void init_triton_intel(py::module &&m) {
301301
return py::int_(ret);
302302
});
303303

304+
m.def("has_precise_divide_sqrt", [](mlir::ModuleOp &mod) -> bool {
305+
using namespace mlir;
306+
WalkResult result = mod.walk([&](Operation *op) {
307+
if (isa<mlir::triton::PreciseDivFOp, mlir::triton::PreciseSqrtOp>(op))
308+
return WalkResult::interrupt();
309+
return WalkResult::advance();
310+
});
311+
return result.wasInterrupted();
312+
});
313+
304314
// FIXME: This is for internal experimentation. In the end we will need a
305315
// producer flag (e.g. PyTorch flag) to allow the Triton compiler to use the
306316
// fast math semantics on all arithmetic operations.

0 commit comments

Comments
 (0)