1818
1919#include " jit_avx2_pooling_generator_f32.hpp"
2020
21- #define ymm_store_mask Ymm (15 )
22- #define ymm_input Ymm (14 )
23- #define ymm_index Ymm (13 )
24- #define xmm_index Xmm (13 )
25- #define ymm_simd Ymm (12 )
26- #define xmm_simd Xmm (12 )
27- #define ymm_simd_stride_w Ymm (11 )
28- #define xmm_simd_stride_w Xmm (11 )
29- #define ymm_ki_offset Ymm (10 )
30- #define xmm_ki_offset Xmm (10 )
31- #define ymm_ji_offset Ymm (9 )
32- #define xmm_ji_offset Xmm (9 )
33- #define ymm_tmp Ymm (8 )
34- #define xmm_tmp Xmm (8 )
35- #define ymm_offset_base Ymm (7 )
36- #define xmm_offset_base Xmm (7 )
21+ #define ymm_store_mask Ymm (15 )
22+ #define ymm_input Ymm (14 )
23+ #define ymm_tmp Ymm (13 )
24+ #define xmm_tmp Xmm (13 )
25+ #define ymm_index Ymm (12 )
26+ #define xmm_index Xmm (12 )
27+ #define ymm_c_block Ymm (11 )
28+ #define xmm_c_block Xmm (11 )
29+ #define ymm_c_block_stride_w Ymm (10 )
30+ #define xmm_c_block_stride_w Xmm (10 )
31+ #define ymm_ki_offset Ymm (9 )
32+ #define xmm_ki_offset Xmm (9 )
33+ #define ymm_ji_offset Ymm (8 )
34+ #define xmm_ji_offset Xmm (8 )
35+ #define ymm_offset_base Ymm (7 )
36+ #define xmm_offset_base Xmm (7 )
3737
3838namespace mkldnn {
3939namespace impl {
4040namespace cpu {
4141
4242inline void jit_avx2_pooling_generator_f32::oh_step (
43- jit_pooling_param_t *params, uint32_t ur_w,
43+ jit_pooling_param_t *params, int ur_w,
4444 int pad_l, int pad_r, const char * kh_lable)
4545{
4646 using Xbyak::Ymm;
@@ -53,16 +53,17 @@ inline void jit_avx2_pooling_generator_f32::oh_step(
5353 } cvt;
5454 cvt._flt_max = -FLT_MAX;
5555
56- uint32_t IW = params->iw ;
57- uint32_t KW = params->kw ;
58- uint32_t stride_w = params->stride_w ;
56+ int iw = params->iw ;
57+ int kw = params->kw ;
58+ int stride_w = params->stride_w ;
59+ int c_block = params->c_block ;
5960
6061 vpxor (ymm_store_mask, ymm_store_mask);
6162
6263 mov (tmp_gpr, cvt._flt_max_int );
6364 movq (xmm_tmp, tmp_gpr);
6465 vbroadcastss (ymm_tmp, xmm_tmp);
65- for (uint32_t jj = 0 ; jj < ur_w; jj++)
66+ for (int jj = 0 ; jj < ur_w; jj++)
6667 vmovaps (Ymm (jj), ymm_tmp);
6768
6869 mov (aux_reg_input , reg_input);
@@ -71,23 +72,23 @@ inline void jit_avx2_pooling_generator_f32::oh_step(
7172 if (this ->_is_training ) {
7273 vpxor (ymm_ki_offset, ymm_ki_offset);
7374 }
74- for (uint32_t ki = 0 ; ki < KW ; ki++) {
75- int jj_start = nstl::max (0 , pad_l-( int ) ki);
75+ for (int ki = 0 ; ki < kw ; ki++) {
76+ int jj_start = nstl::max (0 , pad_l - ki);
7677 int jj_end = (int )ur_w -
77- nstl::max (0 , ( int )ki+ pad_r - (int )(KW -1 ));
78+ nstl::max (0 , ki + pad_r - (kw -1 ));
7879 if (this ->_is_training ) {
7980 vmovaps (ymm_index, ymm_ki_offset);
8081 vmovaps (ymm_ji_offset, ymm_offset_base);
8182 if (jj_start != 0 ) {
82- mov (tmp_gpr,(jj_start * stride_w * params-> c_block ));
83+ mov (tmp_gpr,(jj_start * stride_w * c_block));
8384 movq (xmm_tmp, tmp_gpr);
8485 vpbroadcastd (ymm_tmp, xmm_tmp);
8586 vpaddd (ymm_ji_offset, ymm_ji_offset, ymm_tmp);
8687 }
8788 }
8889 for (int jj = jj_start; jj < jj_end; jj++) {
89- int aux_input_offset = (ki+jj*stride_w-pad_l)*params-> c_block ;
90- if (aux_input_offset > ( int )IW*( int )params-> c_block )
90+ int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
91+ if (aux_input_offset > iw * c_block)
9192 continue ;
9293 if (this ->_is_training ) {
9394 vpaddd (ymm_index, ymm_ki_offset, ymm_ji_offset);
@@ -99,24 +100,23 @@ inline void jit_avx2_pooling_generator_f32::oh_step(
99100 if (this ->_is_training ) {
100101 vblendvps (Ymm (ur_w+jj), Ymm (ur_w+jj), ymm_index,
101102 ymm_store_mask);
102- vpaddd (ymm_ji_offset, ymm_ji_offset , ymm_simd_stride_w );
103+ vpaddd (ymm_ji_offset, ymm_ji_offset , ymm_c_block_stride_w );
103104 }
104105 }
105106 if (this ->_is_training ) {
106- vpaddd (ymm_ki_offset, ymm_ki_offset , ymm_simd );
107+ vpaddd (ymm_ki_offset, ymm_ki_offset , ymm_c_block );
107108 }
108109 }
109- add (aux_reg_input, sizeof (float )*IW*params-> c_block );
110+ add (aux_reg_input, sizeof (float ) * iw * c_block);
110111 inc (kj);
111112 cmp (kj, reg_kh);
112113 jl (kh_lable, T_NEAR);
113114 }
114115
115- for (uint32_t jj = 0 ; jj < ur_w; jj++) {
116- vmovups (YWORD[reg_output + sizeof (float )*jj*params-> c_block ], Ymm (jj));
116+ for (int jj = 0 ; jj < ur_w; jj++) {
117+ vmovups (YWORD[reg_output + sizeof (float )*jj*c_block], Ymm (jj));
117118 if (this ->_is_training )
118- vmovdqa (YWORD[reg_index + sizeof (uint32_t )*jj*params->c_block ],
119- Ymm (ur_w+jj));
119+ vmovdqa (YWORD[reg_index + sizeof (int )*jj*c_block], Ymm (ur_w+jj));
120120 }
121121}
122122
@@ -129,7 +129,16 @@ jit_avx2_pooling_generator_f32::jit_avx2_pooling_generator_f32(
129129 using Xbyak::Ymm;
130130 this ->preamble ();
131131
132- int n_oi = params->ow / params->ur_w ;
132+ int ow = params->ow ;
133+ int iw = params->iw ;
134+ int kw = params->kw ;
135+ int ur_w = params->ur_w ;
136+ int c_block = params->c_block ;
137+ int stride_w = params->stride_w ;
138+ int l_pad = params->l_pad ;
139+ int ur_w_tail = params->ur_w_tail ;
140+
141+ int n_oi = ow / ur_w;
133142
134143 mov (reg_input , ptr [ this ->param1 ]);
135144 mov (reg_output, ptr [ this ->param1 + 8 ]);
@@ -140,72 +149,65 @@ jit_avx2_pooling_generator_f32::jit_avx2_pooling_generator_f32(
140149 mov (reg_arr_init, ptr [ this ->param1 + 80 ]);
141150
142151 if (this ->_is_training ) {
143- mov (tmp_gpr,(params-> c_block ) );
144- movq (xmm_simd , tmp_gpr);
145- vpbroadcastd (ymm_simd, xmm_simd );
152+ mov (tmp_gpr,c_block);
153+ movq (xmm_c_block , tmp_gpr);
154+ vpbroadcastd (ymm_c_block, xmm_c_block );
146155
147- mov (tmp_gpr,(params-> stride_w * params-> c_block ));
148- movq (xmm_simd_stride_w , tmp_gpr);
149- vpbroadcastd (ymm_simd_stride_w, xmm_simd_stride_w );
156+ mov (tmp_gpr,(stride_w * c_block));
157+ movq (xmm_c_block_stride_w , tmp_gpr);
158+ vpbroadcastd (ymm_c_block_stride_w, xmm_c_block_stride_w );
150159
151160 vmovdqu (ymm_offset_base, ptr [ reg_arr_init ]);
152- if (params-> l_pad > 0 ) {
153- mov (tmp_gpr,(params-> l_pad * params-> c_block ));
161+ if (l_pad > 0 ) {
162+ mov (tmp_gpr,(l_pad * c_block));
154163 movq (xmm_tmp, tmp_gpr);
155164 vpbroadcastd (ymm_tmp, xmm_tmp);
156165 vpsubd (ymm_offset_base, ymm_offset_base, ymm_tmp);
157166 }
158167 }
159168
160- int r_pad = nstl::max (0 , (int )((params->ow -1 )*params->stride_w ) +
161- (int )params->kw - 1 - (int )(params->iw + params->l_pad - 1 ));
162- int r_pad1 = (int )(params->ur_w *n_oi - 1 )*params->stride_w +
163- params->kw - 1 - (params->iw + params->l_pad - 1 );
169+ int r_pad = nstl::max (0 , ((ow-1 )*stride_w) + kw - 1 - (iw + l_pad - 1 ));
170+ int r_pad1 = (ur_w*n_oi - 1 )*stride_w + kw - 1 - (iw + l_pad - 1 );
164171 if (r_pad1 > 0 ) n_oi--;
165172
166- if (params-> l_pad > 0 ) {
173+ if (l_pad > 0 ) {
167174 n_oi--;
168175 if (n_oi < 0 && r_pad1 > 0 ) {
169- oh_step (params, params->ur_w , params->l_pad , r_pad1,
170- " .kh_loop_oimain_padwl" );
176+ oh_step (params, ur_w, l_pad, r_pad1, " .kh_loop_oimain_padwl" );
171177 } else {
172- oh_step (params, params->ur_w , params->l_pad , 0 ,
173- " .kh_loop_oimain_padwl" );
178+ oh_step (params, ur_w, l_pad, 0 , " .kh_loop_oimain_padwl" );
174179 }
175180
176- add (reg_input, sizeof (float )*(params->ur_w *params->stride_w -
177- params->l_pad )*params->c_block );
178- add (reg_output, sizeof (float )*params->ur_w *params->c_block );
181+ add (reg_input, sizeof (float )*(ur_w*stride_w - l_pad)*c_block);
182+ add (reg_output, sizeof (float )*ur_w*c_block);
179183 if (this ->_is_training )
180- add (reg_index, sizeof (uint32_t )*params-> ur_w *params-> c_block );
184+ add (reg_index, sizeof (int )* ur_w*c_block);
181185 }
182186
183187 xor_ (oi_iter, oi_iter);
184188 if (n_oi > 0 ) {
185189 L (" .ow_loop" ); {
186- oh_step (params, params->ur_w , 0 , 0 , " .kh_loop_oimain" );
187- add (reg_input,
188- sizeof (float )*params->ur_w *params->stride_w *params->c_block );
189- add (reg_output, sizeof (float )*params->ur_w *params->c_block );
190+ oh_step (params, ur_w, 0 , 0 , " .kh_loop_oimain" );
191+ add (reg_input, sizeof (float )*ur_w*stride_w*c_block);
192+ add (reg_output, sizeof (float )*ur_w*c_block);
190193 if (this ->_is_training )
191- add (reg_index, sizeof (uint32_t )*params-> ur_w *params-> c_block );
194+ add (reg_index, sizeof (int )* ur_w*c_block);
192195
193196 inc (oi_iter);
194197 cmp (oi_iter, n_oi); jl (" .ow_loop" , T_NEAR);
195198 } L (" .ow_loop_end" );
196199 }
197200
198201 if (r_pad1 > 0 && n_oi >= 0 ) {
199- oh_step (params, params->ur_w , 0 , r_pad1, " .kh_loop_oimain_padwr" );
200- add (reg_input,
201- sizeof (float )*params->ur_w *params->stride_w *params->c_block );
202- add (reg_output,sizeof (float )*params->ur_w *params->c_block );
202+ oh_step (params, ur_w, 0 , r_pad1, " .kh_loop_oimain_padwr" );
203+ add (reg_input, sizeof (float )*ur_w*stride_w*c_block);
204+ add (reg_output, sizeof (float )*ur_w*c_block);
203205 if (this ->_is_training )
204- add (reg_index, sizeof (uint32_t ) * params-> ur_w * params-> c_block );
206+ add (reg_index, sizeof (int ) * ur_w * c_block);
205207 }
206208
207- if (params-> ur_w_tail != 0 )
208- oh_step (params, params-> ur_w_tail , 0 , r_pad, " .kh_loop_oitail" );
209+ if (ur_w_tail != 0 )
210+ oh_step (params, ur_w_tail, 0 , r_pad, " .kh_loop_oitail" );
209211
210212 this ->postamble ();
211213 return ;
0 commit comments