Skip to content

Commit 5a5dc87

Browse files
committed
Arm backend: Add WhyNoPartitionReporter and report rejected nodes
With more complex checks for if nodes are supported or not, it can probably be complicated for a user to understand why a ceratin node was rejected from partitioning. The WhyNoPartitionReporter attempts to mitigate this by producing a table with all rejected nodes with a reason for why they were rejected. The reasoning behind using a class rather than log statements is - The table keeps all reject information in the same place. - Uniform formatting. - Only log the first reject report of a node. - Easier to change how and when the output is logged/dumped to file. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I3765a0db4bfe530b34a7c5d48b04faa5d08e51de
1 parent b0c2c7c commit 5a5dc87

File tree

8 files changed

+285
-76
lines changed

8 files changed

+285
-76
lines changed

backends/arm/operator_support/convolution_support.py

+17
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
3434

3535
for pad in output_padding:
3636
if pad != 0:
37+
self.reporter.report_reject(
38+
node, "Convolutions with non-zero output padding not implemented."
39+
)
3740
return False
3841

3942
# Hardware specific constraints
@@ -56,19 +59,33 @@ def _is_node_supported_u55(self, node: fx.Node):
5659
# Depthwise convolution
5760
for dim in shape_in[1:]:
5861
if not 1 <= dim <= 65536:
62+
self.reporter.report_reject(
63+
node,
64+
f"Depthwise convolution must have CWH <= 65536, got {dim})",
65+
)
5966
return False
6067
else:
6168
# Convolution
6269
if not 1 <= C_in <= 65536:
70+
self.reporter.report_reject(
71+
node, f"Convolution must have C <= 65536, got {C_in})"
72+
)
6373
return False
6474

6575
kernel_w = kernel[2]
6676
kernel_h = kernel[3] if len(kernel) > 3 else 1
6777
# Kernel condition misses constraint on sum of absolute weights
6878
if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096:
79+
self.reporter.report_reject(
80+
node,
81+
f"Convolution needs to have kernel_y<=64, kernel_x*kernel_y<=4096, got kernel ({kernel_w}, {kernel_h})",
82+
)
6983
return False
7084

7185
if not self._stride_condition(node):
86+
self.reporter.report_reject(
87+
node, "Failed condition on stride, pad and dilation combination."
88+
)
7289
return False
7390

7491
return True

backends/arm/operator_support/pool_2d_support.py

+42-2
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,35 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
5454
if len(node.args) > 3:
5555
# Padding case
5656
if not all(1 <= k <= 8 for k in kernel):
57+
self.reporter.report_reject(
58+
node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}"
59+
)
5760
return False
5861
else:
5962
if not kernel_check(kernel):
63+
self.reporter.report_reject(
64+
node,
65+
f"Avgpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}",
66+
)
6067
return False
6168

62-
return dim_check(shape) and shape[0] == 1 and stride_check(stride)
69+
if not dim_check(shape):
70+
self.reporter.report_reject(
71+
node,
72+
f"Avgpool2d needs N == 1, rest dims <= 65536, got shape {list(shape)}",
73+
)
74+
return False
75+
if not stride_check(stride):
76+
self.reporter.report_reject(
77+
node, f"Avgpool2d needs stride <= 3, got {stride}"
78+
)
79+
return False
80+
if not shape[0] == 1:
81+
self.reporter.report_reject(
82+
node, f"Avgpool2d needs N==1, got N=={shape[0]}"
83+
)
84+
return False
85+
return True
6386

6487

6588
@register_tosa_support_check
@@ -82,4 +105,21 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
82105
kernel = cast(tuple[int, int], node.args[1])
83106
stride = cast(tuple[int, int], node.args[2])
84107

85-
return kernel_check(kernel) and dim_check(shape) and stride_check(stride)
108+
if not kernel_check(kernel):
109+
self.reporter.report_reject(
110+
node,
111+
f"Maxpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}",
112+
)
113+
return False
114+
if not dim_check(shape):
115+
self.reporter.report_reject(
116+
node,
117+
f"Maxpool2d needs N == 1, rest dims <= 65536, got shape {list(shape)}",
118+
)
119+
return False
120+
if not stride_check(stride):
121+
self.reporter.report_reject(
122+
node, f"Maxpool2d needs stride <= 3, got {stride}"
123+
)
124+
return False
125+
return True

backends/arm/operator_support/reduce_sum_support.py

+5
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
3434

3535
for dim in dim_list:
3636
if not 1 <= input_shape[dim] <= 65536:
37+
self.reporter.report_reject(
38+
node, f"sum needs dims < 65536, got shape {input_shape}"
39+
)
3740
return False
3841

3942
# We can't be certain of which dim is the last in memory yet,
@@ -45,7 +48,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
4548
for length in input_shape[dim + 1 :]:
4649
post_R_product *= length
4750
if not 1 <= pre_R_product <= 65536:
51+
self.reporter.report_reject(node, "Failed dim check")
4852
return False
4953
if not 1 <= post_R_product <= 65536:
54+
self.reporter.report_reject(node, "Failed dim check")
5055
return False
5156
return True

backends/arm/operator_support/to_copy_support.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,6 @@ def is_node_tosa_supported(
7575
) -> bool:
7676
assert node.target in self.targets
7777

78-
if tosa_spec not in self.tosa_specs:
79-
return False
80-
8178
assert tosa_spec.support_integer()
8279
supported_dtypes = (
8380
self.ALL_SUPPORTED_TYPES
@@ -97,30 +94,32 @@ def is_node_tosa_supported(
9794
assert isinstance(input_val, torch._subclasses.FakeTensor)
9895
input_dtype = input_val.dtype
9996
if input_dtype not in supported_dtypes:
100-
logger.info(
101-
f"Input dtype {input_val.dtype} is not supported in "
102-
f"{node.target.name()}." # type: ignore[union-attr] # pyre-ignore[16]
97+
self.reporter.report_reject(
98+
node,
99+
f"Input dtype {input_val.dtype} is not supported in {node.target}.",
103100
)
104101
return False
105102

106103
# Check output type
107104
output_val = node.meta["val"]
108105
assert isinstance(output_val, torch._subclasses.FakeTensor)
109106
if output_val.dtype not in supported_dtypes[input_dtype]:
110-
logger.info(
107+
self.reporter.report_reject(
108+
node,
111109
f"Output dtype {output_val.dtype} is not supported in "
112-
f"{node.target.name()} for input dtype {input_dtype}. " # type: ignore[union-attr] # pyre-ignore[16]
110+
f"{node.target} for input dtype {input_dtype}. "
113111
f"Supported output types: "
114-
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}"
112+
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}",
115113
)
116114
return False
117115

118116
# Check memory format (to_copy)
119117
if "memory_format" in node.kwargs:
120118
if node.kwargs["memory_format"] in (torch.preserve_format,):
121-
logger.info(
119+
self.reporter.report_reject(
120+
node,
122121
f"Argument 'memory_format' is not supported for "
123-
f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16]
122+
f"{node.target} right now.",
124123
)
125124
return False
126125

@@ -129,9 +128,10 @@ def is_node_tosa_supported(
129128
dim_order = node.kwargs["dim_order"]
130129
# pyre-ignore[6]
131130
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
132-
logger.info(
131+
self.reporter.report_reject(
132+
node,
133133
f"Argument {dim_order=} is not supported for "
134-
f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16]
134+
f"{node.target} right now.",
135135
)
136136
return False
137137

0 commit comments

Comments
 (0)