Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mixed dtype check for XNNPACK partitioner #9533

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions backends/xnnpack/partition/config/xnnpack_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def check_common_constraints(
return True

def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
# Check inputs are valid dtypes
# Check inputs are valid and have the same dtypes
# Gather all args which are nodes
args_to_check = []
reference_dtype = None
for arg in node.args:
if isinstance(arg, list) or isinstance(arg, tuple):
for item in arg:
Expand Down Expand Up @@ -174,11 +175,17 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
if arg_val.dtype not in valid_dtypes:
return False

if reference_dtype is None:
reference_dtype = arg_val.dtype
elif arg_val.dtype != reference_dtype:
return False

return True

def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
# Check outputs are valid dtype
# Check outputs are valid and have the same dtypes
node_val = node.meta.get("val", None)
reference_dtype = None
if node_val is None:
return True

Expand All @@ -192,6 +199,11 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
if val.dtype not in valid_dtypes:
return False

if reference_dtype is None:
reference_dtype = val.dtype
elif val.dtype != reference_dtype:
return False

return True

def _check_node_has_valid_dtype(self, node):
Expand Down
37 changes: 30 additions & 7 deletions backends/xnnpack/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,26 @@ def forward(self, x):
out2 = x + self._constant2 + self._constant3
return out1, out2

def _test_add(self, inputs):
(
def _test_add(self, inputs, mixed_dtype=False):
tester = (
Tester(self.Add(), inputs)
.export()
.check_count({"torch.ops.aten.add.Tensor": 4})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

if mixed_dtype:
# Inverse check for mixed-dtype: original node remains and no delegate node
tester.check_count(
{"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4}
)
tester.check_not(["torch.ops.higher_order.executorch_call_delegate"])
else:
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
tester.check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"])

(tester.to_executorch().serialize().run_method_and_compare_outputs())

def test_fp16_add(self):
inputs = (torch.randn(1).to(torch.float16), torch.randn(1).to(torch.float16))
self._test_add(inputs)
Expand Down Expand Up @@ -237,3 +244,19 @@ def forward(self, x, z):
.serialize()
.run_method_and_compare_outputs()
)

def test_fp32_add_with_mixed_dtype(self):
test_cases = [
torch.bfloat16,
torch.float16,
torch.int8,
]
for dtype in test_cases:
with self.subTest(dtype=str(dtype)):
inputs = (
torch.randn(1, 1, 4, 4).to(torch.float32),
torch.randn(1, 1, 4, 4).to(dtype),
)
# Set mixed_dtype=True to verify that
# no delegate node is inserted and the original node remains in the graph
self._test_add(inputs, mixed_dtype=True)
34 changes: 27 additions & 7 deletions backends/xnnpack/test/ops/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def forward(self, *args):
x = torch.cat(xs, dim=self.dim)
return x + x # Quantize by propagation.

def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
def _test_cat(
self, module, inputs, cat_num=1, quant=False, quant_ops=2, mixed_dtype=False
):
for legacy_mode in (True, False):
tester = Tester(module, inputs)

Expand Down Expand Up @@ -53,15 +55,17 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
if quant:
tester.check_not(["torch.ops.quantized_decomposed"])

(
if mixed_dtype:
# Inverse check for mixed-dtype: original node remains and no delegate node
tester.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
tester.check_not(["torch.ops.higher_order.executorch_call_delegate"])
else:
tester.check_count(
{"torch.ops.higher_order.executorch_call_delegate": 1}
)
.check_not(["executorch_exir_dialects_edge__ops_aten_cat"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)
tester.check_not(["executorch_exir_dialects_edge__ops_aten_cat"])

(tester.to_executorch().serialize().run_method_and_compare_outputs())

def test_fp16_cat2(self):
"""
Expand Down Expand Up @@ -249,3 +253,19 @@ def forward(self, x, y):
def _test_qs8_cat_nhwc2(self):
inputs = (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3))
self._test_cat(self.CatNhwc(), inputs, quant=True, quant_ops=4)

def test_fp32_cat_with_mixed_dtype(self):
test_cases = [
torch.bfloat16,
torch.float16,
torch.int8,
]
for dtype in test_cases:
with self.subTest(dtype=str(dtype)):
inputs = (
torch.randn(1, 2, 3).to(torch.float32),
torch.randn(1, 2, 3).to(dtype),
)
# Set mixed_dtype=True to verify that
# no delegate node is inserted and the original node remains in the graph
self._test_cat(self.Cat(), inputs, mixed_dtype=True)
37 changes: 30 additions & 7 deletions backends/xnnpack/test/ops/test_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,26 @@ def forward(self, x):
z = x / x
return z

def _test_div(self, inputs):
(
def _test_div(self, inputs, mixed_dtype=False):
tester = (
Tester(self.Div(), inputs)
.export()
.check_count({"torch.ops.aten.div.Tensor": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

if mixed_dtype:
# Inverse check for mixed-dtype: original node remains and no delegate node
tester.check_count(
{"executorch_exir_dialects_edge__ops_aten_div_Tensor": 1}
)
tester.check_not(["torch.ops.higher_order.executorch_call_delegate"])
else:
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
tester.check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"])

(tester.to_executorch().serialize().run_method_and_compare_outputs())

def test_fp16_div(self):
# Adding 4 to move distribution away from 0, 4 Std Dev should be far enough
inputs = (
Expand Down Expand Up @@ -67,3 +74,19 @@ def test_fp32_div_single_input(self):
.serialize()
.run_method_and_compare_outputs()
)

def test_fp32_div_with_mixed_dtype(self):
test_cases = [
torch.bfloat16,
torch.float16,
torch.int8,
]
for dtype in test_cases:
with self.subTest(dtype=str(dtype)):
inputs = (
(torch.randn(1) + 4).to(torch.float32),
(torch.randn(1) + 4).to(dtype),
)
# Set mixed_dtype=True to verify that
# no delegate node is inserted and the original node remains in the graph
self._test_div(inputs, mixed_dtype=True)
37 changes: 30 additions & 7 deletions backends/xnnpack/test/ops/test_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,26 @@ def forward(self, x, y):
z = x * y
return torch.nn.functional.relu(z)

def _test_mul(self, inputs):
(
def _test_mul(self, inputs, mixed_dtype=False):
tester = (
Tester(self.Mul(), inputs)
.export()
.check_count({"torch.ops.aten.mul.Tensor": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_mul_Tensor"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

if mixed_dtype:
# Inverse check for mixed-dtype: original node remains and no delegate node
tester.check_count(
{"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1}
)
tester.check_not(["torch.ops.higher_order.executorch_call_delegate"])
else:
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
tester.check_not(["executorch_exir_dialects_edge__ops_aten_mul_Tensor"])

(tester.to_executorch().serialize().run_method_and_compare_outputs())

def test_fp16_mul(self):
inputs = (
torch.randn((1, 3)).to(torch.float16),
Expand Down Expand Up @@ -144,3 +151,19 @@ def test_qs8_mul_relu(self):
.serialize()
.run_method_and_compare_outputs()
)

def test_fp32_mul_with_mixed_dtype(self):
test_cases = [
torch.bfloat16,
torch.float16,
torch.int8,
]
for dtype in test_cases:
with self.subTest(dtype=str(dtype)):
inputs = (
torch.randn(1, 1, 4, 4).to(torch.float32),
torch.randn(1, 1, 4, 4).to(dtype),
)
# Set mixed_dtype=True to verify that
# no delegate node is inserted and the original node remains in the graph
self._test_mul(inputs, mixed_dtype=True)
37 changes: 30 additions & 7 deletions backends/xnnpack/test/ops/test_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,26 @@ def forward(self, x):
z = x - x
return z

def _test_sub(self, inputs):
(
def _test_sub(self, inputs, mixed_dtype=False):
tester = (
Tester(self.Sub(), inputs)
.export()
.check_count({"torch.ops.aten.sub.Tensor": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

if mixed_dtype:
# Inverse check for mixed-dtype: original node remains and no delegate node
tester.check_count(
{"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1}
)
tester.check_not(["torch.ops.higher_order.executorch_call_delegate"])
else:
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
tester.check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"])

(tester.to_executorch().serialize().run_method_and_compare_outputs())

def test_fp16_sub(self):
inputs = (
torch.randn((1, 3)).to(torch.float16),
Expand Down Expand Up @@ -149,3 +156,19 @@ def forward(self, x, y):
.serialize()
.run_method_and_compare_outputs()
)

def test_fp32_sub_with_mixed_dtype(self):
test_cases = [
torch.bfloat16,
torch.float16,
torch.int8,
]
for dtype in test_cases:
with self.subTest(dtype=str(dtype)):
inputs = (
torch.randn(1, 1, 4, 4).to(torch.float32),
torch.randn(1, 1, 4, 4).to(dtype),
)
# Set mixed_dtype=True to verify that
# no delegate node is inserted and the original node remains in the graph
self._test_sub(inputs, mixed_dtype=True)
Loading