forked from llvm/llvm-project
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][Linalg] Refine how broadcast dims are treated (llvm#99015)
This PR fixes how broadcast dims (identified as "zero" results in permutation maps) corresponding to a reduction iterator are vectorised in the case of generic Ops. Here's an example: ```mlir #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)> func.func @generic_with_reduction_and_broadcast(%arg0: tensor<1x12x197x197xf32>) -> (tensor<1x12x197x1xf32>) { %0 = tensor.empty() : tensor<1x12x197x1xf32> %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x12x197x197xf32>) outs(%0 : tensor<1x12x197x1xf32>) { ^bb0(%in: f32, %out: f32): %818 = arith.addf %in, %out : f32 linalg.yield %818 : f32 } -> tensor<1x12x197x1xf32> return %1 : tensor<1x12x197x1xf32> } ``` This is a perfectly valid Generic Op, but currently triggers two issues in the vectoriser. The root cause is this map: ```mlir #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)> ``` This map triggers an assert in `reindexIndexingMap` - this hook incorrectly assumes that every result in the input map is a `dim` expression and that there are no constants. That's not the case in this example. `reindexIndexingMap` is extended to allow maps like the one above. For now, only constant "zero" results are allowed. This can be extended in the future once a good motivating example is available. Separately, the permutation map highlighted above "breaks" mask calculation (ATM masks are always computed, even in the presence of static shapes). When applying the following permutation: ```mlir (d0, d1, d2, d3) -> (d0, d1, d2, 0) ``` to these canonical shapes (corresponding to the example above): ``` (1, 12, 197, 197) ``` we end up with the following error: ```bash error: vector types must have positive constant sizes but got 1, 12, 197, 0 ``` The error makes sense and indicates that we should update the permutation map above to: ``` (d0, d1, d2, d3) -> (d0, d1, d2) ``` This would correctly give the following vector type: ``` vector<1x12x197xi1> ``` Fixes llvm#97247
- Loading branch information
1 parent
0b6c816
commit 1eea819
Showing
5 changed files
with
148 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters