Skip to content

Commit 715b476

Browse files
kealan-barbierikarturov
authored andcommitted
xe: jit: gemm: fixup gemm acc, zp limits
1 parent 5422e6e commit 715b476

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

src/gpu/intel/jit/gemm/gen_gemm.hpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,14 @@ struct gen_gemm_t : public gpu_gemm_t {
231231
VERBOSE_SHAPE_RESTRICTION);
232232
}
233233

234+
auto &wei_scales = attr()->scales_.get(DNNL_ARG_WEIGHTS);
235+
auto &src_scales = attr()->scales_.get(DNNL_ARG_SRC);
236+
237+
if (quant_enabled_ && !wei_scales.has_default_groups())
238+
wei_scales_2d_ = true;
239+
if (quant_enabled_ && !src_scales.has_default_groups())
240+
src_scales_2d_ = true;
241+
234242
if (!attr()->zero_points_.has_default_values()) {
235243
if (!attr_zps.has_default_values(DNNL_ARG_A)) {
236244
const int cmask_a = attr_zps.get_mask(DNNL_ARG_A);
@@ -262,6 +270,12 @@ struct gen_gemm_t : public gpu_gemm_t {
262270
VDISPATCH_GEMM(utils::one_of(cmask_a, 0, mask_per_oc,
263271
mask_per_ic),
264272
VERBOSE_UNSUPPORTED_ZP_CFG);
273+
// Weights zp can only be performantly enabled during upconversion
274+
// for cases that perform decompression.
275+
VDISPATCH_GEMM(wei_decomp_
276+
|| utils::one_of(d->a_type(), s4, u4)
277+
|| !wei_scales_2d_,
278+
VERBOSE_UNSUPPORTED_ZP_CFG);
265279
}
266280
}
267281

@@ -307,14 +321,6 @@ struct gen_gemm_t : public gpu_gemm_t {
307321
if (swap_ab_) std::swap(ao_dims_, bo_dims_);
308322
}
309323

310-
auto &wei_scales = attr()->scales_.get(DNNL_ARG_WEIGHTS);
311-
auto &src_scales = attr()->scales_.get(DNNL_ARG_SRC);
312-
313-
if (quant_enabled_ && !wei_scales.has_default_groups())
314-
wei_scales_2d_ = true;
315-
if (quant_enabled_ && !src_scales.has_default_groups())
316-
src_scales_2d_ = true;
317-
318324
for (auto s : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) {
319325
if (attr()->scales_.has_default_values(s)) continue;
320326

@@ -399,6 +405,7 @@ struct gen_gemm_t : public gpu_gemm_t {
399405
: data_type::s32;
400406
if (swap_ab_) std::swap(ao_type, bo_type);
401407
bool int_acc = utils::one_of(eff_a_type(), s8, u8);
408+
int_acc &= !wei_scales_2d_;
402409
auto co_type = with_bias() ? d->bias_type()
403410
: with_sum_ab() ? d->sum_ab_type
404411
: int_acc ? s32

0 commit comments

Comments
 (0)