Skip to content

Commit 5538706

Browse files
junjiang-labcopybara-github
authored andcommitted
Add impl for aten.mean.dim and lowering.
PiperOrigin-RevId: 764338440
1 parent 92c6eb6 commit 5538706

File tree

5 files changed

+141
-31
lines changed

5 files changed

+141
-31
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ def _aten_bitwise_and_tensor_decomp(x, y):
110110
return torch.ops.tfl.logical_and(x, y)
111111

112112

113+
@register_decomp(torch.ops.aten.mean.dim)
114+
def _aten_mean_dim_decomp(x, dim, keepdim=False):
115+
return torch.ops.tfl.mean(x, dim, keepdim)
116+
117+
113118
@register_decomp(torch.ops.aten.gt.Tensor)
114119
def _aten_gt_tensor_decomp(x, y):
115120
return torch.ops.tfl.greater(x, y)
@@ -203,6 +208,19 @@ def _aten_cat_decomp(tensors, dim=0):
203208
return torch.ops.tfl.concatenation(processed_tensors, dim)
204209

205210

211+
@register_decomp(torch.ops.aten.full_like.default)
212+
def _aten_full_like_decomp(
213+
x,
214+
fill_value,
215+
dtype=None,
216+
layout=None,
217+
device=None,
218+
pin_memory=None,
219+
memory_format=None,
220+
):
221+
return torch.ops.tfl.fill(tuple(x.shape), fill_value)
222+
223+
206224
@register_decomp(torch.ops.aten.view.default)
207225
def _aten_view_decomp(x, shape: Sequence[int]):
208226
return torch.ops.tfl.reshape(x, shape)

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 75 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,27 @@ def _tfl_logical_and_lowering(
177177
)
178178

179179

180+
@lower(torch.ops.tfl.mean.default)
181+
def _tfl_mean_lowering(
182+
lctx: LoweringContext,
183+
x: ir.Value,
184+
dims: int | ir.Value | Sequence[int | ir.Value],
185+
keepdim: bool = False,
186+
) -> ir.Value:
187+
if isinstance(dims, int) or isinstance(dims, ir.Value):
188+
dims_ir_value = lowering_utils.convert_to_ir_value(dims)
189+
else:
190+
dims_ir_value = lowering_utils.convert_shape_to_ir_value(dims)
191+
return _ir_operation(
192+
"tfl.mean",
193+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
194+
operands=[x, dims_ir_value],
195+
attributes={
196+
"keep_dims": ir.BoolAttr.get(keepdim),
197+
},
198+
)
199+
200+
180201
@lower(torch.ops.tfl.greater.default)
181202
def _tfl_greater_lowering(
182203
lctx: LoweringContext,
@@ -317,45 +338,69 @@ def _tfl_concatenation_lowering(
317338
)
318339

319340

341+
@lower(torch.ops.tfl.fill.default)
342+
def _tfl_fill_lowering(
343+
lctx: LoweringContext,
344+
dims: Sequence[int | ir.Value],
345+
fill_value: ir.Value,
346+
) -> ir.Value:
347+
dims_ir_value = lowering_utils.convert_shape_to_ir_value(dims)
348+
fill_value_ir_value = lowering_utils.convert_to_ir_value(fill_value)
349+
350+
# Ensure fill_value_ir_value is a scalar (0-D tensor) for TFLite Fill op.
351+
# The TFLite Fill kernel expects the value to be a 0-D tensor.
352+
if isinstance(fill_value_ir_value.type, ir.RankedTensorType):
353+
tensor_type = fill_value_ir_value.type
354+
# If it's a 1-D tensor with a single element, reshape to 0-D.
355+
if list(tensor_type.shape) == [1]:
356+
scalar_type = ir.RankedTensorType.get([], tensor_type.element_type)
357+
fill_value_ir_value = stablehlo.reshape(scalar_type, fill_value_ir_value)
358+
359+
# Determine the target element type from the node's output definition.
360+
result_types = lowering_utils.node_meta_to_ir_types(lctx.node)
361+
if not result_types or not isinstance(result_types[0], ir.RankedTensorType):
362+
raise ValueError(
363+
"tfl.fill: Unable to determine result tensor type or result is not a"
364+
" ranked tensor."
365+
)
366+
target_element_type = result_types[0].element_type
367+
368+
# Ensure fill_value_ir_value is a RankedTensorType to access its properties.
369+
if not isinstance(fill_value_ir_value.type, ir.RankedTensorType):
370+
raise TypeError(
371+
"tfl.fill: fill_value_ir_value expected to be RankedTensorType, got"
372+
f" {fill_value_ir_value.type}"
373+
)
374+
375+
current_fill_tensor_type = fill_value_ir_value.type
376+
current_element_type = current_fill_tensor_type.element_type
377+
378+
# If the element type of the (scalar) fill_value doesn't match the target
379+
# output element type, cast fill_value_ir_value to the target_element_type
380+
# while maintaining its current shape (which should be scalar).
381+
if current_element_type != target_element_type:
382+
cast_to_type = ir.RankedTensorType.get(
383+
current_fill_tensor_type.shape, target_element_type
384+
)
385+
fill_value_ir_value = stablehlo.convert(cast_to_type, fill_value_ir_value)
386+
387+
return _ir_operation(
388+
"tfl.fill",
389+
results=result_types,
390+
operands=[dims_ir_value, fill_value_ir_value],
391+
)
392+
393+
320394
@lower(torch.ops.tfl.reshape.default)
321395
def _tfl_reshape_lowering(
322396
lctx: LoweringContext,
323397
x: ir.Value,
324398
shape: Sequence[int | ir.Value],
325399
) -> ir.Value:
326-
# Check if all elements in the shape sequence are integers.
327-
if not shape or all(isinstance(dim, int) for dim in shape):
328-
# If all are integers, create a constant numpy array.
329-
# Assuming int32 is the required type for TFLite shape tensors.
330-
shape_ir_value = lowering_utils.numpy_array_constant(
331-
np.array(shape, dtype=np.int32)
332-
)
333-
else:
334-
# Handle mixed int and ir.Value shape sequence
335-
processed_dims = []
336-
for dim in shape:
337-
if isinstance(dim, int):
338-
# Convert int to a constant 1D tensor
339-
shape_ir_value = lowering_utils.numpy_array_constant(
340-
np.array([dim], dtype=np.int32)
341-
)
342-
processed_dims.append(shape_ir_value)
343-
else:
344-
assert isinstance(dim, ir.Value)
345-
# Convert ir.Value to a constant 1D tensor
346-
new_type = ir.RankedTensorType.get([1], dim.type.element_type)
347-
reshape_dim = stablehlo.reshape(new_type, dim)
348-
processed_dims.append(reshape_dim)
349-
350-
shape_ir_value = stablehlo.concatenate(
351-
processed_dims,
352-
dimension=0,
353-
)
354-
355400
return _ir_operation(
356401
"tfl.reshape",
357402
results=lowering_utils.node_meta_to_ir_types(lctx.node),
358-
operands=[x, shape_ir_value],
403+
operands=[x, lowering_utils.convert_shape_to_ir_value(shape)],
359404
)
360405

361406

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ def tfl_logical_and(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
6868
return torch.logical_and(x, y)
6969

7070

71+
@custom_op_with_fake(
72+
"tfl::mean", schema="(Tensor x, Any dims, bool keepdim) -> Tensor"
73+
)
74+
def tfl_mean(x: torch.Tensor, dims: Any, keepdim: bool = False) -> torch.Tensor:
75+
return torch.mean(x, dims, keepdim)
76+
77+
7178
@custom_op_with_fake("tfl::greater")
7279
def tfl_greater(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
7380
return torch.gt(x, y)
@@ -123,6 +130,11 @@ def tfl_concatenation(
123130
return torch.cat(tensors, dim=dim)
124131

125132

133+
@custom_op_with_fake("tfl::fill", schema="(int[] x, Any y) -> Tensor")
134+
def tfl_fill(dims: Sequence[int], fill_value: Any) -> torch.Tensor:
135+
return torch.full(dims, fill_value)
136+
137+
126138
def _normalize_shape(
127139
tensor_input: torch.Tensor, shape: Sequence[int]
128140
) -> Sequence[int]:

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ def _assert_export_and_close(
134134
("aten_pow_Tensor_Scalar_0", torch.ops.aten.pow.Tensor_Scalar, (rnd(torch.float32, (10, 10)), np.random.rand(),), dict()),
135135
("aten_pow_Tensor_Tensor_0", torch.ops.aten.pow.Tensor_Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
136136
("aten_bitwise_and_Tensor_0", torch.ops.aten.bitwise_and.Tensor, (rnd(torch.bool, (10, 10)), rnd(torch.bool, (10, 10)),), dict()),
137+
("aten_mean_dim_0", torch.ops.aten.mean.dim, (rnd(torch.float32, (10, 10)), 0), dict()),
138+
("aten_mean_dim_1", torch.ops.aten.mean.dim, (rnd(torch.float32, (10, 10)), 0, True), dict()),
139+
("aten_mean_dim_2", torch.ops.aten.mean.dim, (rnd(torch.float32, (10, 10)), 1), dict()),
137140
("aten_gt_Tensor_0", torch.ops.aten.gt.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
138141
("aten_gt_Tensor_1", torch.ops.aten.gt.Tensor, (rnd(torch.float32, (1, 10)), rnd(torch.float32, (10, 1)),), dict()),
139142
("aten_lt_Tensor_0", torch.ops.aten.lt.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
@@ -163,6 +166,8 @@ def _assert_export_and_close(
163166
("aten_cat_2", torch.ops.aten.cat.default, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0, 10))], 0,), dict()),
164167
("aten_cat_3", torch.ops.aten.cat.default, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0,))], 0,), dict()),
165168
("aten_cat_4", torch.ops.aten.cat.default, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10))],), dict()),
169+
("aten_full_like_0", torch.ops.aten.full_like.default, (rnd(torch.float32, (10, 10)), 0.123,), dict()),
170+
("aten_full_like_1", torch.ops.aten.full_like.default, (rnd(torch.int64, (10, 10)), 123,), dict()),
166171
("aten_view_0", torch.ops.aten.view.default, (rnd(torch.float32, (10, 10)), [1, 100],), dict()),
167172
("aten_view_1", torch.ops.aten.view.default, (rnd(torch.float32, (1, 10)), [10, 1],), dict()),
168173
("aten_view_2", torch.ops.aten.view.default, (rnd(torch.float32, (10, 10)), [2, 5, 10],), dict()),

ai_edge_torch/odml_torch/lowerings/utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections.abc import Callable
1818
import functools
1919
import numbers
20-
from typing import Any, Optional, Union
20+
from typing import Any, Optional, Sequence, Union
2121
from ai_edge_torch.odml_torch import export_utils
2222
from jax._src.lib.mlir import ir
2323
from jax._src.lib.mlir.dialects import hlo as stablehlo
@@ -281,3 +281,33 @@ def convert_to_ir_value(
281281
if isinstance(value, ir.Value):
282282
return value
283283
raise TypeError(f"Unsupported type for conversion to ir.Value: {type(value)}")
284+
285+
286+
def convert_shape_to_ir_value(
287+
shape: Sequence[int],
288+
) -> ir.Value:
289+
# Check if all elements in the shape sequence are integers.
290+
if not shape or all(isinstance(dim, int) for dim in shape):
291+
# If all are integers, create a constant numpy array.
292+
# Assuming int32 is the required type for TFLite shape tensors.
293+
shape_ir_value = numpy_array_constant(np.array(shape, dtype=np.int32))
294+
else:
295+
# Handle mixed int and ir.Value shape sequence
296+
processed_dims = []
297+
for dim in shape:
298+
if isinstance(dim, int):
299+
# Convert int to a constant 1D tensor
300+
shape_ir_value = numpy_array_constant(np.array([dim], dtype=np.int32))
301+
processed_dims.append(shape_ir_value)
302+
else:
303+
assert isinstance(dim, ir.Value)
304+
# Convert ir.Value to a constant 1D tensor
305+
new_type = ir.RankedTensorType.get([1], dim.type.element_type)
306+
reshape_dim = stablehlo.reshape(new_type, dim)
307+
processed_dims.append(reshape_dim)
308+
309+
shape_ir_value = stablehlo.concatenate(
310+
processed_dims,
311+
dimension=0,
312+
)
313+
return shape_ir_value

0 commit comments

Comments
 (0)