Skip to content

Commit 5295eee

Browse files
authored
[linalg] : Use (-)realmax instead of (-)inf to avoid usage of non-finites. (#4363)
This change replaces usage of non-finite value `inf` with finite value `realmax` for init value of various max/min operations depending on whether `allow-non-finites` option is disabled through the `fx.export_and_import` API -- no change in semantics of the ops. The default behavior of emitting `inf` is preserved.
1 parent f2a192c commit 5295eee

File tree

36 files changed

+711
-150
lines changed

36 files changed

+711
-150
lines changed

include/torch-mlir/Conversion/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
103103
compelling for modeling effects more broadly.
104104
}];
105105
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
106+
let options = [
107+
Option<"allowNonFinites", "allow-non-finites",
108+
"bool", /*default=*/"true",
109+
"When enabled (default), some ops may emit non-finites, for example, max pooling may compare values to an initial value of `-inf`. When disabled, non-finites will be replaced with the closest finite value for a given dtype.">,
110+
];
106111
}
107112

108113
def ConvertTorchToTensor : Pass<"convert-torch-to-tensor", "func::FuncOp"> {
@@ -141,6 +146,11 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
141146
pass also makes use of TMTensor Dialect, which the former one doesn't.
142147
}];
143148
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
149+
let options = [
150+
Option<"allowNonFinites", "allow-non-finites",
151+
"bool", /*default=*/"true",
152+
"When enabled (default), some ops may emit non-finites, for example, max pooling may compare values to an initial value of `-inf`. When disabled, non-finites will be replaced with the closest finite value for a given dtype.">,
153+
];
144154
}
145155

146156
def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> {
@@ -168,6 +178,9 @@ def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp">
168178
// are unlikely to exceed the range of i32(4GiB)
169179
Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false",
170180
"Enable truncate index from i64 to i32(unsafely)">,
181+
Option<"allowNonFinites", "allow-non-finites",
182+
"bool", /*default=*/"true",
183+
"When enabled (default), some ops may emit non-finites, for example, max pooling may compare values to an initial value of `-inf`. When disabled, non-finites will be replaced with the closest finite value for a given dtype.">,
171184
];
172185
}
173186
#endif

include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717

1818
namespace mlir::torch::onnx_c {
1919

20+
#define GEN_PASS_DECL_CONVERTTORCHONNXTOTORCH
21+
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc"
22+
2023
std::unique_ptr<OperationPass<func::FuncOp>> createTorchOnnxToTorchPass();
24+
std::unique_ptr<OperationPass<func::FuncOp>>
25+
createTorchOnnxToTorchPass(bool allowNonFinites);
2126

2227
/// Registers all torch-mlir conversion passes.
2328
void registerTorchOnnxToTorchPasses();

include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ def ConvertTorchOnnxToTorch : Pass<"convert-torch-onnx-to-torch", "func::FuncOp"
2121
}];
2222

2323
let constructor = "mlir::torch::onnx_c::createTorchOnnxToTorchPass()";
24+
let options = [
25+
Option<"allowNonFinites", "allow-non-finites",
26+
"bool", /*default=*/"true",
27+
"When enabled (default), some ops may emit non-finites, for example, max pooling may compare values to an initial value of `-inf`. When disabled, non-finites will be replaced with the closest finite value for a given dtype.">,
28+
];
2429
}
2530

2631
#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES

include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,10 @@ struct OpBinder {
422422
Operation *op;
423423
};
424424

425+
struct OnnxTorchToTorchOptions {
426+
bool allowNonFinites = true;
427+
};
428+
425429
/// We use a single pattern per ONNX domain to handle all named custom
426430
/// ops.
427431
/// This allows us to avoid the n^2 problem on pattern application by
@@ -431,19 +435,25 @@ struct OpBinder {
431435
class OnnxCustomOpConversionPattern
432436
: public OpConversionPattern<Torch::OperatorOp> {
433437
public:
434-
using HandlerFn = LogicalResult (*)(OpBinder binder,
435-
ConversionPatternRewriter &rewriter);
438+
using HandlerFn =
439+
std::function<LogicalResult(OpBinder, ConversionPatternRewriter &)>;
440+
441+
using HandlerFnWithOptions =
442+
LogicalResult (*)(OpBinder binder, ConversionPatternRewriter &rewriter,
443+
const OnnxTorchToTorchOptions &options);
444+
436445
struct HandlerReg {
437446
HandlerReg(HandlerFn callback, int64_t sinceVersion)
438-
: callback(callback), sinceVersion(sinceVersion) {}
447+
: callback(std::move(callback)), sinceVersion(sinceVersion) {}
439448
HandlerFn callback;
440449
int64_t sinceVersion;
441450
};
442451

443452
OnnxCustomOpConversionPattern(MLIRContext *context, std::string domainPrefix,
444-
int64_t domainVersion)
453+
int64_t domainVersion,
454+
const OnnxTorchToTorchOptions &options)
445455
: OpConversionPattern(context), domainPrefix(std::move(domainPrefix)),
446-
domainVersion(domainVersion) {
456+
domainVersion(domainVersion), options(options) {
447457
// Onnx lowerings could produce other Onnx operations during the rewrite.
448458
setHasBoundedRewriteRecursion();
449459
}
@@ -463,19 +473,25 @@ class OnnxCustomOpConversionPattern
463473
/// Multiple conversions can be registered for the same op, most
464474
/// commonly differing by their `sinceVersion`.
465475
void onOp(StringRef name, int64_t sinceVersion, HandlerFn callback);
476+
void onOp(StringRef name, int64_t sinceVersion,
477+
HandlerFnWithOptions callback);
478+
479+
const OnnxTorchToTorchOptions &getOptions() const { return options; }
466480

467481
private:
468482
std::string domainPrefix;
469483
int64_t domainVersion;
470484
DenseMap<StringAttr, SmallVector<HandlerReg, 1>> namedHandlers;
485+
OnnxTorchToTorchOptions options;
471486
};
472487

473488
// Patterns are split into chunks to speed compile time and reduce some
474489
// contention on the same source files.
475490
void populateComMicrosoftDomain(OnnxCustomOpConversionPattern &patterns);
476491
void populateDefaultDomainAtoF(OnnxCustomOpConversionPattern &patterns);
477492
void populateDefaultDomainGtoP(OnnxCustomOpConversionPattern &patterns);
478-
void populateDefaultDomainQtoZ(OnnxCustomOpConversionPattern &patterns);
493+
void populateDefaultDomainQtoZ(OnnxCustomOpConversionPattern &patterns,
494+
const OnnxTorchToTorchOptions &options);
479495

480496
} // namespace mlir::torch::onnx_c
481497

include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,16 @@
1717

1818
namespace mlir {
1919
namespace torch {
20+
21+
#define GEN_PASS_DECL_CONVERTTORCHTOLINALG
22+
#include "torch-mlir/Conversion/Passes.h.inc"
23+
2024
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass();
21-
}
25+
26+
std::unique_ptr<OperationPass<func::FuncOp>>
27+
createConvertTorchToLinalgPass(bool allowNonFinites);
28+
29+
} // namespace torch
2230
} // namespace mlir
2331

2432
#endif // TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H

include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ createConvertTorchToStablehloPass();
2626
// Convenience wrapper for users who want to pass options as individual
2727
// parameters
2828
std::unique_ptr<OperationPass<func::FuncOp>>
29-
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
29+
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index,
30+
bool allowNonFinites);
3031

3132
} // namespace torch
3233
} // namespace mlir

include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,16 @@
1515

1616
namespace mlir {
1717
namespace torch {
18+
19+
#define GEN_PASS_DECL_CONVERTTORCHTOTMTENSOR
20+
#include "torch-mlir/Conversion/Passes.h.inc"
21+
1822
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTMTensorPass();
19-
}
23+
24+
std::unique_ptr<OperationPass<func::FuncOp>>
25+
createConvertTorchToTMTensorPass(bool allowNonFinites);
26+
27+
} // namespace torch
2028
} // namespace mlir
2129

2230
#endif // TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H

include/torch-mlir/Conversion/Utils/Utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ void getZeroPoint(Value value, Value &zeropoint);
122122
LogicalResult getQuantizationParams(Value value, Value &zeropoint, Value &scale,
123123
int64_t &axis);
124124

125+
APFloat getFloatInf(mlir::FloatType fpType, bool negative,
126+
bool allowNonFinites);
127+
125128
} // namespace Torch
126129
} // namespace torch
127130
} // namespace mlir

include/torch-mlir/Dialect/Torch/Transforms/Passes.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,15 @@ struct TorchLoweringPipelineOptions
6666
*this, "extra-library",
6767
llvm::cl::desc("Filename of MLIR module for splicing into the abstract "
6868
"interpretation library.")};
69+
70+
Option<bool> allowNonFinites{
71+
*this, "allow-non-finites",
72+
llvm::cl::desc(
73+
"When enabled (default), some ops may emit non-finites, for example, "
74+
"max pooling may compare values to an initial value of `-inf`. When "
75+
"disabled, non-finites will be replaced with the closest finite "
76+
"value for a given dtype."),
77+
llvm::cl::init(true)};
6978
};
7079

7180
/// Creates a pipeline that lowers the object graph IR that is produced by

include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,22 @@ class ModuleOp;
2121

2222
namespace torch {
2323
namespace TorchConversion {
24+
struct LinalgOnTensorsBackendPipelineOptions
25+
: public PassPipelineOptions<LinalgOnTensorsBackendPipelineOptions> {
26+
Option<bool> allowNonFinites{
27+
*this, "allow-non-finites",
28+
llvm::cl::desc(
29+
"When enabled (default), some ops may emit non-finites, for example, "
30+
"max pooling may compare values to an initial value of `-inf`. When "
31+
"disabled, non-finites will be replaced with the closest finite "
32+
"value for a given dtype."),
33+
llvm::cl::init(true)};
34+
};
2435

2536
/// Creates a pipeline that lowers from the torch backend contract to the
2637
/// linalg-on-tensors backend contract.
27-
void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
38+
void createTorchBackendToLinalgOnTensorsBackendPipeline(
39+
OpPassManager &pm, const LinalgOnTensorsBackendPipelineOptions &options);
2840

2941
// Do not register the TOSA options if the TOSA target is disabled
3042
#ifdef TORCH_MLIR_ENABLE_TOSA
@@ -59,6 +71,14 @@ struct StablehloBackendPipelineOptions
5971
*this, "enable-i32-index",
6072
llvm::cl::desc("Enable truncate index from i64 to i32(unsafely)"),
6173
llvm::cl::init(false)};
74+
Option<bool> allowNonFinites{
75+
*this, "allow-non-finites",
76+
llvm::cl::desc(
77+
"When enabled (default), some ops may emit non-finites, for example, "
78+
"max pooling may compare values to an initial value of `-inf`. When "
79+
"disabled, non-finites will be replaced with the closest finite "
80+
"value for a given dtype."),
81+
llvm::cl::init(true)};
6282
};
6383

6484
void createTorchBackendToStablehloBackendPipeline(

0 commit comments

Comments
 (0)