Skip to content

Commit 4d78aaf

Browse files
Dubtsov, Roman Svpirogov
authored andcommitted
common: fix zero-padding for tensors of small rank
1 parent c36887b commit 4d78aaf

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/common/memory_zero_pad.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*******************************************************************************/
1616

17-
#include <assert.h>
17+
#include <cassert>
1818

1919
#include "mkldnn_thread.hpp"
2020
#include "mkldnn_traits.hpp"
@@ -61,12 +61,14 @@ void typed_zero_pad_blk(
6161
const int c_tail_s = C_blocked ? dims[2] % blksize : 0;
6262
assert(a_tail_s || b_tail_s || c_tail_s);
6363

64+
const int ndims = m_d.ndims();
65+
assert(1 <= ndims && ndims <= 6);
6466
const int A = A_blocked ? pdims[0] / blksize : dims[0];
65-
const int B = B_blocked ? pdims[1] / blksize : dims[1];
66-
const int C = C_blocked ? pdims[2] / blksize : dims[2];
67-
const int D = m_d.ndims() > 3 ? dims[3] : 1;
68-
const int E = m_d.ndims() > 4 ? dims[4] : 1;
69-
const int F = m_d.ndims() > 5 ? dims[5] : 1;
67+
const int B = ndims <= 1 ? 1 : B_blocked ? pdims[1] / blksize : dims[1];
68+
const int C = ndims <= 2 ? 1 : C_blocked ? pdims[2] / blksize : dims[2];
69+
const int D = ndims <= 3 ? 1 : dims[3];
70+
const int E = ndims <= 4 ? 1 : dims[4];
71+
const int F = ndims <= 5 ? 1 : dims[5];
7072
const int inner_blk = blk.inner_nblks == 3 ? blk.inner_blks[2] : 1;
7173

7274
auto zeroize_tail = [&](data_t *d, const int tail_s) {

0 commit comments

Comments
 (0)