@@ -2110,24 +2110,28 @@ def _make_elementwise_unary_prim(
21102110asin = _make_elementwise_unary_prim (
21112111 PrimIDs .ASIN ,
21122112 "asin" ,
2113+ number_fn = math .asin ,
21132114 supported_input_dtypes = fp_math_dtypes ,
21142115)
21152116
21162117asinh = _make_elementwise_unary_prim (
21172118 PrimIDs .ASINH ,
21182119 "asinh" ,
2120+ number_fn = math .asinh ,
21192121 supported_input_dtypes = fp_math_dtypes ,
21202122)
21212123
21222124atan = _make_elementwise_unary_prim (
21232125 PrimIDs .ATAN ,
21242126 "atan" ,
2127+ number_fn = math .atan ,
21252128 supported_input_dtypes = fp_math_dtypes ,
21262129)
21272130
21282131atanh = _make_elementwise_unary_prim (
21292132 PrimIDs .ATANH ,
21302133 "atanh" ,
2134+ number_fn = math .atanh ,
21312135 supported_input_dtypes = fp_math_dtypes ,
21322136)
21332137
@@ -2153,24 +2157,28 @@ def _make_elementwise_unary_prim(
21532157cos = _make_elementwise_unary_prim (
21542158 PrimIDs .COS ,
21552159 "cos" ,
2160+ number_fn = math .cos ,
21562161 supported_input_dtypes = fp_math_dtypes ,
21572162)
21582163
21592164cosh = _make_elementwise_unary_prim (
21602165 PrimIDs .COSH ,
21612166 "cosh" ,
2167+ number_fn = math .cosh ,
21622168 supported_input_dtypes = fp_math_dtypes ,
21632169)
21642170
21652171erf = _make_elementwise_unary_prim (
21662172 PrimIDs .ERF ,
21672173 "erf" ,
2174+ number_fn = math .erf ,
21682175 supported_input_dtypes = fp_math_dtypes ,
21692176)
21702177
21712178erfc = _make_elementwise_unary_prim (
21722179 PrimIDs .ERFC ,
21732180 "erfc" ,
2181+ number_fn = math .erfc ,
21742182 supported_input_dtypes = fp_math_dtypes ,
21752183)
21762184
@@ -2193,15 +2201,24 @@ def _make_elementwise_unary_prim(
21932201 supported_input_dtypes = fp_math_dtypes ,
21942202)
21952203
2204+
2205+ def _exp2_number (a : Number ) -> Number :
2206+ if hasattr (math , "exp2" ):
2207+ return math .exp2 (a )
2208+ return 2 ** a
2209+
2210+
21962211exp2 = _make_elementwise_unary_prim (
21972212 PrimIDs .EXP2 ,
21982213 "exp2" ,
2214+ number_fn = _exp2_number ,
21992215 supported_input_dtypes = fp_math_dtypes ,
22002216)
22012217
22022218expm1 = _make_elementwise_unary_prim (
22032219 PrimIDs .EXPM1 ,
22042220 "expm1" ,
2221+ number_fn = math .expm1 ,
22052222 supported_input_dtypes = fp_math_dtypes ,
22062223)
22072224
@@ -2233,24 +2250,28 @@ def _make_elementwise_unary_prim(
22332250log = _make_elementwise_unary_prim (
22342251 PrimIDs .LOG ,
22352252 "log" ,
2253+ number_fn = math .log ,
22362254 supported_input_dtypes = fp_math_dtypes ,
22372255)
22382256
22392257log10 = _make_elementwise_unary_prim (
22402258 PrimIDs .LOG10 ,
22412259 "log10" ,
2260+ number_fn = math .log10 ,
22422261 supported_input_dtypes = fp_math_dtypes ,
22432262)
22442263
22452264log1p = _make_elementwise_unary_prim (
22462265 PrimIDs .LOG1P ,
22472266 "log1p" ,
2267+ number_fn = math .log1p ,
22482268 supported_input_dtypes = fp_math_dtypes ,
22492269)
22502270
22512271log2 = _make_elementwise_unary_prim (
22522272 PrimIDs .LOG2 ,
22532273 "log2" ,
2274+ number_fn = math .log2 ,
22542275 supported_input_dtypes = fp_math_dtypes ,
22552276)
22562277
@@ -2317,30 +2338,35 @@ def _signbit_number(a: Number) -> bool:
23172338sin = _make_elementwise_unary_prim (
23182339 PrimIDs .SIN ,
23192340 "sin" ,
2341+ number_fn = math .sin ,
23202342 supported_input_dtypes = fp_math_dtypes ,
23212343)
23222344
23232345sinh = _make_elementwise_unary_prim (
23242346 PrimIDs .SINH ,
23252347 "sinh" ,
2348+ number_fn = math .sinh ,
23262349 supported_input_dtypes = fp_math_dtypes ,
23272350)
23282351
23292352sqrt = _make_elementwise_unary_prim (
23302353 PrimIDs .SQRT ,
23312354 "sqrt" ,
2355+ number_fn = math .sqrt ,
23322356 supported_input_dtypes = fp_math_dtypes ,
23332357)
23342358
23352359tan = _make_elementwise_unary_prim (
23362360 PrimIDs .TAN ,
23372361 "tan" ,
2362+ number_fn = math .tan ,
23382363 supported_input_dtypes = fp_math_dtypes ,
23392364)
23402365
23412366tanh = _make_elementwise_unary_prim (
23422367 PrimIDs .TANH ,
23432368 "tanh" ,
2369+ number_fn = math .tanh ,
23442370 supported_input_dtypes = fp_math_dtypes ,
23452371)
23462372
0 commit comments