Skip to content

Commit e833b10

Browse files
committed
cpu: pooling: change uint32_t to int
1 parent af977e7 commit e833b10

File tree

2 files changed

+71
-69
lines changed

2 files changed

+71
-69
lines changed

src/cpu/jit_avx2_pooling_generator_f32.cpp

Lines changed: 70 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,29 @@
1818

1919
#include "jit_avx2_pooling_generator_f32.hpp"
2020

21-
#define ymm_store_mask Ymm(15)
22-
#define ymm_input Ymm(14)
23-
#define ymm_index Ymm(13)
24-
#define xmm_index Xmm(13)
25-
#define ymm_simd Ymm(12)
26-
#define xmm_simd Xmm(12)
27-
#define ymm_simd_stride_w Ymm(11)
28-
#define xmm_simd_stride_w Xmm(11)
29-
#define ymm_ki_offset Ymm(10)
30-
#define xmm_ki_offset Xmm(10)
31-
#define ymm_ji_offset Ymm(9)
32-
#define xmm_ji_offset Xmm(9)
33-
#define ymm_tmp Ymm(8)
34-
#define xmm_tmp Xmm(8)
35-
#define ymm_offset_base Ymm(7)
36-
#define xmm_offset_base Xmm(7)
21+
#define ymm_store_mask Ymm(15)
22+
#define ymm_input Ymm(14)
23+
#define ymm_tmp Ymm(13)
24+
#define xmm_tmp Xmm(13)
25+
#define ymm_index Ymm(12)
26+
#define xmm_index Xmm(12)
27+
#define ymm_c_block Ymm(11)
28+
#define xmm_c_block Xmm(11)
29+
#define ymm_c_block_stride_w Ymm(10)
30+
#define xmm_c_block_stride_w Xmm(10)
31+
#define ymm_ki_offset Ymm(9)
32+
#define xmm_ki_offset Xmm(9)
33+
#define ymm_ji_offset Ymm(8)
34+
#define xmm_ji_offset Xmm(8)
35+
#define ymm_offset_base Ymm(7)
36+
#define xmm_offset_base Xmm(7)
3737

3838
namespace mkldnn {
3939
namespace impl {
4040
namespace cpu {
4141

4242
inline void jit_avx2_pooling_generator_f32::oh_step(
43-
jit_pooling_param_t *params, uint32_t ur_w,
43+
jit_pooling_param_t *params, int ur_w,
4444
int pad_l, int pad_r, const char* kh_lable)
4545
{
4646
using Xbyak::Ymm;
@@ -53,16 +53,17 @@ inline void jit_avx2_pooling_generator_f32::oh_step(
5353
} cvt;
5454
cvt._flt_max = -FLT_MAX;
5555

56-
uint32_t IW = params->iw;
57-
uint32_t KW = params->kw;
58-
uint32_t stride_w = params->stride_w;
56+
int iw = params->iw;
57+
int kw = params->kw;
58+
int stride_w = params->stride_w;
59+
int c_block = params->c_block;
5960

6061
vpxor(ymm_store_mask, ymm_store_mask);
6162

6263
mov(tmp_gpr, cvt._flt_max_int);
6364
movq(xmm_tmp, tmp_gpr);
6465
vbroadcastss(ymm_tmp, xmm_tmp);
65-
for (uint32_t jj = 0; jj < ur_w; jj++)
66+
for (int jj = 0; jj < ur_w; jj++)
6667
vmovaps(Ymm(jj), ymm_tmp);
6768

6869
mov(aux_reg_input , reg_input);
@@ -71,23 +72,23 @@ inline void jit_avx2_pooling_generator_f32::oh_step(
7172
if (this->_is_training) {
7273
vpxor(ymm_ki_offset, ymm_ki_offset);
7374
}
74-
for (uint32_t ki = 0; ki < KW; ki++) {
75-
int jj_start = nstl::max(0, pad_l-(int)ki);
75+
for (int ki = 0; ki < kw; ki++) {
76+
int jj_start = nstl::max(0, pad_l - ki);
7677
int jj_end = (int)ur_w -
77-
nstl::max(0, (int)ki+pad_r - (int)(KW-1));
78+
nstl::max(0, ki + pad_r - (kw-1));
7879
if (this->_is_training) {
7980
vmovaps(ymm_index, ymm_ki_offset);
8081
vmovaps(ymm_ji_offset, ymm_offset_base);
8182
if (jj_start != 0) {
82-
mov(tmp_gpr,(jj_start * stride_w * params->c_block));
83+
mov(tmp_gpr,(jj_start * stride_w * c_block));
8384
movq(xmm_tmp, tmp_gpr);
8485
vpbroadcastd(ymm_tmp, xmm_tmp);
8586
vpaddd(ymm_ji_offset, ymm_ji_offset, ymm_tmp);
8687
}
8788
}
8889
for (int jj = jj_start; jj < jj_end; jj++) {
89-
int aux_input_offset = (ki+jj*stride_w-pad_l)*params->c_block;
90-
if (aux_input_offset > (int)IW*(int)params->c_block)
90+
int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
91+
if (aux_input_offset > iw * c_block)
9192
continue;
9293
if (this->_is_training) {
9394
vpaddd(ymm_index, ymm_ki_offset, ymm_ji_offset);
@@ -99,24 +100,23 @@ inline void jit_avx2_pooling_generator_f32::oh_step(
99100
if (this->_is_training) {
100101
vblendvps(Ymm(ur_w+jj), Ymm(ur_w+jj), ymm_index,
101102
ymm_store_mask);
102-
vpaddd(ymm_ji_offset, ymm_ji_offset , ymm_simd_stride_w);
103+
vpaddd(ymm_ji_offset, ymm_ji_offset , ymm_c_block_stride_w);
103104
}
104105
}
105106
if (this->_is_training) {
106-
vpaddd(ymm_ki_offset, ymm_ki_offset , ymm_simd);
107+
vpaddd(ymm_ki_offset, ymm_ki_offset , ymm_c_block);
107108
}
108109
}
109-
add(aux_reg_input, sizeof(float)*IW*params->c_block);
110+
add(aux_reg_input, sizeof(float) * iw * c_block);
110111
inc(kj);
111112
cmp(kj, reg_kh);
112113
jl(kh_lable, T_NEAR);
113114
}
114115

115-
for (uint32_t jj = 0; jj < ur_w; jj++) {
116-
vmovups(YWORD[reg_output + sizeof(float)*jj*params->c_block], Ymm(jj));
116+
for (int jj = 0; jj < ur_w; jj++) {
117+
vmovups(YWORD[reg_output + sizeof(float)*jj*c_block], Ymm(jj));
117118
if (this->_is_training)
118-
vmovdqa(YWORD[reg_index + sizeof(uint32_t)*jj*params->c_block],
119-
Ymm(ur_w+jj));
119+
vmovdqa(YWORD[reg_index + sizeof(int)*jj*c_block], Ymm(ur_w+jj));
120120
}
121121
}
122122

@@ -129,7 +129,16 @@ jit_avx2_pooling_generator_f32::jit_avx2_pooling_generator_f32(
129129
using Xbyak::Ymm;
130130
this->preamble();
131131

132-
int n_oi = params->ow / params->ur_w;
132+
int ow = params->ow;
133+
int iw = params->iw;
134+
int kw = params->kw;
135+
int ur_w = params->ur_w;
136+
int c_block = params->c_block;
137+
int stride_w = params->stride_w;
138+
int l_pad = params->l_pad;
139+
int ur_w_tail = params->ur_w_tail;
140+
141+
int n_oi = ow / ur_w;
133142

134143
mov(reg_input , ptr [ this->param1 ]);
135144
mov(reg_output, ptr [ this->param1 + 8]);
@@ -140,72 +149,65 @@ jit_avx2_pooling_generator_f32::jit_avx2_pooling_generator_f32(
140149
mov(reg_arr_init, ptr [ this->param1 + 80]);
141150

142151
if (this->_is_training) {
143-
mov(tmp_gpr,(params->c_block));
144-
movq(xmm_simd, tmp_gpr);
145-
vpbroadcastd(ymm_simd, xmm_simd);
152+
mov(tmp_gpr,c_block);
153+
movq(xmm_c_block, tmp_gpr);
154+
vpbroadcastd(ymm_c_block, xmm_c_block);
146155

147-
mov(tmp_gpr,(params->stride_w * params->c_block));
148-
movq(xmm_simd_stride_w, tmp_gpr);
149-
vpbroadcastd(ymm_simd_stride_w, xmm_simd_stride_w);
156+
mov(tmp_gpr,(stride_w * c_block));
157+
movq(xmm_c_block_stride_w, tmp_gpr);
158+
vpbroadcastd(ymm_c_block_stride_w, xmm_c_block_stride_w);
150159

151160
vmovdqu(ymm_offset_base, ptr [ reg_arr_init ]);
152-
if (params->l_pad > 0) {
153-
mov(tmp_gpr,(params->l_pad * params->c_block));
161+
if (l_pad > 0) {
162+
mov(tmp_gpr,(l_pad * c_block));
154163
movq(xmm_tmp, tmp_gpr);
155164
vpbroadcastd(ymm_tmp, xmm_tmp);
156165
vpsubd(ymm_offset_base, ymm_offset_base, ymm_tmp);
157166
}
158167
}
159168

160-
int r_pad = nstl::max(0, (int)((params->ow-1)*params->stride_w) +
161-
(int)params->kw - 1 - (int)(params->iw + params->l_pad - 1 ));
162-
int r_pad1 = (int)(params->ur_w*n_oi - 1)*params->stride_w +
163-
params->kw - 1 - (params->iw + params->l_pad - 1);
169+
int r_pad = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad - 1 ));
170+
int r_pad1 = (ur_w*n_oi - 1)*stride_w + kw - 1 - (iw + l_pad - 1);
164171
if (r_pad1 > 0) n_oi--;
165172

166-
if (params->l_pad > 0) {
173+
if (l_pad > 0) {
167174
n_oi--;
168175
if (n_oi < 0 && r_pad1 > 0) {
169-
oh_step(params, params->ur_w, params->l_pad, r_pad1,
170-
".kh_loop_oimain_padwl");
176+
oh_step(params, ur_w, l_pad, r_pad1, ".kh_loop_oimain_padwl");
171177
} else {
172-
oh_step(params, params->ur_w, params->l_pad, 0,
173-
".kh_loop_oimain_padwl");
178+
oh_step(params, ur_w, l_pad, 0, ".kh_loop_oimain_padwl");
174179
}
175180

176-
add(reg_input, sizeof(float)*(params->ur_w*params->stride_w -
177-
params->l_pad)*params->c_block);
178-
add(reg_output, sizeof(float)*params->ur_w*params->c_block);
181+
add(reg_input, sizeof(float)*(ur_w*stride_w - l_pad)*c_block);
182+
add(reg_output, sizeof(float)*ur_w*c_block);
179183
if (this->_is_training)
180-
add(reg_index, sizeof(uint32_t)*params->ur_w*params->c_block);
184+
add(reg_index, sizeof(int)*ur_w*c_block);
181185
}
182186

183187
xor_(oi_iter, oi_iter);
184188
if (n_oi > 0) {
185189
L(".ow_loop"); {
186-
oh_step(params, params->ur_w, 0, 0, ".kh_loop_oimain");
187-
add(reg_input,
188-
sizeof(float)*params->ur_w*params->stride_w*params->c_block);
189-
add(reg_output, sizeof(float)*params->ur_w*params->c_block);
190+
oh_step(params, ur_w, 0, 0, ".kh_loop_oimain");
191+
add(reg_input, sizeof(float)*ur_w*stride_w*c_block);
192+
add(reg_output, sizeof(float)*ur_w*c_block);
190193
if (this->_is_training)
191-
add(reg_index, sizeof(uint32_t)*params->ur_w*params->c_block);
194+
add(reg_index, sizeof(int)*ur_w*c_block);
192195

193196
inc(oi_iter);
194197
cmp(oi_iter, n_oi); jl(".ow_loop", T_NEAR);
195198
} L(".ow_loop_end");
196199
}
197200

198201
if (r_pad1 > 0 && n_oi >= 0) {
199-
oh_step(params, params->ur_w, 0, r_pad1, ".kh_loop_oimain_padwr");
200-
add(reg_input,
201-
sizeof(float)*params->ur_w*params->stride_w*params->c_block);
202-
add(reg_output,sizeof(float)*params->ur_w*params->c_block);
202+
oh_step(params, ur_w, 0, r_pad1, ".kh_loop_oimain_padwr");
203+
add(reg_input, sizeof(float)*ur_w*stride_w*c_block);
204+
add(reg_output, sizeof(float)*ur_w*c_block);
203205
if (this->_is_training)
204-
add(reg_index, sizeof(uint32_t) * params->ur_w * params->c_block);
206+
add(reg_index, sizeof(int) * ur_w * c_block);
205207
}
206208

207-
if (params->ur_w_tail != 0)
208-
oh_step(params, params->ur_w_tail, 0, r_pad, ".kh_loop_oitail");
209+
if (ur_w_tail != 0)
210+
oh_step(params, ur_w_tail, 0, r_pad, ".kh_loop_oitail");
209211

210212
this->postamble();
211213
return;

src/cpu/jit_avx2_pooling_generator_f32.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class jit_avx2_pooling_generator_f32 : public jit_generator {
7070

7171
const bool _is_training;
7272

73-
inline void oh_step(jit_pooling_param_t *params, uint32_t ur_w,
73+
inline void oh_step(jit_pooling_param_t *params, int ur_w,
7474
int pad_l, int pad_r, const char* kh_lable);
7575
public:
7676
jit_avx2_pooling_generator_f32(

0 commit comments

Comments
 (0)