Skip to content

Commit 31d3545

Browse files
authored
Arm backend: Add CEIL Operator (#9267)
Arm backend: support for CEIL op - Update unary operator factory with CEIL op - Rename and refactor test_floor to handle similar ops Signed-off-by: Madeleine Dunn <[email protected]>
1 parent 88d9616 commit 31d3545

File tree

6 files changed

+157
-82
lines changed

6 files changed

+157
-82
lines changed

backends/arm/_passes/insert_table_ops.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class InsertTableOpsPass(ExportPass):
4141
"""
4242

4343
table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
44+
exir_ops.edge.aten.ceil.default: torch.ceil,
4445
exir_ops.edge.aten.exp.default: torch.exp,
4546
exir_ops.edge.aten.floor.default: torch.floor,
4647
exir_ops.edge.aten.log.default: torch.log,

backends/arm/operator_support/tosa_supported_operators.py

+1
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def is_node_supported(
147147
exir_ops.edge.aten.bitwise_xor.Tensor,
148148
exir_ops.edge.aten.expand_copy.default,
149149
exir_ops.edge.aten.cat.default,
150+
exir_ops.edge.aten.ceil.default,
150151
exir_ops.edge.aten.clamp.default,
151152
exir_ops.edge.aten.bmm.default,
152153
exir_ops.edge.aten.permute_copy.default,

backends/arm/operators/ops_unary.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,6 @@ def define_node(
5353
register_node_visitor(UnaryOperator)
5454

5555

56+
unary_operator_factory("aten.ceil.default", TosaOp.Op().CEIL)
5657
unary_operator_factory("aten.floor.default", TosaOp.Op().FLOOR)
5758
unary_operator_factory("aten.logical_not.default", TosaOp.Op().LOGICAL_NOT)

backends/arm/quantizer/quantization_annotator.py

+1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def _match_pattern(
126126

127127
_one_to_one = [
128128
torch.ops.aten.abs.default,
129+
torch.ops.aten.ceil.default,
129130
torch.ops.aten.exp.default,
130131
torch.ops.aten.floor.default,
131132
torch.ops.aten.log.default,

backends/arm/test/ops/test_floor.py

-82
This file was deleted.

backends/arm/test/ops/test_unary.py

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Tuple
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
19+
input_t1 = Tuple[torch.Tensor] # Input x
20+
21+
22+
class Ceil(torch.nn.Module):
23+
def forward(self, x: torch.Tensor):
24+
return torch.ceil(x)
25+
26+
op_name = "ceil"
27+
aten_op = "torch.ops.aten.ceil.default"
28+
exir_op = "executorch_exir_dialects_edge__ops_aten_ceil_default"
29+
30+
31+
class Floor(torch.nn.Module):
32+
def forward(self, x: torch.Tensor):
33+
return torch.floor(x)
34+
35+
op_name = "floor"
36+
aten_op = "torch.ops.aten.floor.default"
37+
exir_op = "executorch_exir_dialects_edge__ops_aten_floor_default"
38+
39+
40+
zeros = torch.zeros(1, 10, 10, 10)
41+
ones = torch.ones(10, 10, 10)
42+
rand = torch.rand(10, 10) - 0.5
43+
randn_pos = torch.randn(1, 4, 4, 4) + 10
44+
randn_neg = torch.randn(1, 4, 4, 4) - 10
45+
ramp = torch.arange(-16, 16, 0.2)
46+
47+
48+
test_data = {
49+
"ceil_zeros": (
50+
Ceil(),
51+
zeros,
52+
),
53+
"floor_zeros": (
54+
Floor(),
55+
zeros,
56+
),
57+
"ceil_ones": (
58+
Ceil(),
59+
ones,
60+
),
61+
"floor_ones": (
62+
Floor(),
63+
ones,
64+
),
65+
"ceil_rand": (
66+
Ceil(),
67+
rand,
68+
),
69+
"floor_rand": (
70+
Floor(),
71+
rand,
72+
),
73+
"ceil_randn_pos": (
74+
Ceil(),
75+
randn_pos,
76+
),
77+
"floor_randn_pos": (
78+
Floor(),
79+
randn_pos,
80+
),
81+
"ceil_randn_neg": (
82+
Ceil(),
83+
randn_neg,
84+
),
85+
"floor_randn_neg": (
86+
Floor(),
87+
randn_neg,
88+
),
89+
"ceil_ramp": (
90+
Ceil(),
91+
ramp,
92+
),
93+
"floor_ramp": (
94+
Floor(),
95+
ramp,
96+
),
97+
}
98+
99+
100+
@common.parametrize("test_data", test_data)
101+
def test_unary_tosa_MI(test_data: input_t1):
102+
module = test_data[0]
103+
pipeline = TosaPipelineMI[input_t1](
104+
module, (test_data[1],), module.aten_op, module.exir_op
105+
)
106+
pipeline.run()
107+
108+
109+
@common.parametrize("test_data", test_data)
110+
def test_unary_tosa_BI(test_data: input_t1):
111+
module = test_data[0]
112+
pipeline = TosaPipelineBI[input_t1](
113+
module, (test_data[1],), module.aten_op, module.exir_op
114+
)
115+
pipeline.run()
116+
117+
118+
@common.parametrize("test_data", test_data)
119+
def test_unary_u55_BI(test_data: input_t1):
120+
module = test_data[0]
121+
pipeline = EthosU55PipelineBI[input_t1](
122+
module, (test_data[1],), module.aten_op, module.exir_op, run_on_fvp=False
123+
)
124+
pipeline.run()
125+
126+
127+
@common.parametrize("test_data", test_data)
128+
def test_unary_u85_BI(test_data: input_t1):
129+
module = test_data[0]
130+
pipeline = EthosU85PipelineBI[input_t1](
131+
module, (test_data[1],), module.aten_op, module.exir_op, run_on_fvp=False
132+
)
133+
pipeline.run()
134+
135+
136+
@common.parametrize("test_data", test_data)
137+
@common.SkipIfNoCorstone300
138+
def test_unary_u55_BI_on_fvp(test_data: input_t1):
139+
module = test_data[0]
140+
pipeline = EthosU55PipelineBI[input_t1](
141+
module, (test_data[1],), module.aten_op, module.exir_op, run_on_fvp=True
142+
)
143+
pipeline.run()
144+
145+
146+
@common.parametrize("test_data", test_data)
147+
@common.SkipIfNoCorstone320
148+
def test_unary_u85_BI_on_fvp(test_data: input_t1):
149+
module = test_data[0]
150+
pipeline = EthosU85PipelineBI[input_t1](
151+
module, (test_data[1],), module.aten_op, module.exir_op, run_on_fvp=True
152+
)
153+
pipeline.run()

0 commit comments

Comments
 (0)