@@ -231,6 +231,14 @@ struct gen_gemm_t : public gpu_gemm_t {
231231 VERBOSE_SHAPE_RESTRICTION);
232232 }
233233
234+ auto &wei_scales = attr ()->scales_ .get (DNNL_ARG_WEIGHTS);
235+ auto &src_scales = attr ()->scales_ .get (DNNL_ARG_SRC);
236+
237+ if (quant_enabled_ && !wei_scales.has_default_groups ())
238+ wei_scales_2d_ = true ;
239+ if (quant_enabled_ && !src_scales.has_default_groups ())
240+ src_scales_2d_ = true ;
241+
234242 if (!attr ()->zero_points_ .has_default_values ()) {
235243 if (!attr_zps.has_default_values (DNNL_ARG_A)) {
236244 const int cmask_a = attr_zps.get_mask (DNNL_ARG_A);
@@ -262,6 +270,12 @@ struct gen_gemm_t : public gpu_gemm_t {
262270 VDISPATCH_GEMM (utils::one_of (cmask_a, 0 , mask_per_oc,
263271 mask_per_ic),
264272 VERBOSE_UNSUPPORTED_ZP_CFG);
273+ // Weights zp can only be performantly enabled during upconversion
274+ // for cases that perform decompression.
275+ VDISPATCH_GEMM (wei_decomp_
276+ || utils::one_of (d->a_type (), s4, u4)
277+ || !wei_scales_2d_,
278+ VERBOSE_UNSUPPORTED_ZP_CFG);
265279 }
266280 }
267281
@@ -307,14 +321,6 @@ struct gen_gemm_t : public gpu_gemm_t {
307321 if (swap_ab_) std::swap (ao_dims_, bo_dims_);
308322 }
309323
310- auto &wei_scales = attr ()->scales_ .get (DNNL_ARG_WEIGHTS);
311- auto &src_scales = attr ()->scales_ .get (DNNL_ARG_SRC);
312-
313- if (quant_enabled_ && !wei_scales.has_default_groups ())
314- wei_scales_2d_ = true ;
315- if (quant_enabled_ && !src_scales.has_default_groups ())
316- src_scales_2d_ = true ;
317-
318324 for (auto s : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) {
319325 if (attr ()->scales_ .has_default_values (s)) continue ;
320326
@@ -399,6 +405,7 @@ struct gen_gemm_t : public gpu_gemm_t {
399405 : data_type::s32;
400406 if (swap_ab_) std::swap (ao_type, bo_type);
401407 bool int_acc = utils::one_of (eff_a_type (), s8, u8 );
408+ int_acc &= !wei_scales_2d_;
402409 auto co_type = with_bias () ? d->bias_type ()
403410 : with_sum_ab () ? d->sum_ab_type
404411 : int_acc ? s32
0 commit comments