Skip to content

Commit 4063746

Browse files
xiaomengyfacebook-github-bot
authored andcommitted
Optimize batch mm op when broadcast the second input (pytorch#21556)
Summary: Pull Request resolved: pytorch#21556 Optimize batch mm op when broadcast the second input Reviewed By: houseroad Differential Revision: D15728914 fbshipit-source-id: c60441d69d4997dd32a3566780496c7ccda5e67a
1 parent d715012 commit 4063746

File tree

1 file changed

+60
-32
lines changed

1 file changed

+60
-32
lines changed

caffe2/operators/batch_matmul_op.h

+60-32
Original file line numberDiff line numberDiff line change
@@ -203,39 +203,67 @@ class BatchMatMulOp final : public Operator<Context> {
203203
Y_data,
204204
&context_);
205205
} else if (A_batch_size == 1) {
206-
math::GemmStridedBatched<T, Context, Engine>(
207-
trans_a_ ? CblasTrans : CblasNoTrans,
208-
trans_b_ ? CblasTrans : CblasNoTrans,
209-
Y_batch_size,
210-
M,
211-
N,
212-
K,
213-
1.0f,
214-
A_data,
215-
0,
216-
B_data,
217-
K * N,
218-
0.0f,
219-
Y_data,
220-
M * N,
221-
&context_);
206+
if (M == 1 && trans_b_) {
207+
math::Gemv<T, Context, Engine>(
208+
CblasNoTrans,
209+
B_batch_size * N,
210+
K,
211+
1.0f,
212+
B_data,
213+
A_data,
214+
0.0f,
215+
Y_data,
216+
&context_);
217+
} else {
218+
math::GemmStridedBatched<T, Context, Engine>(
219+
trans_a_ ? CblasTrans : CblasNoTrans,
220+
trans_b_ ? CblasTrans : CblasNoTrans,
221+
Y_batch_size,
222+
M,
223+
N,
224+
K,
225+
1.0f,
226+
A_data,
227+
0,
228+
B_data,
229+
K * N,
230+
0.0f,
231+
Y_data,
232+
M * N,
233+
&context_);
234+
}
222235
} else if (B_batch_size == 1) {
223-
math::GemmStridedBatched<T, Context, Engine>(
224-
trans_a_ ? CblasTrans : CblasNoTrans,
225-
trans_b_ ? CblasTrans : CblasNoTrans,
226-
Y_batch_size,
227-
M,
228-
N,
229-
K,
230-
1.0f,
231-
A_data,
232-
M * K,
233-
B_data,
234-
0,
235-
0.0f,
236-
Y_data,
237-
M * N,
238-
&context_);
236+
if (!trans_a_) {
237+
math::Gemm<T, Context, Engine>(
238+
CblasNoTrans,
239+
trans_b_ ? CblasTrans : CblasNoTrans,
240+
A_batch_size * M,
241+
N,
242+
K,
243+
1.0f,
244+
A_data,
245+
B_data,
246+
0.0f,
247+
Y_data,
248+
&context_);
249+
} else {
250+
math::GemmStridedBatched<T, Context, Engine>(
251+
CblasTrans,
252+
trans_b_ ? CblasTrans : CblasNoTrans,
253+
Y_batch_size,
254+
M,
255+
N,
256+
K,
257+
1.0f,
258+
A_data,
259+
M * K,
260+
B_data,
261+
0,
262+
0.0f,
263+
Y_data,
264+
M * N,
265+
&context_);
266+
}
239267
} else if (!is_broadcast_dims) {
240268
math::GemmStridedBatched<T, Context, Engine>(
241269
trans_a_ ? CblasTrans : CblasNoTrans,

0 commit comments

Comments
 (0)