Skip to content

Commit 6981558

Browse files
chunnienccopybara-github
authored andcommitted
enable tfl direct lowering on shlo custom_call carrier
PiperOrigin-RevId: 736594309
1 parent 175a583 commit 6981558

File tree

3 files changed

+72
-11
lines changed

3 files changed

+72
-11
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,52 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Torch-TFL op to MLIR lowerings."""
16+
from typing import Sequence
17+
from ai_edge_torch import odml_torch
18+
from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops
1619
from ai_edge_torch.odml_torch.lowerings import registry
20+
from ai_edge_torch.odml_torch.lowerings import utils as lowering_utils
21+
from jax._src.lib.mlir import ir
22+
from jax._src.lib.mlir.dialects import hlo as stablehlo
23+
import torch
1724

1825
lower = registry.lower
26+
LoweringContext = odml_torch.lowerings.context.LoweringContext
27+
28+
29+
def _ir_operation(
30+
name: str,
31+
results: Sequence[ir.Type],
32+
operands: Sequence[ir.Value] | None = None,
33+
attributes: dict[str, ir.Attribute] | None = None,
34+
):
35+
"""Helper function to create an IR operation in StableHLO CustomCall carrier."""
36+
attributes = ir.DictAttr.get(attributes)
37+
return stablehlo.custom_call(
38+
result=results,
39+
inputs=operands,
40+
call_target_name=ir.StringAttr.get(name),
41+
has_side_effect=ir.BoolAttr.get(False),
42+
backend_config=attributes,
43+
api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 4),
44+
)
45+
46+
47+
@lower(torch.ops.tfl.batch_matmul.default)
48+
def _tfl_batch_matmul_lowering(
49+
lctx: LoweringContext,
50+
x: ir.Value,
51+
y: ir.Value,
52+
adj_x: bool = False,
53+
adj_y: bool = False,
54+
) -> ir.Value:
55+
return _ir_operation(
56+
"tfl.batch_matmul",
57+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
58+
operands=[x, y],
59+
attributes={
60+
"adj_x": ir.BoolAttr.get(adj_x),
61+
"adj_y": ir.BoolAttr.get(adj_y),
62+
"asymmetric_quantize_inputs": ir.BoolAttr.get(False),
63+
},
64+
)

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,8 @@
2121

2222
@custom_op_with_fake("tfl::batch_matmul")
2323
def tfl_batch_matmul(
24-
x: torch.Tensor,
25-
y: torch.Tensor,
26-
adj_x: bool = False,
27-
adj_y: bool = False,
28-
asymmetric_quantize_inputs: bool = False,
24+
x: torch.Tensor, y: torch.Tensor, adj_x: bool = False, adj_y: bool = False
2925
) -> torch.Tensor:
30-
if asymmetric_quantize_inputs:
31-
raise NotImplementedError(
32-
"asymmetric_quantize_inputs=True is not implemented"
33-
)
3426
if x.ndim < 2 or y.ndim < 2:
3527
raise ValueError("Input tensors must have at least 2 dimensions.")
3628
if adj_x:

ai_edge_torch/odml_torch/lowerings/utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818
import numbers
1919
from typing import Any
2020
from typing import Optional
21-
21+
from ai_edge_torch.odml_torch import export_utils
2222
from jax._src.lib.mlir import ir
2323
from jax._src.lib.mlir.dialects import hlo as stablehlo
2424
import numpy as np
2525
import torch
26+
import torch.utils._pytree as pytree
2627

2728

28-
def torch_dtype_to_ir_element_type(dtype):
29+
def torch_dtype_to_ir_element_type(dtype) -> ir.Type:
30+
"""Builds ir.Type from torch dtype."""
2931
ty_get = {
3032
torch.double: ir.F64Type.get,
3133
torch.float32: ir.F32Type.get,
@@ -39,6 +41,27 @@ def torch_dtype_to_ir_element_type(dtype):
3941
return ty_get()
4042

4143

44+
def node_meta_to_ir_types(node: torch.fx.Node) -> list[ir.Type]:
45+
"""Builds IR result types from torch FX node meta."""
46+
tensor_meta = node.meta.get("tensor_meta") or node.meta.get("val")
47+
if not tensor_meta:
48+
raise RuntimeError(f"{node.name} does not have tensor meta")
49+
50+
tensor_meta_list, _ = pytree.tree_flatten(
51+
[tensor_meta],
52+
is_leaf=lambda x: hasattr(x, "dtype") and hasattr(x, "shape"),
53+
)
54+
results = []
55+
for meta in tensor_meta_list:
56+
shape = [
57+
export_utils.IR_DYNAMIC if export_utils.is_torch_dynamic(dim) else dim
58+
for dim in meta.shape
59+
]
60+
elty = torch_dtype_to_ir_element_type(meta.dtype)
61+
results.append(ir.RankedTensorType.get(shape, elty))
62+
return results
63+
64+
4265
def splat(val, ty, shape=tuple(), *, loc: Optional[Any] = None):
4366
if isinstance(ty, ir.IntegerType):
4467
if ty.width == 1:

0 commit comments

Comments
 (0)