Skip to content

Commit a23d0c1

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add approximate gelu replacement to opt level 2 (#10129)
Summary: Pull Request resolved: #10129 As titled. Gelu is prohibitively expensive to run on DSPs, due to the std::erf call in the function. The PT approximate version using an approximation based on `tanh`, which is faster on the ASR encoder 27M model for example. Seems like BUCK files (even with just on_call commands, the linter is complaining). Differential Revision: D72935935
1 parent 4022ff1 commit a23d0c1

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

Diff for: backends/cadence/aot/replace_ops.py

+97
Original file line numberDiff line numberDiff line change
@@ -2110,6 +2110,102 @@ def call_operator(
21102110
return super().call_operator(op, args, kwargs, meta)
21112111

21122112

2113+
@register_cadence_pass(CadencePassAttribute(opt_level=2))
2114+
class ReplaceGeluWithApproximateGeluPass(ExportPass):
2115+
"""
2116+
Replace the gelu op with an approximate gelu op. The approximate gelu op
2117+
is more efficient on DSP backends.
2118+
"""
2119+
2120+
def call_operator(
2121+
self,
2122+
op,
2123+
args: Tuple[Argument, ...],
2124+
kwargs: Dict[str, Argument],
2125+
meta: NodeMetadata,
2126+
) -> ProxyValue:
2127+
if op not in {
2128+
exir_ops.edge.aten.gelu.default,
2129+
}:
2130+
return super().call_operator(op, args, kwargs, meta)
2131+
2132+
# compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
2133+
# as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))
2134+
2135+
# Get 0.5 * x
2136+
half = super().call_operator(
2137+
exir_ops.edge.aten.mul.Tensor,
2138+
(args[0], 0.5),
2139+
{},
2140+
meta,
2141+
)
2142+
2143+
scaled = super().call_operator(
2144+
exir_ops.edge.aten.mul.Tensor,
2145+
(args[0], 0.044715),
2146+
{},
2147+
meta,
2148+
)
2149+
2150+
# Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
2151+
# it is much more efficient on DSP backends)
2152+
scaled_square = super().call_operator(
2153+
exir_ops.edge.aten.mul.Tensor,
2154+
(scaled, args[0]),
2155+
{},
2156+
meta,
2157+
)
2158+
2159+
# Get x^3
2160+
scaled_cubed = super().call_operator(
2161+
exir_ops.edge.aten.mul.Tensor,
2162+
(scaled_square, args[0]),
2163+
{},
2164+
meta,
2165+
)
2166+
2167+
# Get x + 0.044715 * x^3
2168+
inner_sum = super().call_operator(
2169+
exir_ops.edge.aten.add.Tensor,
2170+
(scaled_cubed, args[0]),
2171+
{},
2172+
meta,
2173+
)
2174+
2175+
# Get 0.7978845608028654 * ( x + 0.044715 * x^3)
2176+
scaled_sum = super().call_operator(
2177+
exir_ops.edge.aten.mul.Tensor,
2178+
(inner_sum, 0.7978845608028654),
2179+
{},
2180+
meta,
2181+
)
2182+
2183+
# Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
2184+
tanh = super().call_operator(
2185+
exir_ops.edge.aten.tanh.default,
2186+
(scaled_sum,),
2187+
{},
2188+
meta,
2189+
)
2190+
2191+
# Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
2192+
# TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
2193+
outer_sum = super().call_operator(
2194+
exir_ops.edge.aten.add.Tensor,
2195+
(tanh, 1.0),
2196+
{},
2197+
meta,
2198+
)
2199+
2200+
# Retunr the final result
2201+
return super().call_operator(
2202+
exir_ops.edge.aten.mul.Tensor,
2203+
(half, outer_sum),
2204+
{},
2205+
meta,
2206+
)
2207+
2208+
21132209
# This class encapsulates all the functions that replace/switch one op in the
21142210
# graph with another.
21152211
class CadenceReplaceOpsInGraph:
@@ -2149,4 +2245,5 @@ class CadenceReplaceOpsInGraph:
21492245
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
21502246
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
21512247
ReplaceWhereWithFullArgsWithWhereScalar,
2248+
# ReplaceGeluWithApproximateGeluPass,
21522249
]

Diff for: backends/cadence/aot/tests/test_replace_ops_passes.py

+36
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ReplaceConvWithIm2RowAndLinear,
3030
ReplaceEmptyTensorsWithFullPass,
3131
ReplaceFunctionallyEquivalentOpTargets,
32+
ReplaceGeluWithApproximateGeluPass,
3233
ReplaceIm2RowWithViewPass,
3334
ReplaceLinearWithFullyConnectedOpPass,
3435
ReplaceMMWithAddMMPass,
@@ -1301,6 +1302,41 @@ def forward(self, cond: torch.Tensor):
13011302
1,
13021303
)
13031304

1305+
def test_replace_aten_gelu_with_approximate_gelu(self):
1306+
class Gelu(torch.nn.Module):
1307+
def forward(self, input):
1308+
return torch.nn.functional.gelu(input)
1309+
1310+
inputs = torch.randn(2, 1, 64)
1311+
1312+
graph_module = export_to_edge(Gelu(), (inputs,)).exported_program().graph_module
1313+
1314+
p = ReplaceGeluWithApproximateGeluPass()
1315+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1316+
1317+
# Assert that aten.gelu op was decomposed
1318+
self.assertEqual(
1319+
count_node(
1320+
graph_after_passes,
1321+
exir_ops.edge.aten.gelu.default,
1322+
),
1323+
0,
1324+
)
1325+
1326+
# The decomposition should have one tanh, 2 add and 6 mul
1327+
self.assertEqual(
1328+
count_node(graph_after_passes, exir_ops.edge.aten.tanh.default),
1329+
1,
1330+
)
1331+
self.assertEqual(
1332+
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
1333+
2,
1334+
)
1335+
self.assertEqual(
1336+
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
1337+
6,
1338+
)
1339+
13041340

13051341
class TestReplaceIm2rowWithViewPass(unittest.TestCase):
13061342
def test_no_replacement_for_conv(self):

0 commit comments

Comments
 (0)