@@ -2110,6 +2110,102 @@ def call_operator(
2110
2110
return super ().call_operator (op , args , kwargs , meta )
2111
2111
2112
2112
2113
+ @register_cadence_pass (CadencePassAttribute (opt_level = 2 ))
2114
+ class ReplaceGeluWithApproximateGeluPass (ExportPass ):
2115
+ """
2116
+ Replace the gelu op with an approximate gelu op. The approximate gelu op
2117
+ is more efficient on DSP backends.
2118
+ """
2119
+
2120
+ def call_operator (
2121
+ self ,
2122
+ op ,
2123
+ args : Tuple [Argument , ...],
2124
+ kwargs : Dict [str , Argument ],
2125
+ meta : NodeMetadata ,
2126
+ ) -> ProxyValue :
2127
+ if op not in {
2128
+ exir_ops .edge .aten .gelu .default ,
2129
+ }:
2130
+ return super ().call_operator (op , args , kwargs , meta )
2131
+
2132
+ # compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
2133
+ # as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))
2134
+
2135
+ # Get 0.5 * x
2136
+ half = super ().call_operator (
2137
+ exir_ops .edge .aten .mul .Tensor ,
2138
+ (args [0 ], 0.5 ),
2139
+ {},
2140
+ meta ,
2141
+ )
2142
+
2143
+ scaled = super ().call_operator (
2144
+ exir_ops .edge .aten .mul .Tensor ,
2145
+ (args [0 ], 0.044715 ),
2146
+ {},
2147
+ meta ,
2148
+ )
2149
+
2150
+ # Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
2151
+ # it is much more efficient on DSP backends)
2152
+ scaled_square = super ().call_operator (
2153
+ exir_ops .edge .aten .mul .Tensor ,
2154
+ (scaled , args [0 ]),
2155
+ {},
2156
+ meta ,
2157
+ )
2158
+
2159
+ # Get x^3
2160
+ scaled_cubed = super ().call_operator (
2161
+ exir_ops .edge .aten .mul .Tensor ,
2162
+ (scaled_square , args [0 ]),
2163
+ {},
2164
+ meta ,
2165
+ )
2166
+
2167
+ # Get x + 0.044715 * x^3
2168
+ inner_sum = super ().call_operator (
2169
+ exir_ops .edge .aten .add .Tensor ,
2170
+ (scaled_cubed , args [0 ]),
2171
+ {},
2172
+ meta ,
2173
+ )
2174
+
2175
+ # Get 0.7978845608028654 * ( x + 0.044715 * x^3)
2176
+ scaled_sum = super ().call_operator (
2177
+ exir_ops .edge .aten .mul .Tensor ,
2178
+ (inner_sum , 0.7978845608028654 ),
2179
+ {},
2180
+ meta ,
2181
+ )
2182
+
2183
+ # Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
2184
+ tanh = super ().call_operator (
2185
+ exir_ops .edge .aten .tanh .default ,
2186
+ (scaled_sum ,),
2187
+ {},
2188
+ meta ,
2189
+ )
2190
+
2191
+ # Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
2192
+ # TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
2193
+ outer_sum = super ().call_operator (
2194
+ exir_ops .edge .aten .add .Tensor ,
2195
+ (tanh , 1.0 ),
2196
+ {},
2197
+ meta ,
2198
+ )
2199
+
2200
+ # Retunr the final result
2201
+ return super ().call_operator (
2202
+ exir_ops .edge .aten .mul .Tensor ,
2203
+ (half , outer_sum ),
2204
+ {},
2205
+ meta ,
2206
+ )
2207
+
2208
+
2113
2209
# This class encapsulates all the functions that replace/switch one op in the
2114
2210
# graph with another.
2115
2211
class CadenceReplaceOpsInGraph :
@@ -2149,4 +2245,5 @@ class CadenceReplaceOpsInGraph:
2149
2245
ReplaceAtenAvgPoolWithJarvisAvgPoolPass ,
2150
2246
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass ,
2151
2247
ReplaceWhereWithFullArgsWithWhereScalar ,
2248
+ # ReplaceGeluWithApproximateGeluPass,
2152
2249
]
0 commit comments