Skip to content

Commit 6396a73

Browse files
authored
reshape shape expression, drop reshape permute, test reshape oom (#5918)
1 parent 3571d7e commit 6396a73

33 files changed

+2269
-1222
lines changed

docs/developer-guide/expression.md

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
### expression
2+
3+
expression is used in the reshape slice parameter to express the dynamic shape or subscript value based on the expression formula and input shape
4+
5+
Compared with directly converting the expression calculation process into multiple operators, the motivation for using expression
6+
* No additional shape concat and other operators will be generated due to dynamic calculation, which greatly reduces the number of layers of the ncnn model and makes it easier to view the model structure and modify expression
7+
* Shape or subscript evaluations are usually single-digit operations, which are more suitable for direct completion on the CPU without layout conversion and kernel call overhead
8+
9+
In the param file, `Reshape` layer can contain 6=expression
10+
11+
The pnnx tool can automatically convert `pnnx.Expression` to the expr parameter of ncnn `Reshape`
12+
13+
* Convert to 0w, 0h, 0d or 0c according to the input shape rank and `size(@0,1)`
14+
* Automatically remove the batch dimension according to the input batch index
15+
* Convert `pnnx.Expression` and `Tensor.reshape`/`Tensor.view` two operators are fused into ncnn `Reshape`
16+
* Automatically summarize the number of references, exclude duplicate references and sort the indexes of references
17+
* Convert the customary shape representation order, such as CHW to WHC
18+
19+
Example pnnx.param where A and B are 3D tensors
20+
```
21+
pnnx.Expression expr 2 1 A B shape expr=[add(size(@1,0),2),mul(size(@0,1),2),-1]
22+
Tensor.reshape reshape 2 1 A shape out
23+
```
24+
25+
pnnx.py
26+
```python
27+
shape = [(B.size(0) + 2), (A.size(1) * 2), -1]
28+
out = A.reshape(*shape)
29+
```
30+
31+
Converted to ncnn.param
32+
```
33+
Reshape reshape 2 1 A B out 6="-1,*(0h,2),+(1c,2)"
34+
```
35+
36+
### syntax
37+
38+
Use infix expression, format is `op(arg0,arg1,...)`, multiple operations can be nested, multiple sizes are separated by commas, and numbers can be integers or decimals
39+
40+
Among them, the commonly used `add` `sub` `mul` `div` `floor_div` are abbreviated as `+` `-` `*` `/` `//`, and other arithmetic operations use names, such as `sin` `ceil` `max`, etc.
41+
42+
* `max(2,3)`
43+
* `floor(sin(3.14))`
44+
* `+(*(-2,1),10)` means (-2 * 1) + 10
45+
* `1,2,+(3,2)` list can represent output shape with 3-rank
46+
47+
The input shape can be referenced at runtime, format is `id(w|h|d|c)`, the maximum id is 9, which means that up to 10 inputs can be referenced
48+
49+
Assuming that the Reshape layer has two input blobs, A and B, then
50+
51+
* `0w,1h` means A.w, B.h
52+
* `*(+(0c,1c),2)` means (A.c + B.c) * 2
53+
54+
### helper api
55+
56+
```cpp
57+
#include "expression.h"
58+
59+
int count_expression_blobs(const std::string& expr);
60+
61+
int eval_list_expression(const std::string& expr, const std::vector<Mat>& blobs, std::vector<int>& outlist);
62+
```
63+
64+
* `count_expression_blobs`
65+
66+
Pass expression to get the number of inputs it references, such as `0w,1h` returns 2
67+
68+
* `eval_list_expression`
69+
70+
Evaluate the result list according to expression and input blob calculate. If the calculation result is a floating point number, it will be automatically truncated to an integer.
71+
72+
### supported operator
73+
74+
|type|operators|
75+
|---|---|
76+
|float to int|`trunc` `ceil` `floor` `round`|
77+
|binary arithmetic|`+` `-` `*` `/` `//` `max` `min` `pow` `fmod` `remainder` `atan2` `logaddexp`|
78+
|unary arithmetic|`abs` `neg` `sign` `square` `sqrt` `rsqrt` `reciprocal` `exp` `log` `log10` `sin` `asin` `cos` `acos` `tan` `atan` `sinh` `asinh` `cosh` `acosh` `tanh` `atanh`|
79+
|integer bitwise|`and` `or` `xor` `lshift` `rshift`|

docs/developer-guide/operators.md

+1-3
Original file line numberDiff line numberDiff line change
@@ -1688,8 +1688,7 @@ y = float2int8(x3 * scale_out)
16881688

16891689
# Reshape
16901690
```
1691-
if permute == 1 y = hwc2chw(reshape(chw2hwc(x)))
1692-
else y = reshape(x)
1691+
y = reshape(x)
16931692
```
16941693

16951694
* one_blob_only
@@ -1700,7 +1699,6 @@ else y = reshape(x)
17001699
| 1 | h | int | -233 | |
17011700
| 11 | d | int | -233 | |
17021701
| 2 | c | int | -233 | |
1703-
| 3 | permute | int | 0 | |
17041702

17051703
Reshape flag:
17061704
- 0 = copy from bottom

src/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ set(ncnn_SRCS
2222
command.cpp
2323
cpu.cpp
2424
datareader.cpp
25+
expression.cpp
2526
gpu.cpp
2627
layer.cpp
2728
mat.cpp
@@ -731,6 +732,7 @@ if(NCNN_INSTALL_SDK)
731732
command.h
732733
cpu.h
733734
datareader.h
735+
expression.h
734736
gpu.h
735737
layer.h
736738
layer_shader_type.h

0 commit comments

Comments
 (0)