Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm backend: Add WhyNoPartitionReporter and report rejected nodes #8963

Merged
merged 2 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions backends/arm/operator_support/convolution_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

for pad in output_padding:
if pad != 0:
self.reporter.report_reject(
node, "Convolutions with non-zero output padding not implemented."
)
return False

# Hardware specific constraints
Expand All @@ -56,19 +59,33 @@ def _is_node_supported_u55(self, node: fx.Node):
# Depthwise convolution
for dim in shape_in[1:]:
if not 1 <= dim <= 65536:
self.reporter.report_reject(
node,
f"Depthwise convolution must have CWH <= 65536, got {dim})",
)
return False
else:
# Convolution
if not 1 <= C_in <= 65536:
self.reporter.report_reject(
node, f"Convolution must have C <= 65536, got {C_in})"
)
return False

kernel_w = kernel[2]
kernel_h = kernel[3] if len(kernel) > 3 else 1
# Kernel condition misses constraint on sum of absolute weights
if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096:
self.reporter.report_reject(
node,
f"Convolution needs to have kernel_y<=64, kernel_x*kernel_y<=4096, got kernel ({kernel_w}, {kernel_h})",
)
return False

if not self._stride_condition(node):
self.reporter.report_reject(
node, "Failed condition on stride, pad and dilation combination."
)
return False

return True
Expand Down
44 changes: 42 additions & 2 deletions backends/arm/operator_support/pool_2d_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,35 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
if len(node.args) > 3:
# Padding case
if not all(1 <= k <= 8 for k in kernel):
self.reporter.report_reject(
node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}"
)
return False
else:
if not kernel_check(kernel):
self.reporter.report_reject(
node,
f"Avgpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}",
)
return False

return dim_check(shape) and shape[0] == 1 and stride_check(stride)
if not dim_check(shape):
self.reporter.report_reject(
node,
f"Avgpool2d needs N == 1, rest dims <= 65536, got shape {list(shape)}",
)
return False
if not stride_check(stride):
self.reporter.report_reject(
node, f"Avgpool2d needs stride <= 3, got {stride}"
)
return False
if not shape[0] == 1:
self.reporter.report_reject(
node, f"Avgpool2d needs N==1, got N=={shape[0]}"
)
return False
return True


@register_tosa_support_check
Expand All @@ -82,4 +105,21 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
kernel = cast(tuple[int, int], node.args[1])
stride = cast(tuple[int, int], node.args[2])

return kernel_check(kernel) and dim_check(shape) and stride_check(stride)
if not kernel_check(kernel):
self.reporter.report_reject(
node,
f"Maxpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}",
)
return False
if not dim_check(shape):
self.reporter.report_reject(
node,
f"Maxpool2d needs N == 1, rest dims <= 65536, got shape {list(shape)}",
)
return False
if not stride_check(stride):
self.reporter.report_reject(
node, f"Maxpool2d needs stride <= 3, got {stride}"
)
return False
return True
5 changes: 5 additions & 0 deletions backends/arm/operator_support/reduce_sum_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

for dim in dim_list:
if not 1 <= input_shape[dim] <= 65536:
self.reporter.report_reject(
node, f"sum needs dims < 65536, got shape {input_shape}"
)
return False

# We can't be certain of which dim is the last in memory yet,
Expand All @@ -45,7 +48,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
for length in input_shape[dim + 1 :]:
post_R_product *= length
if not 1 <= pre_R_product <= 65536:
self.reporter.report_reject(node, "Failed dim check")
return False
if not 1 <= post_R_product <= 65536:
self.reporter.report_reject(node, "Failed dim check")
return False
return True
26 changes: 13 additions & 13 deletions backends/arm/operator_support/to_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@ def is_node_tosa_supported(
) -> bool:
assert node.target in self.targets

if tosa_spec not in self.tosa_specs:
return False

assert tosa_spec.support_integer()
supported_dtypes = (
self.ALL_SUPPORTED_TYPES
Expand All @@ -97,30 +94,32 @@ def is_node_tosa_supported(
assert isinstance(input_val, torch._subclasses.FakeTensor)
input_dtype = input_val.dtype
if input_dtype not in supported_dtypes:
logger.info(
f"Input dtype {input_val.dtype} is not supported in "
f"{node.target.name()}." # type: ignore[union-attr] # pyre-ignore[16]
self.reporter.report_reject(
node,
f"Input dtype {input_val.dtype} is not supported in {node.target}.",
)
return False

# Check output type
output_val = node.meta["val"]
assert isinstance(output_val, torch._subclasses.FakeTensor)
if output_val.dtype not in supported_dtypes[input_dtype]:
logger.info(
self.reporter.report_reject(
node,
f"Output dtype {output_val.dtype} is not supported in "
f"{node.target.name()} for input dtype {input_dtype}. " # type: ignore[union-attr] # pyre-ignore[16]
f"{node.target} for input dtype {input_dtype}. "
f"Supported output types: "
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}"
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}",
)
return False

# Check memory format (to_copy)
if "memory_format" in node.kwargs:
if node.kwargs["memory_format"] in (torch.preserve_format,):
logger.info(
self.reporter.report_reject(
node,
f"Argument 'memory_format' is not supported for "
f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16]
f"{node.target} right now.",
)
return False

Expand All @@ -129,9 +128,10 @@ def is_node_tosa_supported(
dim_order = node.kwargs["dim_order"]
# pyre-ignore[6]
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
logger.info(
self.reporter.report_reject(
node,
f"Argument {dim_order=} is not supported for "
f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16]
f"{node.target} right now.",
)
return False

Expand Down
Loading
Loading