Commit 56e635e
authored
[tosa] : Add option to enable/disable patterns selectively. (#4485)
Consider the source IR:
```
func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) -> !torch.vtensor<[1,512,1,1],f32> {
%int7 = torch.constant.int 7
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%none = torch.constant.none
%kernel = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %false, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32>
return %0 : !torch.vtensor<[1,512,1,1],f32>
}
```
When lowered through TOSA path we get
```
❯ torch-mlir-opt --convert-torch-to-tosa /tmp/torch.mlir --mlir-print-op-generic | mlir-opt -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,canonicalize,cse))" --allow-unregistered-dialect
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> {
%cst = arith.constant 4.900000e+01 : f32
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[1,512,7,7],f32>) -> tensor<1x512x7x7xf32>
%1 = "torch.constant.int"() <{value = 7 : i64}> : () -> !torch.int
%2 = "torch.constant.int"() <{value = 1 : i64}> : () -> !torch.int
%3 = "torch.constant.int"() <{value = 0 : i64}> : () -> !torch.int
%4 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool
%5 = "torch.constant.none"() : () -> !torch.none
%6 = "torch.prim.ListConstruct"(%1, %1) : (!torch.int, !torch.int) -> !torch.list<int>
%7 = "torch.prim.ListConstruct"(%2, %2) : (!torch.int, !torch.int) -> !torch.list<int>
%8 = "torch.prim.ListConstruct"(%3, %3) : (!torch.int, !torch.int) -> !torch.list<int>
%9 = tensor.empty() : tensor<1x7x7x512xf32>
%transposed = linalg.transpose ins(%0 : tensor<1x512x7x7xf32>) outs(%9 : tensor<1x7x7x512xf32>) permutation = [0, 2, 3, 1]
%10 = tensor.empty() : tensor<1x1x1x512xf32>
%11 = linalg.fill ins(%cst_0 : f32) outs(%10 : tensor<1x1x1x512xf32>) -> tensor<1x1x1x512xf32>
%12 = tensor.empty() : tensor<7x7xf32>
%13 = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%transposed, %12 : tensor<1x7x7x512xf32>, tensor<7x7xf32>) outs(%11 : tensor<1x1x1x512xf32>) -> tensor<1x1x1x512xf32>
%14 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<1x1x1x512xf32>) outs(%10 : tensor<1x1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%17 = arith.divf %in, %cst : f32
linalg.yield %17 : f32
} -> tensor<1x1x1x512xf32>
%15 = tensor.empty() : tensor<1x512x1x1xf32>
%transposed_1 = linalg.transpose ins(%14 : tensor<1x1x1x512xf32>) outs(%15 : tensor<1x512x1x1xf32>) permutation = [0, 3, 1, 2]
%16 = "torch_c.from_builtin_tensor"(%transposed_1) : (tensor<1x512x1x1xf32>) -> !torch.vtensor<[1,512,1,1],f32>
return %16 : !torch.vtensor<[1,512,1,1],f32>
}
}
```
When lowered through linalg path we get:
```
❯ torch-mlir-opt --convert-torch-to-linalg /tmp/torch.mlir --mlir-print-op-generic | mlir-opt -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,canonicalize,cse))" --allow-unregistered-dialect
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> {
%cst = arith.constant 4.900000e+01 : f32
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[1,512,7,7],f32>) -> tensor<1x512x7x7xf32>
%1 = "torch.constant.int"() <{value = 7 : i64}> : () -> !torch.int
%2 = "torch.constant.int"() <{value = 1 : i64}> : () -> !torch.int
%3 = "torch.constant.int"() <{value = 0 : i64}> : () -> !torch.int
%4 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool
%5 = "torch.constant.none"() : () -> !torch.none
%6 = "torch.prim.ListConstruct"(%1, %1) : (!torch.int, !torch.int) -> !torch.list<int>
%7 = "torch.prim.ListConstruct"(%2, %2) : (!torch.int, !torch.int) -> !torch.list<int>
%8 = "torch.prim.ListConstruct"(%3, %3) : (!torch.int, !torch.int) -> !torch.list<int>
%9 = tensor.empty() : tensor<1x512x1x1xf32>
%10 = linalg.fill ins(%cst_0 : f32) outs(%9 : tensor<1x512x1x1xf32>) -> tensor<1x512x1x1xf32>
%11 = tensor.empty() : tensor<7x7xf32>
%12 = linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%0, %11 : tensor<1x512x7x7xf32>, tensor<7x7xf32>) outs(%10 : tensor<1x512x1x1xf32>) -> tensor<1x512x1x1xf32>
%13 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%12 : tensor<1x512x1x1xf32>) outs(%9 : tensor<1x512x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%15 = arith.divf %in, %cst : f32
linalg.yield %15 : f32
} -> tensor<1x512x1x1xf32>
%14 = "torch_c.from_builtin_tensor"(%13) : (tensor<1x512x1x1xf32>) -> !torch.vtensor<[1,512,1,1],f32>
return %14 : !torch.vtensor<[1,512,1,1],f32>
}
}
```
Because of layout mismatch between PyTorch (NCHW) and TOSA (NHWC), there
will be two additional transpose operations in the TOSA path. This
requires two additional buffers which leads to a problem for
resource-constrained embedded HW which don't have enough memory.
This change adds an option to selectively enable/disable legalizations
through the TOSA path, so that for the `tosa_linalg` path we can choose
to not lower some ops (depending on the target HW) through TOSA and
instead let it lower through the linalg path that runs after TOSA path.1 parent 64ca81a commit 56e635e
File tree
6 files changed
+154
-45
lines changed- include/torch-mlir
- Conversion
- TorchToTosa
- Dialect/TorchConversion/Transforms
- lib
- Conversion/TorchToTosa
- Dialect/TorchConversion/Transforms
- test/Conversion/TorchToTosa
6 files changed
+154
-45
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
133 | 133 | | |
134 | 134 | | |
135 | 135 | | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
136 | 142 | | |
137 | 143 | | |
138 | 144 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
37 | 37 | | |
38 | 38 | | |
39 | 39 | | |
40 | | - | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
41 | 43 | | |
42 | 44 | | |
43 | 45 | | |
| |||
Lines changed: 10 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
50 | 60 | | |
51 | 61 | | |
52 | 62 | | |
| |||
0 commit comments