Skip to content

Commit 175a583

Browse files
chunnienccopybara-github
authored andcommitted
experimental initial torch-tfl setup
PiperOrigin-RevId: 736573537
1 parent d135da4 commit 175a583

File tree

8 files changed

+331
-0
lines changed

8 files changed

+331
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Torch-TFL ops definitions, decompositions, and lowerings."""
16+
from ai_edge_torch.odml_torch.experimental.torch_tfl import _decomps
17+
from ai_edge_torch.odml_torch.experimental.torch_tfl import _lowerings
18+
from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops
19+
20+
decomps = _decomps.decomps
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Torch ops to Torch-TFL decompositions."""
16+
from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops
17+
import torch
18+
19+
decomps = {}
20+
21+
22+
def register_decomp(op):
23+
global decomps
24+
ops = [op]
25+
if isinstance(op, torch._ops.OpOverloadPacket):
26+
ops = [getattr(op, overload) for overload in op.overloads()]
27+
28+
def register(decomp_fn):
29+
for op in ops:
30+
decomps[op] = decomp_fn
31+
return decomp_fn
32+
33+
return register
34+
35+
36+
@register_decomp(torch.ops.aten.mm.default)
37+
def _aten_mm_decomp(x, y):
38+
return torch.ops.tfl.batch_matmul(x, y)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Torch-TFL op to MLIR lowerings."""
16+
from ai_edge_torch.odml_torch.lowerings import registry
17+
18+
lower = registry.lower
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Torch-TFL op definitions and fake implementations."""
16+
from ai_edge_torch.odml_torch.experimental.torch_tfl import torch_library_utils
17+
import torch
18+
19+
custom_op_with_fake = torch_library_utils.custom_op_with_fake
20+
21+
22+
@custom_op_with_fake("tfl::batch_matmul")
23+
def tfl_batch_matmul(
24+
x: torch.Tensor,
25+
y: torch.Tensor,
26+
adj_x: bool = False,
27+
adj_y: bool = False,
28+
asymmetric_quantize_inputs: bool = False,
29+
) -> torch.Tensor:
30+
if asymmetric_quantize_inputs:
31+
raise NotImplementedError(
32+
"asymmetric_quantize_inputs=True is not implemented"
33+
)
34+
if x.ndim < 2 or y.ndim < 2:
35+
raise ValueError("Input tensors must have at least 2 dimensions.")
36+
if adj_x:
37+
x = torch.transpose(x, -1, -2)
38+
if adj_y:
39+
y = torch.transpose(y, -1, -2)
40+
return torch.matmul(x, y)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Numerical validation tests for torch ops and Torch-TFL ops."""
16+
from ai_edge_torch import testing
17+
from ai_edge_torch.odml_torch.experimental import torch_tfl
18+
import torch
19+
from torch.utils import _pytree as pytree
20+
21+
from absl.testing import absltest as googletest
22+
from absl.testing import parameterized
23+
24+
export_with_tensor_inputs_only = testing.export_with_tensor_inputs_only
25+
26+
27+
def rnd(dtype, shape, min_v=None, max_v=None):
28+
"""Shortcut for creating a random torch tensor."""
29+
if dtype in (torch.int32, torch.int64, torch.bool):
30+
min_v = min_v if min_v else 1
31+
max_v = max_v if max_v else 10
32+
return torch.randint(min_v, max_v, shape).to(dtype)
33+
else:
34+
min_v = min_v if min_v else 0.0
35+
max_v = max_v if max_v else 1.0
36+
return (torch.rand(shape) * (max_v - min_v) + min_v).to(dtype)
37+
38+
39+
class TestTorchTFLImpls(parameterized.TestCase):
40+
"""Numerical validation tests for torch ops and Torch-TFL ops.
41+
42+
The op test suite is forked from
43+
ai_edge_torch/odml_torch/test/test_core_aten_ops.py. Eventually, we should
44+
merge the two test suites.
45+
"""
46+
47+
def setUp(self):
48+
super().setUp()
49+
torch.manual_seed(0)
50+
51+
def _assert_export_and_close(
52+
self, func, args, kwargs, atol=1e-3, rtol=1e-5, equal_nan=True
53+
):
54+
"""Assert func, args, and kwargs can be lowered and pass numerical validation."""
55+
with self.subTest("torch_eval"):
56+
expected = func(*args, **kwargs)
57+
58+
with self.subTest("export_and_decompse"):
59+
exported_program = export_with_tensor_inputs_only(func, args, kwargs)
60+
exported_program = exported_program.run_decompositions(
61+
torch_tfl.decomps
62+
)
63+
64+
with self.subTest("decomp_eval"):
65+
args, kwargs = exported_program.example_inputs
66+
actual = exported_program.module()(*args, **kwargs)
67+
68+
with self.subTest("torch_lower_eval_diff:" + str(atol)):
69+
expected_flat, expected_spec = pytree.tree_flatten(expected)
70+
actual_flat, actual_spec = pytree.tree_flatten(actual)
71+
72+
self.assertEqual(expected_spec, actual_spec)
73+
for v1, v2 in zip(expected_flat, actual_flat):
74+
torch.testing.assert_close(
75+
v1, v2, atol=atol, rtol=rtol, equal_nan=equal_nan
76+
)
77+
78+
@parameterized.named_parameters(
79+
# fmt: off
80+
# pyformat: disabledef
81+
("aten_mm_0", torch.ops.aten.mm.default, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
82+
("aten_mm_1", torch.ops.aten.mm.default, (rnd(torch.float32, (2, 10)), rnd(torch.float32, (10, 5)),), dict()),
83+
# fmt: on
84+
# pyformat: enable
85+
)
86+
def test_op(self, op, args, kwargs):
87+
self._assert_export_and_close(op, args, kwargs)
88+
89+
90+
if __name__ == "__main__":
91+
googletest.main()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Utility functions for defining custom ops in torch library."""
16+
from typing import Callable, Iterable
17+
import torch
18+
19+
20+
def custom_op_with_fake(
21+
name: str,
22+
*,
23+
mutates_args: str | Iterable[str] = (),
24+
schema: str | None = None,
25+
):
26+
"""Defines a custom op with a FakeTensor implementation using the same function."""
27+
28+
def register(fn: Callable[..., object]):
29+
op = torch.library.custom_op(
30+
name,
31+
mutates_args=mutates_args,
32+
schema=schema,
33+
)(fn)
34+
torch.library.register_fake(name)(fn)
35+
return op
36+
37+
return register

ai_edge_torch/testing/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
from ai_edge_torch.testing import export
16+
from ai_edge_torch.testing import model_coverage
17+
18+
export_with_tensor_inputs_only = export.export_with_tensor_inputs_only

ai_edge_torch/testing/export.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Torch export utilities for testing."""
16+
17+
from collections.abc import Callable
18+
from typing import Any
19+
20+
import torch
21+
from torch.utils import _pytree as pytree
22+
23+
24+
def export_with_tensor_inputs_only(
25+
model: Callable[..., Any],
26+
args: tuple[Any, ...],
27+
kwargs: dict[str, Any],
28+
) -> torch.export.ExportedProgram:
29+
"""Exports a PyTorch model, treating only tensor inputs as export inputs.
30+
31+
This function takes a PyTorch model and its input arguments (positional and
32+
keyword) and exports it using `torch.export.export`. However, it modifies
33+
the export process such that only the `torch.Tensor` arguments in the
34+
inputs are considered as export inputs to the exported graph. All other
35+
argument types (e.g., scalars, lists, tuples containing non-tensors) are
36+
treated as constants.
37+
38+
This is useful for testing scenarios where you want to export a model but
39+
want to avoid issues that might arise from non-tensor inputs
40+
being treated as variables, or when you specifically want to focus on the
41+
graph structure based on tensor operations.
42+
43+
Args:
44+
model: The PyTorch `nn.Module` to be exported.
45+
args: A tuple of positional arguments to be passed to the model's `forward`
46+
method.
47+
kwargs: A dictionary of keyword arguments to be passed to the model's
48+
`forward` method.
49+
50+
Returns:
51+
torch.export.ExportedProgram: The exported program representing the model
52+
computation with only tensor inputs being export inputs.
53+
"""
54+
flatten_args, treespec = pytree.tree_flatten([args, kwargs])
55+
56+
export_args = []
57+
indices = []
58+
for i, arg in enumerate(flatten_args):
59+
if isinstance(arg, torch.Tensor):
60+
export_args.append(arg)
61+
indices.append(i)
62+
63+
class ModuleWrapper(torch.nn.Module):
64+
65+
def __init__(self, func, original_args, original_kwargs):
66+
super().__init__()
67+
self.original_args = list(flatten_args)
68+
self.func = func
69+
70+
def forward(self, *export_args):
71+
flatten_args = self.original_args.copy()
72+
for i, arg in zip(indices, export_args):
73+
flatten_args[i] = arg
74+
args, kwargs = pytree.tree_unflatten(flatten_args, treespec)
75+
return self.func(*args, **kwargs)
76+
77+
export_args = tuple(export_args)
78+
export_kwargs = {}
79+
return torch.export.export(
80+
ModuleWrapper(model, args, kwargs).eval(),
81+
export_args,
82+
export_kwargs,
83+
)

0 commit comments

Comments
 (0)