forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathexpand_squeeze_dims_op.cc
194 lines (146 loc) · 5.73 KB
/
expand_squeeze_dims_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#include "caffe2/operators/expand_squeeze_dims_op.h"
namespace caffe2 {
REGISTER_CPU_OPERATOR(ExpandDims, ExpandDimsOp<CPUContext>);
REGISTER_CPU_OPERATOR(Squeeze, SqueezeOp<CPUContext>);
OPERATOR_SCHEMA(ExpandDims)
.NumInputs(1)
.NumOutputs(1)
.AllowInplace({{0, 0}})
.TensorInferenceFunction([](const OperatorDef& def,
const vector<TensorShape>& in) {
ArgumentHelper helper(def);
auto dims = helper.template GetRepeatedArgument<int>("dims");
auto originalSize = dims.size();
CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
std::sort(dims.begin(), dims.end());
dims.erase(std::unique(dims.begin(), dims.end()), dims.end());
if (dims.size() < originalSize) {
LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
}
CAFFE_ENFORCE(dims.front() >= 0, "Dimension ids must be non-negative.");
CAFFE_ENFORCE_GE(
in[0].dims_size() + dims.size(),
dims.back() + 1,
"Input needs at least ",
(1 + dims.back() - dims.size()),
" dimensions given `dims`.");
vector<TensorShape> out(1);
int cur_pos = 0;
int idx = 0;
for (const auto new_dim : dims) {
for (int i = cur_pos; i < new_dim; i++) {
out[0].add_dims(in[0].dims(idx++));
}
out[0].add_dims(1);
cur_pos = new_dim + 1;
}
for (; idx < in[0].dims_size(); idx++) {
out[0].add_dims(in[0].dims(idx));
}
out[0].set_data_type(in[0].data_type());
return out;
})
.SetDoc(R"DOC(
The *ExpandDims* op inserts single-dimensional entries into the shape of the input tensor *data,* and produces a single output tensor *expanded*. The op also takes an argument *dims* with a list of dimensions for where to add the single dimensional entries. If the same blob is provided as input and output, the operation is copy-free. This is the exact inverse operation of *Squeeze*.
Github Links:
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/expand_squeeze_dims_op.h
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/expand_squeeze_dims_op.cc
<details>
<summary> <b>Example</b> </summary>
**Code**
```
workspace.ResetWorkspace()
op = core.CreateOperator(
"ExpandDims",
["data"],
["expanded"],
dims=[0,1],
)
workspace.FeedBlob("data", np.zeros((100,100)).astype(np.float32))
print("data.shape:", workspace.FetchBlob("data").shape)
workspace.RunOperatorOnce(op)
print("expanded.shape:", workspace.FetchBlob("expanded").shape)
```
**Result**
```
data.shape: (100, 100)
expanded.shape: (1, 1, 100, 100)
```
</details>
)DOC")
.Input(0, "data", "Input tensor of data to be operated on.")
.Output(0, "expanded", "Reshaped tensor with same data as input.")
.Arg(
"dims",
"*(type: [int])* List of dimensions of *data* to add single dimensional entry.")
.InheritOnnxSchema();
OPERATOR_SCHEMA(Squeeze)
.NumInputs(1)
.NumOutputs(1)
.AllowInplace({{0, 0}})
.SetDoc(R"DOC(
The *Squeeze* op removes single-dimensional entries from the shape of the input tensor *data,* and produces a single output tensor *squeezed*. The op also takes an argument *dims* with a list of dimensions to squeeze. If the same blob is provided as input and output, the operation is copy-free. This is the exact inverse operation of *ExpandDims* given the same *dims* argument.
Github Links:
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/expand_squeeze_dims_op.h
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/expand_squeeze_dims_op.cc
<details>
<summary> <b>Example</b> </summary>
**Code**
```
workspace.ResetWorkspace()
op = core.CreateOperator(
"Squeeze",
["data"],
["squeezed"],
dims=[0,1],
)
workspace.FeedBlob("data", np.zeros((1,1,100,100)).astype(np.float32))
print("data.shape:", workspace.FetchBlob("data").shape)
workspace.RunOperatorOnce(op)
print("squeezed.shape:", workspace.FetchBlob("squeezed").shape)
```
**Result**
```
data.shape: (1, 1, 100, 100)
squeezed.shape: (100, 100)
```
</details>
)DOC")
.Input(0, "data", "Input tensor of data to be operated on.")
.Output(0, "squeezed", "Reshaped tensor with same data as input.")
.Arg("dims", "*(type: [int])* List of dimensions of *data* to squeeze out.")
.TensorInferenceFunction([](const OperatorDef& def,
const vector<TensorShape>& in) {
ArgumentHelper helper(def);
auto dims = helper.template GetRepeatedArgument<int>("dims");
auto originalSize = dims.size();
std::sort(dims.begin(), dims.end());
dims.erase(std::unique(dims.begin(), dims.end()), dims.end());
if (dims.size() < originalSize) {
LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
}
CAFFE_ENFORCE(dims.front() >= 0, "Dimension ids must be non-negative.");
vector<TensorShape> out(1);
std::vector<int> newDims =
SqueezeOp<CPUContext>::ComputeDims(GetDimsVector(in[0]), dims);
out[0] = CreateTensorShape(newDims, in[0].data_type());
return out;
})
.InheritOnnxSchema();
class GetSqueezeGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"ExpandDims", "", vector<string>{GO(0)}, vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(Squeeze, GetSqueezeGradient);
class GetExpandDimsGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"Squeeze", "", vector<string>{GO(0)}, vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(ExpandDims, GetExpandDimsGradient);
}