Skip to content

Commit bc74db7

Browse files
Rely on -cl-fp32-correctly-rounded-divide-sqrt for precise divide and sqrt
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 91cac59 commit bc74db7

File tree

3 files changed

+27
-59
lines changed

3 files changed

+27
-59
lines changed

third_party/intel/backend/compiler.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,12 @@ def make_ttir(mod, metadata, opt):
201201
passes.common.add_symbol_dce(pm)
202202
passes.ttir.add_loop_unroll(pm)
203203
pm.run(mod, 'make_ttir')
204+
205+
if intel.has_precise_divide_sqrt(mod):
206+
metadata["build_flags"] = "-cl-fp32-correctly-rounded-divide-sqrt"
207+
else:
208+
metadata["build_flags"] = ""
209+
204210
return mod
205211

206212
@staticmethod
@@ -364,15 +370,13 @@ def make_spv(src, metadata, options, device_arch):
364370
spirv, name = intel.translate_to_spirv(src)
365371
metadata["name"] = name
366372
if options.grf_mode == 'small':
367-
metadata["build_flags"] = "-cl-intel-128-GRF-per-thread"
373+
metadata["build_flags"] += " -cl-intel-128-GRF-per-thread"
368374
elif options.grf_mode == 'large':
369375
if options.num_warps > 32:
370376
raise RuntimeError("grf_mode = large cannot be used with num_warps > 32")
371-
metadata["build_flags"] = "-cl-intel-256-GRF-per-thread"
377+
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
372378
elif options.grf_mode == 'auto':
373-
metadata["build_flags"] = "-cl-intel-enable-auto-large-GRF-mode"
374-
else:
375-
metadata["build_flags"] = ""
379+
metadata["build_flags"] += " -cl-intel-enable-auto-large-GRF-mode"
376380

377381
if knobs.intel.disable_igc_opt:
378382
metadata["build_flags"] += " -cl-opt-disable"

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 8 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -970,8 +970,8 @@ struct ElementwiseOpConversion
970970
ConversionPatternRewriter &rewriter,
971971
Type elemTy, MultipleOperandsRange operands,
972972
Location loc) const {
973-
assert((!getElementType(op.getLhs()).isBF16() &&
974-
!getElementType(op.getRhs()).isBF16()) &&
973+
assert((!getElementType(operands[0][0]).isBF16() &&
974+
!getElementType(operands[0][1]).isBF16()) &&
975975
"unsupported conversion");
976976
return {
977977
rewriter.create<DestOp>(loc, elemTy, operands[0][0], operands[0][1])};
@@ -1146,57 +1146,10 @@ struct PreciseSqrtOpConversion
11461146
ConversionPatternRewriter &rewriter,
11471147
Type elemTy, MultipleOperandsRange operands,
11481148
Location loc) const {
1149-
auto b = TritonLLVMOpBuilder(loc, rewriter);
1150-
Value input = operands[0][0];
1151-
Type origTy = input.getType();
1152-
if (!origTy.isF64())
1153-
input = b.fpext(f64_ty, input);
1154-
Type funcType = LLVM::LLVMFunctionType::get(f64_ty, {f64_ty});
1155-
LLVM::LLVMFuncOp funcOp =
1156-
appendOrGetExternFuncOp(rewriter, op, "__imf_sqrt_rn", funcType);
1157-
funcOp.setCConv(triton::gpu::intel::getDefaultCConv(op));
1158-
LLVM::CallOp callOp =
1159-
LLVM::createLLVMCallOp(rewriter, loc, funcOp, {input});
1160-
callOp.setCConv(funcOp.getCConv());
1161-
Value result = callOp.getResult();
1162-
if (!origTy.isF64())
1163-
result = rewriter.create<LLVM::FPTruncOp>(loc, origTy, result);
1164-
return {result};
1165-
}
1166-
};
1167-
1168-
template <typename TritonOp>
1169-
struct OpToExternCallConversion
1170-
: public ElementwiseOpConversionBase<TritonOp,
1171-
OpToExternCallConversion<TritonOp>> {
1172-
using Base =
1173-
ElementwiseOpConversionBase<TritonOp, OpToExternCallConversion<TritonOp>>;
1174-
using Base::Base;
1175-
using Adaptor = typename Base::OpAdaptor;
1176-
1177-
explicit OpToExternCallConversion(LLVMTypeConverter &typeConverter,
1178-
ModuleAxisInfoAnalysis &axisAnalysisPass,
1179-
StringRef externFuncName,
1180-
PatternBenefit benefit)
1181-
: Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass,
1182-
benefit),
1183-
funcName(externFuncName) {}
1184-
1185-
SmallVector<Value> createDestOps(TritonOp op, Adaptor adaptor,
1186-
ConversionPatternRewriter &rewriter,
1187-
Type elemTy, MultipleOperandsRange operands,
1188-
Location loc) const {
1189-
Type funcType = getFunctionType(elemTy, operands[0]);
1190-
LLVM::LLVMFuncOp funcOp =
1191-
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);
1192-
funcOp.setCConv(triton::gpu::intel::getDefaultCConv(op));
1193-
auto callOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]);
1194-
callOp.setCConv(funcOp.getCConv());
1195-
return {callOp.getResult()};
1149+
// Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise sqrt.
1150+
return {rewriter.create<LLVM::SqrtOp>(loc, elemTy, operands[0],
1151+
adaptor.getAttributes().getValue())};
11961152
}
1197-
1198-
private:
1199-
StringRef funcName;
12001153
};
12011154

12021155
// Following two patterns are copied from the common part to fix-up calling
@@ -1273,8 +1226,9 @@ void populateElementwiseOpToLLVMPatterns(
12731226

12741227
patterns.add<PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
12751228
benefit);
1276-
patterns.add<OpToExternCallConversion<triton::PreciseDivFOp>>(
1277-
typeConverter, axisInfoAnalysis, "__imf_fdiv_rn", benefit);
1229+
// Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise divide.
1230+
patterns.add<ElementwiseOpConversion<triton::PreciseDivFOp, LLVM::FDivOp>>(
1231+
typeConverter, axisInfoAnalysis, benefit);
12781232
patterns.add<MulhiUIOpConversion>(typeConverter, axisInfoAnalysis, targetInfo,
12791233
benefit);
12801234
patterns.add<ExternElementwiseOpConversion>(typeConverter, axisInfoAnalysis,

third_party/intel/triton_xpu.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,16 @@ void init_triton_intel(py::module &&m) {
299299
return py::int_(ret);
300300
});
301301

302+
m.def("has_precise_divide_sqrt", [](mlir::ModuleOp &mod) -> bool {
303+
using namespace mlir;
304+
WalkResult result = mod.walk([&](Operation *op) {
305+
if (isa<mlir::triton::PreciseDivFOp, mlir::triton::PreciseSqrtOp>(op))
306+
return WalkResult::interrupt();
307+
return WalkResult::advance();
308+
});
309+
return result.wasInterrupted();
310+
});
311+
302312
// FIXME: This is for internal experimentation. In the end we will need a
303313
// producer flag (e.g. PyTorch flag) to allow the Triton compiler to use the
304314
// fast math semantics on all arithmetic operations.

0 commit comments

Comments
 (0)