Skip to content

Add approximate gelu replacement to opt level 2 #10129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2110,6 +2110,102 @@ def call_operator(
return super().call_operator(op, args, kwargs, meta)


@register_cadence_pass(CadencePassAttribute(opt_level=2))
class ReplaceGeluWithApproximateGeluPass(ExportPass):
"""
Replace the gelu op with an approximate gelu op. The approximate gelu op
is more efficient on DSP backends.
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
exir_ops.edge.aten.gelu.default,
}:
return super().call_operator(op, args, kwargs, meta)

# compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
# as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))

# Get 0.5 * x
half = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(args[0], 0.5),
{},
meta,
)

scaled = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(args[0], 0.044715),
{},
meta,
)

# Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
# it is much more efficient on DSP backends)
scaled_square = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(scaled, args[0]),
{},
meta,
)

# Get x^3
scaled_cubed = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(scaled_square, args[0]),
{},
meta,
)

# Get x + 0.044715 * x^3
inner_sum = super().call_operator(
exir_ops.edge.aten.add.Tensor,
(scaled_cubed, args[0]),
{},
meta,
)

# Get 0.7978845608028654 * ( x + 0.044715 * x^3)
scaled_sum = super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(inner_sum, 0.7978845608028654),
{},
meta,
)

# Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
tanh = super().call_operator(
exir_ops.edge.aten.tanh.default,
(scaled_sum,),
{},
meta,
)

# Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
# TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
outer_sum = super().call_operator(
exir_ops.edge.aten.add.Tensor,
(tanh, 1.0),
{},
meta,
)

# Retunr the final result
return super().call_operator(
exir_ops.edge.aten.mul.Tensor,
(half, outer_sum),
{},
meta,
)


# This class encapsulates all the functions that replace/switch one op in the
# graph with another.
class CadenceReplaceOpsInGraph:
Expand Down Expand Up @@ -2149,4 +2245,5 @@ class CadenceReplaceOpsInGraph:
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
ReplaceWhereWithFullArgsWithWhereScalar,
# ReplaceGeluWithApproximateGeluPass,
]
36 changes: 36 additions & 0 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ReplaceConvWithIm2RowAndLinear,
ReplaceEmptyTensorsWithFullPass,
ReplaceFunctionallyEquivalentOpTargets,
ReplaceGeluWithApproximateGeluPass,
ReplaceIm2RowWithViewPass,
ReplaceLinearWithFullyConnectedOpPass,
ReplaceMMWithAddMMPass,
Expand Down Expand Up @@ -1301,6 +1302,41 @@ def forward(self, cond: torch.Tensor):
1,
)

def test_replace_aten_gelu_with_approximate_gelu(self):
class Gelu(torch.nn.Module):
def forward(self, input):
return torch.nn.functional.gelu(input)

inputs = torch.randn(2, 1, 64)

graph_module = export_to_edge(Gelu(), (inputs,)).exported_program().graph_module

p = ReplaceGeluWithApproximateGeluPass()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module

# Assert that aten.gelu op was decomposed
self.assertEqual(
count_node(
graph_after_passes,
exir_ops.edge.aten.gelu.default,
),
0,
)

# The decomposition should have one tanh, 2 add and 6 mul
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.tanh.default),
1,
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
2,
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
6,
)


class TestReplaceIm2rowWithViewPass(unittest.TestCase):
def test_no_replacement_for_conv(self):
Expand Down
Loading