@@ -240,20 +240,27 @@ status_t brgemm_blocking(brgemm_desc_t *brg) {
240240 brg->ldb = brg->load_dim / brg->ld_block ;
241241 brg->ldb_tail = brg->load_dim % brg->ld_block ;
242242
243+ const int max_vpad = nstl::max (
244+ brg->brgattr .max_top_vpad , brg->brgattr .max_bottom_vpad );
245+
243246 int adj_ld_block2 = calculate_ldb_params (brg, 4 );
244247 int max_bcast_block = calculate_max_bcast_block (brg, adj_ld_block2);
245-
246248 // reduce 'ld_block2' to allow a larger 'bd_block'
247- const int max_vpad = nstl::max (
248- brg->brgattr .max_top_vpad , brg->brgattr .max_bottom_vpad );
249249 if (is_superset (brg->isa_impl , avx2) && max_bcast_block < max_vpad) {
250- adj_ld_block2 = calculate_ldb_params (brg, 2 );
251- max_bcast_block = calculate_max_bcast_block (brg, adj_ld_block2);
250+ for (int try_ld_block2 = 2 ; try_ld_block2 > 0 ; --try_ld_block2) {
251+ adj_ld_block2 = calculate_ldb_params (brg, try_ld_block2);
252+ max_bcast_block = calculate_max_bcast_block (brg, adj_ld_block2);
253+ if (max_bcast_block >= max_vpad) break ;
254+ }
255+ // bcast block in brgemm kernel should be greater than virtual
256+ // padding to avoid possible functional issues
257+ if (max_bcast_block < max_vpad) return status::unimplemented;
252258 }
253259
254- const int min_block = 1 ;
260+ const int min_block = nstl::max (1 , max_vpad);
261+
255262 float best_bd_block_eff = 0 .f ;
256- brg->bd_block = 1 ;
263+ brg->bd_block = max_bcast_block ;
257264 for (int bd_block = max_bcast_block; bd_block >= min_block;
258265 bd_block--) {
259266 const auto bd_block_disb = static_cast <float >(brg->bcast_dim )
0 commit comments