Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions ai_edge_torch/_convert/test/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,78 @@ def forward(self, a, b, c):
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
)

def test_convert_conv2d_x1(self):
"""Tests conversion of a simple Conv2d module."""

class Conv2d(nn.Module):

def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1
)

def forward(self, x):
return self.conv(x)

args = (torch.randn((1, 3, 224, 224)),)
torch_module = Conv2d().eval()
edge_model = ai_edge_torch.convert(torch_module, args)

tmp_dir_name = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR")
tmp_dir_path = os.path.join(tmp_dir_name, "conv2d_x1.tflite")
edge_model.export(tmp_dir_path)

self.assertTrue(
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
)

def test_convert_conv2d_add(self):
"""Tests conversion of Conv2d layers with add ops."""

class Conv2d_add(nn.Module):

def __init__(self):
super().__init__()
self.convs = nn.ModuleList()
self.convs.append(
nn.Conv2d(
in_channels=3,
out_channels=16,
kernel_size=3,
stride=1,
padding=1,
)
)
for _ in range(14):
self.convs.append(
nn.Conv2d(
in_channels=16,
out_channels=16,
kernel_size=3,
stride=1,
padding=1,
)
)

def forward(self, x):
x = self.convs[0](x)
for i in range(1, 15):
x = x + self.convs[i](x)
return x

args = (torch.randn((1, 3, 224, 224)),)
torch_module = Conv2d_add().eval()
edge_model = ai_edge_torch.convert(torch_module, args)

tmp_dir_name = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR")
tmp_dir_path = os.path.join(tmp_dir_name, "conv2d_add_x14.tflite")
edge_model.export(tmp_dir_path)

self.assertTrue(
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
)

def test_convert_resnet18(self):
args = (torch.randn(4, 3, 224, 224),)
torch_module = torchvision.models.resnet18().eval()
Expand Down
40 changes: 39 additions & 1 deletion ai_edge_torch/testing/model_coverage/model_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,34 @@ def _torch_tensors_to_np(*argv):
raise ValueError("Unsupported torch.tensor type.")


def _print_diff(tensor_idx, tflite_out, torch_out):
"""Prints difference details between two tensors."""
diff = np.abs(tflite_out - torch_out)
max_abs_diff = np.max(diff)
avg_abs_diff = np.mean(diff)
print(f"Tensor {tensor_idx} difference:")
print(f" PyTorch result: {torch_out}")
print(f" TFLite result: {tflite_out}")
print(f" Difference: {diff}")
print(f" Max absolute difference: {max_abs_diff}")
print(f" Mean absolute difference: {avg_abs_diff}")
top10_diffs = np.sort(diff.flatten())[-10:][::-1]
print(f" Top 10 differences: {top10_diffs}")
nonzero_indices = np.abs(torch_out) > 0
if np.any(nonzero_indices):
rel_diff = diff[nonzero_indices] / np.abs(torch_out[nonzero_indices])
max_rel_diff_percent = np.max(rel_diff) * 100
mean_rel_diff_percent = np.mean(rel_diff) * 100
print(
" Max relative difference (for non-zero golden values):"
f" {max_rel_diff_percent:.2f}%"
)
print(
" Mean relative difference (for non-zero golden values):"
f" {mean_rel_diff_percent:.2f}%"
)


def compare_tflite_torch(
edge_model: model.Model,
torch_eval_func: Callable,
Expand All @@ -72,7 +100,7 @@ def compare_tflite_torch(
*,
num_valid_inputs: int = 1,
signature_name: str = None,
atol: float = 1e-5,
atol: float = 1e-4,
rtol: float = 1e-5
):
"""Compares torch models and TFLite models.
Expand Down Expand Up @@ -140,6 +168,16 @@ def get_edge_output(inputs):
for out, golden_out in zip(output, golden_output)
])
if not is_equal:
print("TFLite and PyTorch results are different.")
if not is_output_len_eq:
print(
"Output length mismatch:"
f" TFLite {len(output)}, PyTorch {len(golden_output)}"
)
return False
for i, (out, golden_out) in enumerate(zip(output, golden_output)):
if not np.allclose(out, golden_out, atol=atol, rtol=rtol):
_print_diff(i, out, golden_out)
return False

return True
Loading