diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index 20018610fce..f247f0631cf 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -144,9 +144,10 @@ def check_common_constraints( return True def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): - # Check inputs are valid dtypes + # Check inputs are valid and have the same dtypes # Gather all args which are nodes args_to_check = [] + reference_dtype = None for arg in node.args: if isinstance(arg, list) or isinstance(arg, tuple): for item in arg: @@ -174,11 +175,17 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): if arg_val.dtype not in valid_dtypes: return False + if reference_dtype is None: + reference_dtype = arg_val.dtype + elif arg_val.dtype != reference_dtype: + return False + return True def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): - # Check outputs are valid dtype + # Check outputs are valid and have the same dtypes node_val = node.meta.get("val", None) + reference_dtype = None if node_val is None: return True @@ -192,6 +199,11 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): if val.dtype not in valid_dtypes: return False + if reference_dtype is None: + reference_dtype = val.dtype + elif val.dtype != reference_dtype: + return False + return True def _check_node_has_valid_dtype(self, node): diff --git a/backends/xnnpack/test/ops/test_add.py b/backends/xnnpack/test/ops/test_add.py index 29a87df1303..4cd6532756d 100644 --- a/backends/xnnpack/test/ops/test_add.py +++ b/backends/xnnpack/test/ops/test_add.py @@ -42,19 +42,26 @@ def forward(self, x): out2 = x + self._constant2 + self._constant3 return out1, out2 - def _test_add(self, inputs): - ( + def _test_add(self, inputs, mixed_dtype=False): + tester = ( Tester(self.Add(), inputs) .export() .check_count({"torch.ops.aten.add.Tensor": 4}) .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) - .to_executorch() - .serialize() - .run_method_and_compare_outputs() ) + if mixed_dtype: + # Inverse check for mixed-dtype: original node remains and no delegate node + tester.check_count( + {"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4} + ) + tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) + else: + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) + + (tester.to_executorch().serialize().run_method_and_compare_outputs()) + def test_fp16_add(self): inputs = (torch.randn(1).to(torch.float16), torch.randn(1).to(torch.float16)) self._test_add(inputs) @@ -237,3 +244,19 @@ def forward(self, x, z): .serialize() .run_method_and_compare_outputs() ) + + def test_fp32_add_with_mixed_dtype(self): + test_cases = [ + torch.bfloat16, + torch.float16, + torch.int8, + ] + for dtype in test_cases: + with self.subTest(dtype=str(dtype)): + inputs = ( + torch.randn(1, 1, 4, 4).to(torch.float32), + torch.randn(1, 1, 4, 4).to(dtype), + ) + # Set mixed_dtype=True to verify that + # no delegate node is inserted and the original node remains in the graph + self._test_add(inputs, mixed_dtype=True) diff --git a/backends/xnnpack/test/ops/test_cat.py b/backends/xnnpack/test/ops/test_cat.py index dd551ea3fa7..4455667e952 100644 --- a/backends/xnnpack/test/ops/test_cat.py +++ b/backends/xnnpack/test/ops/test_cat.py @@ -23,7 +23,9 @@ def forward(self, *args): x = torch.cat(xs, dim=self.dim) return x + x # Quantize by propagation. - def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2): + def _test_cat( + self, module, inputs, cat_num=1, quant=False, quant_ops=2, mixed_dtype=False + ): for legacy_mode in (True, False): tester = Tester(module, inputs) @@ -53,15 +55,17 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2): if quant: tester.check_not(["torch.ops.quantized_decomposed"]) - ( + if mixed_dtype: + # Inverse check for mixed-dtype: original node remains and no delegate node + tester.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) + tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) + else: tester.check_count( {"torch.ops.higher_order.executorch_call_delegate": 1} ) - .check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) - .to_executorch() - .serialize() - .run_method_and_compare_outputs() - ) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) + + (tester.to_executorch().serialize().run_method_and_compare_outputs()) def test_fp16_cat2(self): """ @@ -249,3 +253,19 @@ def forward(self, x, y): def _test_qs8_cat_nhwc2(self): inputs = (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3)) self._test_cat(self.CatNhwc(), inputs, quant=True, quant_ops=4) + + def test_fp32_cat_with_mixed_dtype(self): + test_cases = [ + torch.bfloat16, + torch.float16, + torch.int8, + ] + for dtype in test_cases: + with self.subTest(dtype=str(dtype)): + inputs = ( + torch.randn(1, 2, 3).to(torch.float32), + torch.randn(1, 2, 3).to(dtype), + ) + # Set mixed_dtype=True to verify that + # no delegate node is inserted and the original node remains in the graph + self._test_cat(self.Cat(), inputs, mixed_dtype=True) diff --git a/backends/xnnpack/test/ops/test_div.py b/backends/xnnpack/test/ops/test_div.py index 9bca5feed48..4dce39d6dfa 100644 --- a/backends/xnnpack/test/ops/test_div.py +++ b/backends/xnnpack/test/ops/test_div.py @@ -27,19 +27,26 @@ def forward(self, x): z = x / x return z - def _test_div(self, inputs): - ( + def _test_div(self, inputs, mixed_dtype=False): + tester = ( Tester(self.Div(), inputs) .export() .check_count({"torch.ops.aten.div.Tensor": 1}) .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) - .to_executorch() - .serialize() - .run_method_and_compare_outputs() ) + if mixed_dtype: + # Inverse check for mixed-dtype: original node remains and no delegate node + tester.check_count( + {"executorch_exir_dialects_edge__ops_aten_div_Tensor": 1} + ) + tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) + else: + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) + + (tester.to_executorch().serialize().run_method_and_compare_outputs()) + def test_fp16_div(self): # Adding 4 to move distribution away from 0, 4 Std Dev should be far enough inputs = ( @@ -67,3 +74,19 @@ def test_fp32_div_single_input(self): .serialize() .run_method_and_compare_outputs() ) + + def test_fp32_div_with_mixed_dtype(self): + test_cases = [ + torch.bfloat16, + torch.float16, + torch.int8, + ] + for dtype in test_cases: + with self.subTest(dtype=str(dtype)): + inputs = ( + (torch.randn(1) + 4).to(torch.float32), + (torch.randn(1) + 4).to(dtype), + ) + # Set mixed_dtype=True to verify that + # no delegate node is inserted and the original node remains in the graph + self._test_div(inputs, mixed_dtype=True) diff --git a/backends/xnnpack/test/ops/test_multiply.py b/backends/xnnpack/test/ops/test_multiply.py index db50bc5dd44..99d78ee28e1 100644 --- a/backends/xnnpack/test/ops/test_multiply.py +++ b/backends/xnnpack/test/ops/test_multiply.py @@ -31,19 +31,26 @@ def forward(self, x, y): z = x * y return torch.nn.functional.relu(z) - def _test_mul(self, inputs): - ( + def _test_mul(self, inputs, mixed_dtype=False): + tester = ( Tester(self.Mul(), inputs) .export() .check_count({"torch.ops.aten.mul.Tensor": 1}) .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_mul_Tensor"]) - .to_executorch() - .serialize() - .run_method_and_compare_outputs() ) + if mixed_dtype: + # Inverse check for mixed-dtype: original node remains and no delegate node + tester.check_count( + {"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1} + ) + tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) + else: + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_mul_Tensor"]) + + (tester.to_executorch().serialize().run_method_and_compare_outputs()) + def test_fp16_mul(self): inputs = ( torch.randn((1, 3)).to(torch.float16), @@ -144,3 +151,19 @@ def test_qs8_mul_relu(self): .serialize() .run_method_and_compare_outputs() ) + + def test_fp32_mul_with_mixed_dtype(self): + test_cases = [ + torch.bfloat16, + torch.float16, + torch.int8, + ] + for dtype in test_cases: + with self.subTest(dtype=str(dtype)): + inputs = ( + torch.randn(1, 1, 4, 4).to(torch.float32), + torch.randn(1, 1, 4, 4).to(dtype), + ) + # Set mixed_dtype=True to verify that + # no delegate node is inserted and the original node remains in the graph + self._test_mul(inputs, mixed_dtype=True) diff --git a/backends/xnnpack/test/ops/test_sub.py b/backends/xnnpack/test/ops/test_sub.py index fb3d3d3f948..952cce20cfc 100644 --- a/backends/xnnpack/test/ops/test_sub.py +++ b/backends/xnnpack/test/ops/test_sub.py @@ -27,19 +27,26 @@ def forward(self, x): z = x - x return z - def _test_sub(self, inputs): - ( + def _test_sub(self, inputs, mixed_dtype=False): + tester = ( Tester(self.Sub(), inputs) .export() .check_count({"torch.ops.aten.sub.Tensor": 1}) .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"]) - .to_executorch() - .serialize() - .run_method_and_compare_outputs() ) + if mixed_dtype: + # Inverse check for mixed-dtype: original node remains and no delegate node + tester.check_count( + {"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1} + ) + tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) + else: + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"]) + + (tester.to_executorch().serialize().run_method_and_compare_outputs()) + def test_fp16_sub(self): inputs = ( torch.randn((1, 3)).to(torch.float16), @@ -149,3 +156,19 @@ def forward(self, x, y): .serialize() .run_method_and_compare_outputs() ) + + def test_fp32_sub_with_mixed_dtype(self): + test_cases = [ + torch.bfloat16, + torch.float16, + torch.int8, + ] + for dtype in test_cases: + with self.subTest(dtype=str(dtype)): + inputs = ( + torch.randn(1, 1, 4, 4).to(torch.float32), + torch.randn(1, 1, 4, 4).to(dtype), + ) + # Set mixed_dtype=True to verify that + # no delegate node is inserted and the original node remains in the graph + self._test_sub(inputs, mixed_dtype=True)