Skip to content

Commit 2eae802

Browse files
[ExecuTorch][XNNPACK] Don't partition per_tensor weights with qd8 (#8927)
This is not supported, so we shouldn't partition it. Add an expectedFailure test to indicate that this is not supported Differential Revision: [D70343584](https://our.internmc.facebook.com/intern/diff/D70343584/) ghstack-source-id: 269356867 Pull Request resolved: #8891 Co-authored-by: Digant Desai <[email protected]>
1 parent 7aa6494 commit 2eae802

File tree

3 files changed

+111
-9
lines changed

3 files changed

+111
-9
lines changed

backends/xnnpack/partition/config/gemm_configs.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
is_dynamic_qdq,
2222
is_per_channel,
2323
is_per_channel_group,
24+
is_per_tensor,
2425
is_qparam,
2526
is_quant,
2627
)
@@ -66,8 +67,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
6667
return False
6768

6869
is_valid, _ = self.get_deps(node, ep)
69-
if not is_valid:
70-
why(node, "Failed to get valid dependent nodes.")
7170
return is_valid
7271

7372
def get_node_and_deps(
@@ -123,6 +122,7 @@ def get_deps(
123122
precision = self._detect_precision(node)
124123
if precision not in self.supported_precision_types():
125124
# detected precision but it is either disabled or not supported
125+
why(node, f"Unsupported precision type {precision}")
126126
return (False, [])
127127
_, precision = self._overwrite_precision(node)
128128
valid_bias, bias_deps = self._get_bias_deps(node, ep, precision)
@@ -143,27 +143,42 @@ def _get_weight_deps(
143143
# First find the weight
144144
weight_node = get_input_node(node, self.weight_idx)
145145
if not is_param_node(ep, weight_node):
146-
return (False, []) # weight must be a static param
146+
why(node, "Expected weight to be a static param")
147+
return (False, [])
147148
gemm_deps.append(weight_node)
148149

149150
return (True, gemm_deps)
150151
else:
151152
# Quantized Weight deps
152153
dequant_node = get_input_node(node, self.weight_idx)
153154
if not is_dequant(dequant_node):
155+
why(node, "Expected weight to have a dequantized node")
154156
return False, []
155157
gemm_deps.append(dequant_node)
156158
weight = get_input_node(dequant_node, 0)
157159
if not is_param_node(ep, weight):
160+
why(node, "Expected weight to be a static param")
158161
return False, []
159162
gemm_deps.append(weight)
160163

164+
if (
165+
is_per_tensor(dequant_node)
166+
and precision == ConfigPrecisionType.DYNAMIC_QUANT
167+
):
168+
why(
169+
node,
170+
"XNNPACK does not support per tensor quantized weights for dynamic quantization of activations",
171+
)
172+
return False, []
173+
161174
if is_per_channel(dequant_node) or is_per_channel_group(dequant_node):
162175
if len(dequant_node.all_input_nodes) < 2:
163176
# Expected channel quantized to have scale/zp nodes
177+
why(node, "Expected channel quantized to have scale/zp nodes")
164178
return False, []
165179

166180
gemm_deps.extend(dequant_node.all_input_nodes[1:3])
181+
167182
return (True, gemm_deps)
168183

169184
def _get_output_deps(
@@ -174,7 +189,7 @@ def _get_output_deps(
174189
# Look for fused activations and tail end quant node
175190
node_users = list(node.users.keys())
176191
if len(node_users) != 1:
177-
# Expect quantized node to have a single output (fused act or dequant)
192+
why(node, "Expected quantized node to have a single output")
178193
return False, []
179194

180195
# Check if the quantized pattern has a fused activation
@@ -190,6 +205,7 @@ def _get_output_deps(
190205

191206
if not is_quant(n_output):
192207
# Expected gemm_node --> fused_act (optional) --> dequant
208+
why(node, "Expected output node to have a dequantized node")
193209
return (False, [])
194210
gemm_deps.append(n_output)
195211
elif precision == ConfigPrecisionType.FP32:
@@ -219,7 +235,8 @@ def _get_bias_deps(
219235
bias_node = get_input_node(node, self.bias_idx)
220236
if bias_node:
221237
if not is_param_node(ep, bias_node):
222-
return (False, []) # bias node must be a static param
238+
why(node, "Expected bias to be a static param")
239+
return (False, [])
223240
gemm_deps.append(bias_node)
224241

225242
return (True, gemm_deps)
@@ -233,7 +250,7 @@ def _get_act_deps(
233250
else:
234251
dq_input = get_input_node(node, self.act_idx)
235252
if not is_dequant(dq_input):
236-
# Expected static quant input to be dequant node
253+
why(node, "Expected act input to be dequant node")
237254
return False, []
238255
gemm_deps.append(dq_input)
239256
if precision == ConfigPrecisionType.STATIC_QUANT:
@@ -243,27 +260,28 @@ def _get_act_deps(
243260
# q input node
244261
q_input = get_input_node(dq_input, 0)
245262
if not is_quant(q_input):
263+
why(node, "Expected dequant input to be quant node")
246264
return (False, [])
247265

248266
gemm_deps.append(q_input)
249267
q_input_args = q_input.args
250268
if is_affine_qdq(q_input):
251269
q_input_args = extract_qdq_affine_op_args_for_decomposed_ops(q_input)
252270
if not (is_node(q_input_args[1]) and is_node(q_input_args[2])):
253-
# expected to find getitem node from choose qparam
271+
why(node, "expected to find getitem node from choose qparam")
254272
return (False, [])
255273

256274
getitem1 = q_input_args[1]
257275
getitem2 = q_input_args[2]
258276

259277
if not (is_getitem(getitem1) and is_getitem(getitem2)):
260-
# expected getitem node from choose qparam
278+
why(node, "expected getitem node from choose qparam")
261279
return (False, [])
262280

263281
gemm_deps.extend([getitem1, getitem2])
264282
choose_qparam = get_input_node(getitem1, 0)
265283
if not is_qparam(choose_qparam):
266-
# expected to find choose_qparam node
284+
why(node, "expected to find choose_qparam node")
267285
return (False, [])
268286
gemm_deps.append(choose_qparam)
269287
return (True, gemm_deps)
@@ -471,6 +489,7 @@ def find_partition_args(input_node):
471489
# there can only be a single output node in partition
472490
or len(src_partition.output_nodes) != 1
473491
):
492+
why(node, "invalid source partition")
474493
return (False, [])
475494

476495
# map addmm's args to the source partition linear's inputs and users

backends/xnnpack/test/ops/test_linear.py

+74
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,66 @@ def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float):
539539
uses_bias=uses_bias,
540540
)
541541

542+
def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.float):
543+
for uses_bias in (False, True):
544+
module = BaseLinear(
545+
in_size=8,
546+
input_channels=13,
547+
output_channels=17,
548+
dtype=dtype,
549+
use_bias=uses_bias,
550+
)
551+
inputs = module.get_inputs()
552+
dynamic_shapes = ({1: torch.export.Dim("batch", max=100)},)
553+
554+
quant_config = get_symmetric_quantization_config(
555+
is_per_channel=False,
556+
is_dynamic=True,
557+
)
558+
559+
for legacy_partitioner in (True, False):
560+
for per_op_mode in (True, False):
561+
# Every combination should fail to partition Linear or [add]mm.
562+
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
563+
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
564+
per_op_mode=per_op_mode,
565+
)
566+
567+
tester = Tester(module, inputs, dynamic_shapes=dynamic_shapes)
568+
tester.quantize(Quantize(quantization_config=quant_config))
569+
tester.export()
570+
571+
if legacy_partitioner:
572+
tester.to_edge()
573+
tester.partition(
574+
Partition(DynamicallyQuantizedPartitioner)
575+
).dump_artifact()
576+
# should have [add]mm node
577+
if uses_bias:
578+
tester.check(
579+
[
580+
"executorch_exir_dialects_edge__ops_aten_addmm_default",
581+
]
582+
)
583+
else:
584+
tester.check(
585+
[
586+
"executorch_exir_dialects_edge__ops_aten_mm_default",
587+
]
588+
)
589+
else:
590+
tester.to_edge_transform_and_lower(
591+
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
592+
).dump_artifact()
593+
# should not have a delegate node
594+
tester.check_not(
595+
[
596+
"torch.ops.higher_order.executorch_call_delegate",
597+
]
598+
)
599+
# No need to run the model, since it should fail to partition.
600+
return
601+
542602
def _test_qd8_per_channel_4w_linear(self, dtype: torch.dtype = torch.float):
543603
qconfig = self._get_4b_dqconfig()
544604
input_channels = [2, 63]
@@ -697,10 +757,24 @@ def test_qs8_linear(self):
697757
def test_qd8_f16_per_channel_linear(self):
698758
self._test_qd8_per_channel_linear(dtype=torch.half)
699759

760+
def test_qd8_f16_per_tensor_linear(self):
761+
"""
762+
XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op.
763+
This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights.
764+
"""
765+
self._test_qd8_linear_per_tensor_unsupported(dtype=torch.half)
766+
700767
# Tests for q[dp]8-f32-qc8w
701768
def test_qd8_f32_per_channel_linear(self):
702769
self._test_qd8_per_channel_linear(dtype=torch.float)
703770

771+
def test_qd8_f32_per_tensor_linear(self):
772+
"""
773+
XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op.
774+
This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights.
775+
"""
776+
self._test_qd8_linear_per_tensor_unsupported(dtype=torch.half)
777+
704778
# Tests for q[dp]8-f16-qc4w
705779
def test_linear_qd8_f16_per_channel_int4(self):
706780
self._test_qd8_per_channel_4w_linear(dtype=torch.half)

backends/xnnpack/utils/quant_utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ def is_per_channel(node: torch.fx.Node) -> bool:
8989
return is_per_channel or is_affine_per_channel_group
9090

9191

92+
def is_per_tensor(node: torch.fx.Node) -> bool:
93+
if not (is_quant(node) or is_dequant(node)):
94+
return False
95+
96+
is_per_tensor = "per_tensor" in node.target.__name__ # pyre-ignore
97+
98+
return is_per_tensor and not (is_per_channel(node))
99+
100+
92101
def is_affine_qdq(node: torch.fx.Node) -> bool:
93102
if not (is_quant(node) or is_dequant(node)):
94103
return False

0 commit comments

Comments
 (0)