21
21
is_dynamic_qdq ,
22
22
is_per_channel ,
23
23
is_per_channel_group ,
24
+ is_per_tensor ,
24
25
is_qparam ,
25
26
is_quant ,
26
27
)
@@ -66,8 +67,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
66
67
return False
67
68
68
69
is_valid , _ = self .get_deps (node , ep )
69
- if not is_valid :
70
- why (node , "Failed to get valid dependent nodes." )
71
70
return is_valid
72
71
73
72
def get_node_and_deps (
@@ -123,6 +122,7 @@ def get_deps(
123
122
precision = self ._detect_precision (node )
124
123
if precision not in self .supported_precision_types ():
125
124
# detected precision but it is either disabled or not supported
125
+ why (node , f"Unsupported precision type { precision } " )
126
126
return (False , [])
127
127
_ , precision = self ._overwrite_precision (node )
128
128
valid_bias , bias_deps = self ._get_bias_deps (node , ep , precision )
@@ -143,27 +143,42 @@ def _get_weight_deps(
143
143
# First find the weight
144
144
weight_node = get_input_node (node , self .weight_idx )
145
145
if not is_param_node (ep , weight_node ):
146
- return (False , []) # weight must be a static param
146
+ why (node , "Expected weight to be a static param" )
147
+ return (False , [])
147
148
gemm_deps .append (weight_node )
148
149
149
150
return (True , gemm_deps )
150
151
else :
151
152
# Quantized Weight deps
152
153
dequant_node = get_input_node (node , self .weight_idx )
153
154
if not is_dequant (dequant_node ):
155
+ why (node , "Expected weight to have a dequantized node" )
154
156
return False , []
155
157
gemm_deps .append (dequant_node )
156
158
weight = get_input_node (dequant_node , 0 )
157
159
if not is_param_node (ep , weight ):
160
+ why (node , "Expected weight to be a static param" )
158
161
return False , []
159
162
gemm_deps .append (weight )
160
163
164
+ if (
165
+ is_per_tensor (dequant_node )
166
+ and precision == ConfigPrecisionType .DYNAMIC_QUANT
167
+ ):
168
+ why (
169
+ node ,
170
+ "XNNPACK does not support per tensor quantized weights for dynamic quantization of activations" ,
171
+ )
172
+ return False , []
173
+
161
174
if is_per_channel (dequant_node ) or is_per_channel_group (dequant_node ):
162
175
if len (dequant_node .all_input_nodes ) < 2 :
163
176
# Expected channel quantized to have scale/zp nodes
177
+ why (node , "Expected channel quantized to have scale/zp nodes" )
164
178
return False , []
165
179
166
180
gemm_deps .extend (dequant_node .all_input_nodes [1 :3 ])
181
+
167
182
return (True , gemm_deps )
168
183
169
184
def _get_output_deps (
@@ -174,7 +189,7 @@ def _get_output_deps(
174
189
# Look for fused activations and tail end quant node
175
190
node_users = list (node .users .keys ())
176
191
if len (node_users ) != 1 :
177
- # Expect quantized node to have a single output (fused act or dequant )
192
+ why ( node , "Expected quantized node to have a single output" )
178
193
return False , []
179
194
180
195
# Check if the quantized pattern has a fused activation
@@ -190,6 +205,7 @@ def _get_output_deps(
190
205
191
206
if not is_quant (n_output ):
192
207
# Expected gemm_node --> fused_act (optional) --> dequant
208
+ why (node , "Expected output node to have a dequantized node" )
193
209
return (False , [])
194
210
gemm_deps .append (n_output )
195
211
elif precision == ConfigPrecisionType .FP32 :
@@ -219,7 +235,8 @@ def _get_bias_deps(
219
235
bias_node = get_input_node (node , self .bias_idx )
220
236
if bias_node :
221
237
if not is_param_node (ep , bias_node ):
222
- return (False , []) # bias node must be a static param
238
+ why (node , "Expected bias to be a static param" )
239
+ return (False , [])
223
240
gemm_deps .append (bias_node )
224
241
225
242
return (True , gemm_deps )
@@ -233,7 +250,7 @@ def _get_act_deps(
233
250
else :
234
251
dq_input = get_input_node (node , self .act_idx )
235
252
if not is_dequant (dq_input ):
236
- # Expected static quant input to be dequant node
253
+ why ( node , " Expected act input to be dequant node" )
237
254
return False , []
238
255
gemm_deps .append (dq_input )
239
256
if precision == ConfigPrecisionType .STATIC_QUANT :
@@ -243,27 +260,28 @@ def _get_act_deps(
243
260
# q input node
244
261
q_input = get_input_node (dq_input , 0 )
245
262
if not is_quant (q_input ):
263
+ why (node , "Expected dequant input to be quant node" )
246
264
return (False , [])
247
265
248
266
gemm_deps .append (q_input )
249
267
q_input_args = q_input .args
250
268
if is_affine_qdq (q_input ):
251
269
q_input_args = extract_qdq_affine_op_args_for_decomposed_ops (q_input )
252
270
if not (is_node (q_input_args [1 ]) and is_node (q_input_args [2 ])):
253
- # expected to find getitem node from choose qparam
271
+ why ( node , " expected to find getitem node from choose qparam" )
254
272
return (False , [])
255
273
256
274
getitem1 = q_input_args [1 ]
257
275
getitem2 = q_input_args [2 ]
258
276
259
277
if not (is_getitem (getitem1 ) and is_getitem (getitem2 )):
260
- # expected getitem node from choose qparam
278
+ why ( node , " expected getitem node from choose qparam" )
261
279
return (False , [])
262
280
263
281
gemm_deps .extend ([getitem1 , getitem2 ])
264
282
choose_qparam = get_input_node (getitem1 , 0 )
265
283
if not is_qparam (choose_qparam ):
266
- # expected to find choose_qparam node
284
+ why ( node , " expected to find choose_qparam node" )
267
285
return (False , [])
268
286
gemm_deps .append (choose_qparam )
269
287
return (True , gemm_deps )
@@ -471,6 +489,7 @@ def find_partition_args(input_node):
471
489
# there can only be a single output node in partition
472
490
or len (src_partition .output_nodes ) != 1
473
491
):
492
+ why (node , "invalid source partition" )
474
493
return (False , [])
475
494
476
495
# map addmm's args to the source partition linear's inputs and users
0 commit comments