Skip to content

Commit 56e635e

Browse files
authored
[tosa] : Add option to enable/disable patterns selectively. (#4485)
Consider the source IR: ``` func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) -> !torch.vtensor<[1,512,1,1],f32> { %int7 = torch.constant.int 7 %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 %false = torch.constant.bool false %none = torch.constant.none %kernel = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int> %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> %padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> %0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %false, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32> return %0 : !torch.vtensor<[1,512,1,1],f32> } ``` When lowered through TOSA path we get ``` ❯ torch-mlir-opt --convert-torch-to-tosa /tmp/torch.mlir --mlir-print-op-generic | mlir-opt -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,canonicalize,cse))" --allow-unregistered-dialect #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> module { func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> { %cst = arith.constant 4.900000e+01 : f32 %cst_0 = arith.constant 0.000000e+00 : f32 %0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[1,512,7,7],f32>) -> tensor<1x512x7x7xf32> %1 = "torch.constant.int"() <{value = 7 : i64}> : () -> !torch.int %2 = "torch.constant.int"() <{value = 1 : i64}> : () -> !torch.int %3 = "torch.constant.int"() <{value = 0 : i64}> : () -> !torch.int %4 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool %5 = "torch.constant.none"() : () -> !torch.none %6 = "torch.prim.ListConstruct"(%1, %1) : (!torch.int, !torch.int) -> !torch.list<int> %7 = "torch.prim.ListConstruct"(%2, %2) : (!torch.int, !torch.int) -> !torch.list<int> %8 = "torch.prim.ListConstruct"(%3, %3) : (!torch.int, !torch.int) -> !torch.list<int> %9 = tensor.empty() : tensor<1x7x7x512xf32> %transposed = linalg.transpose ins(%0 : tensor<1x512x7x7xf32>) outs(%9 : tensor<1x7x7x512xf32>) permutation = [0, 2, 3, 1] %10 = tensor.empty() : tensor<1x1x1x512xf32> %11 = linalg.fill ins(%cst_0 : f32) outs(%10 : tensor<1x1x1x512xf32>) -> tensor<1x1x1x512xf32> %12 = tensor.empty() : tensor<7x7xf32> %13 = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%transposed, %12 : tensor<1x7x7x512xf32>, tensor<7x7xf32>) outs(%11 : tensor<1x1x1x512xf32>) -> tensor<1x1x1x512xf32> %14 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<1x1x1x512xf32>) outs(%10 : tensor<1x1x1x512xf32>) { ^bb0(%in: f32, %out: f32): %17 = arith.divf %in, %cst : f32 linalg.yield %17 : f32 } -> tensor<1x1x1x512xf32> %15 = tensor.empty() : tensor<1x512x1x1xf32> %transposed_1 = linalg.transpose ins(%14 : tensor<1x1x1x512xf32>) outs(%15 : tensor<1x512x1x1xf32>) permutation = [0, 3, 1, 2] %16 = "torch_c.from_builtin_tensor"(%transposed_1) : (tensor<1x512x1x1xf32>) -> !torch.vtensor<[1,512,1,1],f32> return %16 : !torch.vtensor<[1,512,1,1],f32> } } ``` When lowered through linalg path we get: ``` ❯ torch-mlir-opt --convert-torch-to-linalg /tmp/torch.mlir --mlir-print-op-generic | mlir-opt -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,canonicalize,cse))" --allow-unregistered-dialect #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> module { func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> { %cst = arith.constant 4.900000e+01 : f32 %cst_0 = arith.constant 0.000000e+00 : f32 %0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[1,512,7,7],f32>) -> tensor<1x512x7x7xf32> %1 = "torch.constant.int"() <{value = 7 : i64}> : () -> !torch.int %2 = "torch.constant.int"() <{value = 1 : i64}> : () -> !torch.int %3 = "torch.constant.int"() <{value = 0 : i64}> : () -> !torch.int %4 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool %5 = "torch.constant.none"() : () -> !torch.none %6 = "torch.prim.ListConstruct"(%1, %1) : (!torch.int, !torch.int) -> !torch.list<int> %7 = "torch.prim.ListConstruct"(%2, %2) : (!torch.int, !torch.int) -> !torch.list<int> %8 = "torch.prim.ListConstruct"(%3, %3) : (!torch.int, !torch.int) -> !torch.list<int> %9 = tensor.empty() : tensor<1x512x1x1xf32> %10 = linalg.fill ins(%cst_0 : f32) outs(%9 : tensor<1x512x1x1xf32>) -> tensor<1x512x1x1xf32> %11 = tensor.empty() : tensor<7x7xf32> %12 = linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%0, %11 : tensor<1x512x7x7xf32>, tensor<7x7xf32>) outs(%10 : tensor<1x512x1x1xf32>) -> tensor<1x512x1x1xf32> %13 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%12 : tensor<1x512x1x1xf32>) outs(%9 : tensor<1x512x1x1xf32>) { ^bb0(%in: f32, %out: f32): %15 = arith.divf %in, %cst : f32 linalg.yield %15 : f32 } -> tensor<1x512x1x1xf32> %14 = "torch_c.from_builtin_tensor"(%13) : (tensor<1x512x1x1xf32>) -> !torch.vtensor<[1,512,1,1],f32> return %14 : !torch.vtensor<[1,512,1,1],f32> } } ``` Because of layout mismatch between PyTorch (NCHW) and TOSA (NHWC), there will be two additional transpose operations in the TOSA path. This requires two additional buffers which leads to a problem for resource-constrained embedded HW which don't have enough memory. This change adds an option to selectively enable/disable legalizations through the TOSA path, so that for the `tosa_linalg` path we can choose to not lower some ops (depending on the target HW) through TOSA and instead let it lower through the linalg path that runs after TOSA path.
1 parent 64ca81a commit 56e635e

File tree

6 files changed

+154
-45
lines changed

6 files changed

+154
-45
lines changed

include/torch-mlir/Conversion/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
133133
"bool", /*default=*/"true",
134134
"Require TorchToTosa full conversion by adding Torch Dialect to "
135135
"TorchToTosa list of illegal dialects">,
136+
ListOption<"disabledPatterns", "disabled-patterns",
137+
"std::string",
138+
"Patterns to disable by name during Torch to TOSA conversion">,
139+
ListOption<"enabledPatterns", "enabled-patterns",
140+
"std::string",
141+
"If non-empty, only these patterns are enabled during Torch to TOSA conversion">,
136142
];
137143
}
138144
#endif

include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
3737
// Convenience wrapper for users who want to pass options as individual
3838
// parameters
3939
std::unique_ptr<OperationPass<func::FuncOp>>
40-
createConvertTorchToTosaPass(bool requireFullTosaConversion);
40+
createConvertTorchToTosaPass(bool requireFullTosaConversion,
41+
ArrayRef<std::string> disabled_patterns,
42+
ArrayRef<std::string> enabled_patterns);
4143

4244
} // namespace torch
4345
} // namespace mlir

include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ struct TosaBackendPipelineOptions
4747
llvm::cl::desc("Require full TorchToTosa conversion by adding Torch "
4848
"Dialect to TorchToTosa list of illegal dialects"),
4949
llvm::cl::init(true)};
50+
ListOption<std::string> disabledPatterns{
51+
*this, "disabled-patterns",
52+
::llvm::cl::desc(
53+
"Patterns to disable by name during Torch to TOSA conversion"),
54+
llvm::cl::ZeroOrMore};
55+
ListOption<std::string> enabledPatterns{
56+
*this, "enabled-patterns",
57+
::llvm::cl::desc("If non-empty, only these patterns are enabled during "
58+
"Torch to TOSA conversion"),
59+
llvm::cl::ZeroOrMore};
5060
};
5161

5262
/// Creates a pipeline that lowers from the torch backend contract to the

0 commit comments

Comments
 (0)