diff --git a/backends/arm/operator_support/convolution_support.py b/backends/arm/operator_support/convolution_support.py
index 0d0a32200e8..b07ae82f98f 100644
--- a/backends/arm/operator_support/convolution_support.py
+++ b/backends/arm/operator_support/convolution_support.py
@@ -34,6 +34,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
 
         for pad in output_padding:
             if pad != 0:
+                self.reporter.report_reject(
+                    node, "Convolutions with non-zero output padding not implemented."
+                )
                 return False
 
         # Hardware specific constraints
@@ -56,19 +59,33 @@ def _is_node_supported_u55(self, node: fx.Node):
             # Depthwise convolution
             for dim in shape_in[1:]:
                 if not 1 <= dim <= 65536:
+                    self.reporter.report_reject(
+                        node,
+                        f"Depthwise convolution must have CWH <= 65536, got {dim})",
+                    )
                     return False
         else:
             # Convolution
             if not 1 <= C_in <= 65536:
+                self.reporter.report_reject(
+                    node, f"Convolution must have C <= 65536, got {C_in})"
+                )
                 return False
 
         kernel_w = kernel[2]
         kernel_h = kernel[3] if len(kernel) > 3 else 1
         # Kernel condition misses constraint on sum of absolute weights
         if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096:
+            self.reporter.report_reject(
+                node,
+                f"Convolution needs to have kernel_y<=64, kernel_x*kernel_y<=4096, got kernel ({kernel_w}, {kernel_h})",
+            )
             return False
 
         if not self._stride_condition(node):
+            self.reporter.report_reject(
+                node, "Failed condition on stride, pad and dilation combination."
+            )
             return False
 
         return True
diff --git a/backends/arm/operator_support/pool_2d_support.py b/backends/arm/operator_support/pool_2d_support.py
index c1dd143a4fc..8291ede8ad9 100644
--- a/backends/arm/operator_support/pool_2d_support.py
+++ b/backends/arm/operator_support/pool_2d_support.py
@@ -54,12 +54,35 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
         if len(node.args) > 3:
             # Padding case
             if not all(1 <= k <= 8 for k in kernel):
+                self.reporter.report_reject(
+                    node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}"
+                )
                 return False
         else:
             if not kernel_check(kernel):
+                self.reporter.report_reject(
+                    node,
+                    f"Avgpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}",
+                )
                 return False
 
-        return dim_check(shape) and shape[0] == 1 and stride_check(stride)
+        if not dim_check(shape):
+            self.reporter.report_reject(
+                node,
+                f"Avgpool2d needs N == 1, rest dims <= 65536, got shape {list(shape)}",
+            )
+            return False
+        if not stride_check(stride):
+            self.reporter.report_reject(
+                node, f"Avgpool2d needs stride <= 3, got {stride}"
+            )
+            return False
+        if not shape[0] == 1:
+            self.reporter.report_reject(
+                node, f"Avgpool2d needs N==1, got N=={shape[0]}"
+            )
+            return False
+        return True
 
 
 @register_tosa_support_check
@@ -82,4 +105,21 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
         kernel = cast(tuple[int, int], node.args[1])
         stride = cast(tuple[int, int], node.args[2])
 
-        return kernel_check(kernel) and dim_check(shape) and stride_check(stride)
+        if not kernel_check(kernel):
+            self.reporter.report_reject(
+                node,
+                f"Maxpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}",
+            )
+            return False
+        if not dim_check(shape):
+            self.reporter.report_reject(
+                node,
+                f"Maxpool2d needs N == 1, rest dims <= 65536, got shape {list(shape)}",
+            )
+            return False
+        if not stride_check(stride):
+            self.reporter.report_reject(
+                node, f"Maxpool2d needs stride <= 3, got {stride}"
+            )
+            return False
+        return True
diff --git a/backends/arm/operator_support/reduce_sum_support.py b/backends/arm/operator_support/reduce_sum_support.py
index 8345d69caaa..37a71d7264c 100644
--- a/backends/arm/operator_support/reduce_sum_support.py
+++ b/backends/arm/operator_support/reduce_sum_support.py
@@ -34,6 +34,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
 
         for dim in dim_list:
             if not 1 <= input_shape[dim] <= 65536:
+                self.reporter.report_reject(
+                    node, f"sum needs dims < 65536, got shape {input_shape}"
+                )
                 return False
 
             # We can't be certain of which dim is the last in memory yet,
@@ -45,7 +48,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
             for length in input_shape[dim + 1 :]:
                 post_R_product *= length
             if not 1 <= pre_R_product <= 65536:
+                self.reporter.report_reject(node, "Failed dim check")
                 return False
             if not 1 <= post_R_product <= 65536:
+                self.reporter.report_reject(node, "Failed dim check")
                 return False
         return True
diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_copy_support.py
index c81c8e58a29..7926b3dc053 100644
--- a/backends/arm/operator_support/to_copy_support.py
+++ b/backends/arm/operator_support/to_copy_support.py
@@ -75,9 +75,6 @@ def is_node_tosa_supported(
     ) -> bool:
         assert node.target in self.targets
 
-        if tosa_spec not in self.tosa_specs:
-            return False
-
         assert tosa_spec.support_integer()
         supported_dtypes = (
             self.ALL_SUPPORTED_TYPES
@@ -97,9 +94,9 @@ def is_node_tosa_supported(
         assert isinstance(input_val, torch._subclasses.FakeTensor)
         input_dtype = input_val.dtype
         if input_dtype not in supported_dtypes:
-            logger.info(
-                f"Input dtype {input_val.dtype} is not supported in "
-                f"{node.target.name()}."  # type: ignore[union-attr]  # pyre-ignore[16]
+            self.reporter.report_reject(
+                node,
+                f"Input dtype {input_val.dtype} is not supported in {node.target}.",
             )
             return False
 
@@ -107,20 +104,22 @@ def is_node_tosa_supported(
         output_val = node.meta["val"]
         assert isinstance(output_val, torch._subclasses.FakeTensor)
         if output_val.dtype not in supported_dtypes[input_dtype]:
-            logger.info(
+            self.reporter.report_reject(
+                node,
                 f"Output dtype {output_val.dtype} is not supported in "
-                f"{node.target.name()} for input dtype {input_dtype}. "  # type: ignore[union-attr]  # pyre-ignore[16]
+                f"{node.target} for input dtype {input_dtype}. "
                 f"Supported output types: "
-                f"{''.join(str(t) for t in supported_dtypes[input_dtype])}"
+                f"{''.join(str(t) for t in supported_dtypes[input_dtype])}",
             )
             return False
 
         # Check memory format (to_copy)
         if "memory_format" in node.kwargs:
             if node.kwargs["memory_format"] in (torch.preserve_format,):
-                logger.info(
+                self.reporter.report_reject(
+                    node,
                     f"Argument 'memory_format' is not supported for "
-                    f"{node.target.name()} right now."  # type: ignore[union-attr]  # pyre-ignore[16]
+                    f"{node.target} right now.",
                 )
                 return False
 
@@ -129,9 +128,10 @@ def is_node_tosa_supported(
             dim_order = node.kwargs["dim_order"]
             # pyre-ignore[6]
             if dim_order != list(range(len(dim_order))):  # type: ignore[arg-type]
-                logger.info(
+                self.reporter.report_reject(
+                    node,
                     f"Argument {dim_order=} is not supported for "
-                    f"{node.target.name()} right now."  # type: ignore[union-attr]  # pyre-ignore[16]
+                    f"{node.target} right now.",
                 )
                 return False
 
diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py
index 223b5d40ea1..dfd8024e4b3 100644
--- a/backends/arm/operator_support/tosa_supported_operators.py
+++ b/backends/arm/operator_support/tosa_supported_operators.py
@@ -19,6 +19,7 @@
 )
 from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
 from executorch.exir import ExportedProgram
+from executorch.exir.backend.utils import WhyNoPartitionReporter
 from executorch.exir.dialects._ops import ops as exir_ops
 from torch.export.graph_signature import InputKind
 from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase
@@ -30,8 +31,9 @@ class SupportedTOSAOperatorCheck(OperatorSupportBase):
     Supported OP for TOSA lowering
     """
 
-    def __init__(self, tosa_spec: TosaSpecification):
+    def __init__(self, tosa_spec: TosaSpecification, reporter: WhyNoPartitionReporter):
         self.tosa_spec = tosa_spec
+        self.reporter = reporter
 
     # Should be populated by subclass implementation
     tosa_specs: list[TosaSpecification] = []
@@ -86,23 +88,42 @@ def get_registered_tosa_support_checks(
 def tosa_support_factory(
     tosa_spec: TosaSpecification,
     exported_program: ExportedProgram,
+    reporter: WhyNoPartitionReporter,
     additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
 ) -> OperatorSupportBase:
-    negative_checks: list[OperatorSupportBase] = [CheckInt64Inputs(exported_program)]
+    """Generates an OperatorSupport class depending on the given `tosa_spec`.
+    Additional checks can be supplied to avoid partitioning additional nodes.
+    """
+    # Postive checks: Add nodes to partitioning
+    positive_checks: list[OperatorSupportBase] = [
+        BaseTOSASupportList(),
+        *[
+            check(tosa_spec, reporter)
+            for check in get_registered_tosa_support_checks(tosa_spec)
+        ],
+    ]
+
+    # Negative checks: Remove nodes from partitioning
+    negative_checks: list[OperatorSupportBase] = [
+        CheckInt64Inputs(exported_program, reporter),
+        *[
+            reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}")
+            for check in (additional_checks if additional_checks else [])
+        ],
+    ]
+
     if not tosa_spec.support_float():
-        negative_checks.append(NeedsDecompositionCheck())
-        negative_checks.append(CheckProperQuantization())
-        negative_checks.append(EthosU55NotSupported(tosa_spec))
+        negative_checks.append(NeedsDecompositionCheck(reporter))
+        negative_checks.append(CheckProperQuantization(reporter))
+    if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
+        negative_checks.append(EthosU55NotSupported(reporter))
+
     return chain(
-        any_chain(
-            BaseTOSASupportList(),
-            *(
-                check(tosa_spec)
-                for check in get_registered_tosa_support_checks(tosa_spec)
-            ),
+        reporter.wrap_check(
+            any_chain(*positive_checks),
+            "Not included in BaseTOSASupportList or a registered tosa_support_check",
         ),
         *negative_checks,
-        *additional_checks if additional_checks else [],
     )
 
 
@@ -186,39 +207,39 @@ def is_node_supported(
 
 class EthosU55NotSupported(OperatorSupportBase):
     """
-    Certain operators are not supported on U55. These are listed in `unsupported` in
-    is_node_supported().
+    Certain operators are not supported on U55. These are listed in `unsupported_ops`.
     """
 
-    def __init__(self, tosa_spec: TosaSpecification):
-        self.tosa_spec = tosa_spec
+    unsupported_ops = [
+        exir_ops.edge.aten.any.default,
+        exir_ops.edge.aten.any.dim,
+        exir_ops.edge.aten.any.dims,
+        exir_ops.edge.aten.bitwise_and.Tensor,
+        exir_ops.edge.aten.bitwise_or.Tensor,
+        exir_ops.edge.aten.bitwise_xor.Tensor,
+        exir_ops.edge.aten.logical_and.default,
+        exir_ops.edge.aten.logical_or.default,
+        exir_ops.edge.aten.logical_xor.default,
+        exir_ops.edge.aten.logical_not.default,
+        exir_ops.edge.aten.amax.default,
+        exir_ops.edge.aten.amin.default,
+        exir_ops.edge.aten.eq.Tensor,
+        exir_ops.edge.aten.ge.Tensor,
+        exir_ops.edge.aten.gt.Tensor,
+        exir_ops.edge.aten.le.Tensor,
+        exir_ops.edge.aten.lt.Tensor,
+    ]
+
+    def __init__(self, reporter: WhyNoPartitionReporter):
+        self.reporter = reporter
 
     def is_node_supported(
         self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
     ) -> bool:
-        if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
-            unsupported_ops = [
-                exir_ops.edge.aten.any.default,
-                exir_ops.edge.aten.any.dim,
-                exir_ops.edge.aten.any.dims,
-                exir_ops.edge.aten.bitwise_and.Tensor,
-                exir_ops.edge.aten.bitwise_or.Tensor,
-                exir_ops.edge.aten.bitwise_xor.Tensor,
-                exir_ops.edge.aten.logical_and.default,
-                exir_ops.edge.aten.logical_or.default,
-                exir_ops.edge.aten.logical_xor.default,
-                exir_ops.edge.aten.logical_not.default,
-                exir_ops.edge.aten.amax.default,
-                exir_ops.edge.aten.amin.default,
-                exir_ops.edge.aten.eq.Tensor,
-                exir_ops.edge.aten.ge.Tensor,
-                exir_ops.edge.aten.gt.Tensor,
-                exir_ops.edge.aten.le.Tensor,
-                exir_ops.edge.aten.lt.Tensor,
-            ]
 
-            if node.target in unsupported_ops:
-                return False
+        if node.target in self.unsupported_ops:
+            self.reporter.report_reject(node, "Op is not supported on U55.")
+            return False
 
         return True
 
@@ -230,6 +251,9 @@ class NeedsDecompositionCheck(OperatorSupportBase):
     that need to be decomposed.
     """
 
+    def __init__(self, reporter: WhyNoPartitionReporter):
+        self.reporter = reporter
+
     def is_node_supported(
         self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
     ) -> bool:
@@ -238,22 +262,27 @@ def is_node_supported(
             return True
         if node.target == exir_ops.edge.aten.mean.dim:
             dim = node.args[1]
-            return dim == [-1, -2]
-        needs_decomp = node.target in [
-            exir_ops.edge.aten.div.Tensor,
-            exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
-            exir_ops.edge.aten.native_layer_norm.default,
-            exir_ops.edge.aten.mean.dim,
-            exir_ops.edge.aten._softmax.default,
-            exir_ops.edge.aten._log_softmax.default,
-            exir_ops.edge.aten.var.correction,
-            exir_ops.edge.aten.var.dim,
-            exir_ops.edge.aten.add.Scalar,
-            exir_ops.edge.aten.sub.Scalar,
-            exir_ops.edge.aten.mul.Scalar,
-            exir_ops.edge.aten.div.Scalar,
-        ]
-        return not needs_decomp
+            needs_decomp = dim != [-1, -2]
+        else:
+            needs_decomp = node.target in [
+                exir_ops.edge.aten.div.Tensor,
+                exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
+                exir_ops.edge.aten.native_layer_norm.default,
+                exir_ops.edge.aten.mean.dim,
+                exir_ops.edge.aten._softmax.default,
+                exir_ops.edge.aten._log_softmax.default,
+                exir_ops.edge.aten.var.correction,
+                exir_ops.edge.aten.var.dim,
+                exir_ops.edge.aten.add.Scalar,
+                exir_ops.edge.aten.sub.Scalar,
+                exir_ops.edge.aten.mul.Scalar,
+                exir_ops.edge.aten.div.Scalar,
+            ]
+        if needs_decomp:
+            self.reporter.report_reject(node, "Needs to be decomposed.")
+            return False
+        else:
+            return True
 
 
 class CheckProperQuantization(OperatorSupportBase):
@@ -266,6 +295,9 @@ class CheckProperQuantization(OperatorSupportBase):
     dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
     q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
 
+    def __init__(self, reporter: WhyNoPartitionReporter):
+        self.reporter = reporter
+
     def _is_matmul_node_supported(
         self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
     ):
@@ -294,14 +326,23 @@ def _is_matmul_node_supported(
                     for input_node in matched_partition.input_nodes
                 )
                 if not input_quantized:
+                    self.reporter.report_reject(
+                        node, "One or more matmul inputs were not quantized."
+                    )
                     return False
                 output_quantized = all(
                     output_node_user.target == self.q_op
                     for output_node_user in matched_partition.output_nodes[0].users
                 )
                 if not output_quantized:
+                    self.reporter.report_reject(
+                        node, "One or more matmul outputs were not quantized."
+                    )
                     return False
             else:
+                self.reporter.report_reject(
+                    node, "Node did not match any matmul source partition."
+                )
                 return False
 
         return True
@@ -367,6 +408,7 @@ def is_node_supported(
         )
 
         if not input_quantized:
+            self.reporter.report_reject(node, "One or more inputs were not quantized.")
             return False
 
         all_q_users = all(
@@ -376,18 +418,22 @@ def is_node_supported(
         output_quantized = output_quantized or all_q_users or not is_floating_point
 
         if not output_quantized:
+            self.reporter.report_reject(node, "One or more outputs were not quantized.")
             return False
         return True
 
 
 class CheckInt64Inputs(OperatorSupportBase):
 
-    def __init__(self, exported_program: ExportedProgram):
+    def __init__(
+        self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
+    ):
         self.input_names = [
             spec.arg.name
             for spec in exported_program.graph_signature.input_specs
             if spec.kind == InputKind.USER_INPUT
         ]
+        self.reporter = reporter
         super().__init__()
 
     def is_node_supported(
@@ -402,5 +448,9 @@ def is_node_supported(
             ):
                 tensor = get_first_fake_tensor(input_node)
                 if tensor.dtype == torch.int64:
+                    self.reporter.report_reject(
+                        node,
+                        f"Had int64 input {input_node.name} that couldn't be handled.",
+                    )
                     return False
         return True
diff --git a/backends/arm/test/misc/test_custom_partition.py b/backends/arm/test/misc/test_custom_partition.py
index 8d73e1c7836..00bc4d306ae 100644
--- a/backends/arm/test/misc/test_custom_partition.py
+++ b/backends/arm/test/misc/test_custom_partition.py
@@ -3,6 +3,8 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
+import logging
+
 import torch
 from executorch.backends.arm.test import common
 from executorch.backends.arm.test.tester.arm_tester import ArmTester
@@ -37,7 +39,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
         return self.nested(a, b)
 
 
-def test_single_reject():
+def test_single_reject(caplog):
+    caplog.set_level(logging.INFO)
+
     module = CustomPartitioning()
     inputs = module.inputs
     compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
@@ -57,6 +61,7 @@ def test_single_reject():
         .run_method_and_compare_outputs(inputs=inputs)
     )
     assert check.has_rejected_node()
+    assert "Rejected by DontPartition" in caplog.text
 
 
 def test_multiple_reject():
@@ -83,7 +88,9 @@ def test_multiple_reject():
     assert check.has_rejected_node()
 
 
-def test_torch_op_reject():
+def test_torch_op_reject(caplog):
+    caplog.set_level(logging.INFO)
+
     module = CustomPartitioning()
     inputs = module.inputs
     compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
@@ -103,6 +110,7 @@ def test_torch_op_reject():
         .run_method_and_compare_outputs(inputs=inputs)
     )
     assert check.has_rejected_node()
+    assert "Rejected by DontPartition" in caplog.text
 
 
 def test_string_op_reject():
@@ -128,7 +136,9 @@ def test_string_op_reject():
     assert check.has_rejected_node()
 
 
-def test_name_reject():
+def test_name_reject(caplog):
+    caplog.set_level(logging.INFO)
+
     module = CustomPartitioning()
     inputs = module.inputs
     compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
@@ -148,6 +158,7 @@ def test_name_reject():
         .run_method_and_compare_outputs(inputs=inputs)
     )
     assert check.has_rejected_node()
+    assert "Rejected by DontPartitionName" in caplog.text
 
 
 def test_module_reject():
@@ -172,7 +183,9 @@ def test_module_reject():
     assert check.has_rejected_node()
 
 
-def test_inexact_module_reject():
+def test_inexact_module_reject(caplog):
+    caplog.set_level(logging.INFO)
+
     module = NestedModule()
     inputs = module.inputs
     compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
@@ -192,6 +205,7 @@ def test_inexact_module_reject():
         .run_method_and_compare_outputs(inputs=inputs)
     )
     assert check.has_rejected_node()
+    assert "Rejected by DontPartitionModule" in caplog.text
 
 
 def test_module_instance_reject():
diff --git a/backends/arm/tosa_partitioner.py b/backends/arm/tosa_partitioner.py
index 228998d82f5..a53bf6fc725 100644
--- a/backends/arm/tosa_partitioner.py
+++ b/backends/arm/tosa_partitioner.py
@@ -25,7 +25,7 @@
     Partitioner,
     PartitionResult,
 )
-from executorch.exir.backend.utils import tag_constant_data
+from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter
 from executorch.exir.dialects._ops import ops as exir_ops
 from torch.export.exported_program import ExportedProgram
 from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
@@ -33,7 +33,7 @@
 
 
 logger = logging.getLogger(__name__)
-logger.setLevel(logging.WARNING)
+logger.setLevel(logging.INFO)
 TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
 if TOSA_DBG_VERBOSE:
     logging.basicConfig(level=logging.INFO)
@@ -78,9 +78,13 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:  # no
 
         logger.info(f"Partitioning for {self.delegation_spec.backend_id}: {tosa_spec}")
 
+        reporter = WhyNoPartitionReporter()
+        operator_support = tosa_support_factory(
+            tosa_spec, exported_program, reporter, self.additional_checks
+        )
         capability_partitioner = CapabilityBasedPartitioner(
             exported_program.graph_module,
-            tosa_support_factory(tosa_spec, exported_program, self.additional_checks),
+            operator_support,
             allows_single_node_partition=True,
         )
         partition_list = capability_partitioner.propose_partitions()
@@ -119,14 +123,17 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
                         if is_partitioned(input):
                             continue
                         if get_first_fake_tensor(input).dtype.is_floating_point:
-                            logger.info(
-                                f"Not partitioning {node.name} becuase input {input.name} has floating point dtype."
+                            reporter.report_reject(
+                                node,
+                                f"Was first node in partition and input {input.name} had fp dtype.",
                             )
                             del node.meta["delegation_tag"]
                             break
 
         tag_constant_data(exported_program)
-
+        logger.info(f"The following nodes were rejected for {tosa_spec}:")
+        logger.info("\n" + reporter.get_table_report())
+        logger.info("(Placeholders and outputs are not included in this list)")
         return PartitionResult(
             tagged_exported_program=exported_program, partition_tags=partition_tags
         )
diff --git a/exir/backend/utils.py b/exir/backend/utils.py
index 9487c59a848..eb9aeb19756 100644
--- a/exir/backend/utils.py
+++ b/exir/backend/utils.py
@@ -1,5 +1,6 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # All rights reserved.
+# Copyright 2025 Arm Limited and/or its affiliates.
 #
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
@@ -8,7 +9,7 @@
 
 import logging
 import operator
-from collections import defaultdict
+from collections import defaultdict, OrderedDict
 from functools import lru_cache
 from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
 
@@ -22,9 +23,11 @@
 from executorch.exir.dialects._ops import ops as exir_ops
 
 from executorch.exir.lowered_backend_module import create_submodule_from_nodes
+from tabulate import tabulate
 from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
 from torch.fx.experimental.symbolic_shapes import has_free_symbols
 from torch.fx.node import Node
+from torch.fx.passes.operator_support import OperatorSupportBase
 from torch.fx.passes.utils.source_matcher_utils import SourcePartition
 
 T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
@@ -569,3 +572,90 @@ def __call__(self, node: torch.fx.Node, reason: str) -> None:
 
     def __str__(self) -> str:
         return f"WhyNoPartition: Node {self.node} was not partitioned because {self.reason}."
+
+
+class WhyNoPartitionReporter:
+    """
+    Helper class for partitioners to gather why nodes were not lowered in a single report.
+    If a node is reported multiple times, only the first report is included.
+
+    Example usage:
+
+        # In your backend partitioner file(s)
+        reporter = WhyNoPartitionReporter()
+
+        # hypothetical function that checks if a node can be lowered
+        if not can_be_lowered(node):
+            reporter.report_reject(node, "This node was not lowered because ...")
+
+        # Back in partitioner
+        logger.info(reporter.get_table_report())
+    """
+
+    def __init__(self):
+        self._rejected_nodes: OrderedDict[torch.fx.Node, str] = (
+            OrderedDict()
+        )  # {Rejected node: reason}
+
+    def report_reject(self, node: torch.fx.Node, reason: str):
+        """Report a node that was rejected from a partition, along with a reason for why."""
+        if node not in self._rejected_nodes:
+            self._rejected_nodes[node] = reason
+
+    def get_table_report(self) -> str:
+        """Returns a string containing a table listing all rejected nodes.
+        The table looks something like this:
+        ╒══════════════════════════╤══════════════════════════╤═════════════════════════════════════╤═════════════════════════════════════╕
+        │ Node name                │ Target                   │ Torch func                          │ Reason                              │
+        ╞══════════════════════════╪══════════════════════════╪═════════════════════════════════════╪═════════════════════════════════════╡
+        │ aten_convolution_default │ aten.convolution.default │ ('conv2d_1', 'builtin_function_or_m │ Convolution needs to have           │
+        │                          │                          │ ethod.conv2d')                      │ kernel_y<=64,                       │
+        │                          │                          │                                     │ kernel_x*kernel_y<=4096, got kernel │
+        │                          │                          │                                     │ (2, 65)                             │
+        ╘══════════════════════════╧══════════════════════════╧═════════════════════════════════════╧═════════════════════════════════════╛
+        """
+        reject_report = []
+        for node in self._rejected_nodes:
+            if node.op == "placeholder" or node.op == "output":
+                continue
+            if not (target := getattr(node.target, "_op", None)):
+                target = node.target
+            torch_fn = node.meta.get("torch_fn", "-")
+            reject_report.append(
+                [node.name, target, torch_fn, self._rejected_nodes[node]]
+            )
+        if len(reject_report) > 0:
+            return tabulate(
+                reject_report,
+                ["Node name", "Target", "Torch func", "Reason"],
+                tablefmt="fancy_grid",
+                maxcolwidths=35,
+            )
+        else:
+            return "No nodes rejected."
+
+    def wrap_check(
+        self, operator_support: OperatorSupportBase, message: str
+    ) -> OperatorSupportBase:
+        """Wrap the operator_support, reporting rejects with the specified message."""
+        return ReportRejected(operator_support, self, message)
+
+
+class ReportRejected(OperatorSupportBase):
+    """Class for wrapping a OperatorSupportBase, reporting rejects with the specified message to `reporter`."""
+
+    def __init__(
+        self,
+        operator_support: OperatorSupportBase,
+        reporter: WhyNoPartitionReporter,
+        message,
+    ):
+        self.operator_support = operator_support
+        self.reporter = reporter
+        self.message = message
+
+    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
+        is_supported = self.operator_support.is_node_supported(submodules, node)
+        if not is_supported:
+            self.reporter.report_reject(node, self.message)
+        return is_supported