Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XNNPACK] resolve ambiguity around 2d affine quantized tensors #8958

Merged
merged 1 commit into from
Mar 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/xnnpack/operators/op_dynamic_dequantize_ops.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
)
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import XNNGraph
from executorch.backends.xnnpack.utils.quant_utils import (
is_dynamic_qdq,
is_per_channel_group,
is_per_token,
)
@@ -92,7 +93,8 @@ def define_node(
"""
We always define dequantize affine nodes because they are always explicit
"""
if is_per_channel_group(node):
is_dynamic = is_dynamic_qdq(node)
if is_per_channel_group(node) and not is_dynamic:
check_or_raise(
is_param_node(self._exported_program, node.all_input_nodes[0]),
f"Expected quantize affine node with per-token semantics to be used "
@@ -103,7 +105,7 @@ def define_node(
return

check_or_raise(
is_per_token(node),
is_per_token(node) and is_dynamic,
"Expecting Affine Dequantized Op to have per-token semantics",
)
# This must be a per-token affine dequantized node, so let us serialize as such
6 changes: 4 additions & 2 deletions backends/xnnpack/operators/op_dynamic_quantize_ops.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
XNode,
)
from executorch.backends.xnnpack.utils.quant_utils import (
is_dynamic_qdq,
is_per_channel_group,
is_per_token,
)
@@ -138,13 +139,14 @@ def define_node(
"""
We always define quantize affine nodes because they are always explicit
"""
if is_per_channel_group(node):
is_dynamic = is_dynamic_qdq(node)
if is_per_channel_group(node) and not is_dynamic:
# Affine quantized was recognized as per channel group which means that it should
# be skipped as this means it is used in front of a weight node
return

check_or_raise(
is_per_token(node),
is_per_token(node) and is_dynamic,
"Encountered affine quantized op which does not have per-token semantics",
)
# Treat this node as dynamic per-token quantization
49 changes: 25 additions & 24 deletions backends/xnnpack/test/ops/test_linear.py
Original file line number Diff line number Diff line change
@@ -645,31 +645,32 @@ def _test_qd8_per_token_weight_per_channel_group_int4(
bl_sizes = [32, 32, 32, 64]
N_sizes = [2, 17, 92, 128]

for use_bias in [True, False]:
for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes):
lin_mod = BaseLinear(
in_size=M,
input_channels=K,
output_channels=N,
dtype=dtype,
use_bias=use_bias,
)
for input_rank in range(2, 4):
for use_bias in [True, False]:
for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes):
lin_mod = BaseLinear(
in_size=M,
input_channels=K,
output_channels=N,
dtype=dtype,
use_bias=use_bias,
)

inputs = lin_mod.get_inputs()
# Half requires slightly higher atol, but if you look at error it is not that bad:
# Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
# -- Model vs. Reference --
# Numel: 4, 4
# Median: -0.05023193359375, -0.0516357421875
# Mean: 0.2373046875, 0.237060546875
# Max: 1.0078125, 1.0078125
# Min: -0.08465576171875, -0.08441162109375
atol = (
1e-2 if dtype == torch.half else 5e-3
) # TODO(T212995726): Investigate right atol for rand[n] inputs
self._test_groupwise_dq_linear(
lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=atol
)
inputs = lin_mod.get_inputs(rank=input_rank)
# Half requires slightly higher atol, but if you look at error it is not that bad:
# Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
# -- Model vs. Reference --
# Numel: 4, 4
# Median: -0.05023193359375, -0.0516357421875
# Mean: 0.2373046875, 0.237060546875
# Max: 1.0078125, 1.0078125
# Min: -0.08465576171875, -0.08441162109375
atol = (
1e-2 if dtype == torch.half else 5e-3
) # TODO(T212995726): Investigate right atol for rand[n] inputs
self._test_groupwise_dq_linear(
lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=atol
)

def test_fp16_linear(self):
for use_bias in (True, False):
26 changes: 22 additions & 4 deletions backends/xnnpack/utils/quant_utils.py
Original file line number Diff line number Diff line change
@@ -47,12 +47,30 @@


def is_dynamic_qdq(node: torch.fx.Node) -> bool:
if node.op != "call_function":
# check has dynamic qdq name
if not (is_quant(node) or is_dequant(node)):
return False

# check scales and zp are dynamically chosen
node_input_args = node.args
if is_affine_qdq(node):
node_input_args = extract_qdq_affine_op_args_for_decomposed_ops(node)

scale = node_input_args[1]
zp = node_input_args[2]
if not (isinstance(scale, torch.fx.Node) and isinstance(zp, torch.fx.Node)):
return False

if not (scale.target == operator.getitem and zp.target == operator.getitem):
return False

scale_choose_qparam = scale.all_input_nodes[0]
zp_choose_qparam = zp.all_input_nodes[0]

if not (is_qparam(scale_choose_qparam) and is_qparam(zp_choose_qparam)):
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
is_dynamic_affine = is_per_token(node) and not is_per_channel_group(node)

return node_name in _DYNAMIC_OPS or is_dynamic_affine
return True


def is_qparam(node: torch.fx.Node) -> bool: