Skip to content

Commit 92bdda8

Browse files
Add number_fn to unary prims (#1918)
1 parent b30e22d commit 92bdda8

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

thunder/core/prims.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2110,24 +2110,28 @@ def _make_elementwise_unary_prim(
21102110
asin = _make_elementwise_unary_prim(
21112111
PrimIDs.ASIN,
21122112
"asin",
2113+
number_fn=math.asin,
21132114
supported_input_dtypes=fp_math_dtypes,
21142115
)
21152116

21162117
asinh = _make_elementwise_unary_prim(
21172118
PrimIDs.ASINH,
21182119
"asinh",
2120+
number_fn=math.asinh,
21192121
supported_input_dtypes=fp_math_dtypes,
21202122
)
21212123

21222124
atan = _make_elementwise_unary_prim(
21232125
PrimIDs.ATAN,
21242126
"atan",
2127+
number_fn=math.atan,
21252128
supported_input_dtypes=fp_math_dtypes,
21262129
)
21272130

21282131
atanh = _make_elementwise_unary_prim(
21292132
PrimIDs.ATANH,
21302133
"atanh",
2134+
number_fn=math.atanh,
21312135
supported_input_dtypes=fp_math_dtypes,
21322136
)
21332137

@@ -2153,24 +2157,28 @@ def _make_elementwise_unary_prim(
21532157
cos = _make_elementwise_unary_prim(
21542158
PrimIDs.COS,
21552159
"cos",
2160+
number_fn=math.cos,
21562161
supported_input_dtypes=fp_math_dtypes,
21572162
)
21582163

21592164
cosh = _make_elementwise_unary_prim(
21602165
PrimIDs.COSH,
21612166
"cosh",
2167+
number_fn=math.cosh,
21622168
supported_input_dtypes=fp_math_dtypes,
21632169
)
21642170

21652171
erf = _make_elementwise_unary_prim(
21662172
PrimIDs.ERF,
21672173
"erf",
2174+
number_fn=math.erf,
21682175
supported_input_dtypes=fp_math_dtypes,
21692176
)
21702177

21712178
erfc = _make_elementwise_unary_prim(
21722179
PrimIDs.ERFC,
21732180
"erfc",
2181+
number_fn=math.erfc,
21742182
supported_input_dtypes=fp_math_dtypes,
21752183
)
21762184

@@ -2193,15 +2201,24 @@ def _make_elementwise_unary_prim(
21932201
supported_input_dtypes=fp_math_dtypes,
21942202
)
21952203

2204+
2205+
def _exp2_number(a: Number) -> Number:
2206+
if hasattr(math, "exp2"):
2207+
return math.exp2(a)
2208+
return 2**a
2209+
2210+
21962211
exp2 = _make_elementwise_unary_prim(
21972212
PrimIDs.EXP2,
21982213
"exp2",
2214+
number_fn=_exp2_number,
21992215
supported_input_dtypes=fp_math_dtypes,
22002216
)
22012217

22022218
expm1 = _make_elementwise_unary_prim(
22032219
PrimIDs.EXPM1,
22042220
"expm1",
2221+
number_fn=math.expm1,
22052222
supported_input_dtypes=fp_math_dtypes,
22062223
)
22072224

@@ -2233,24 +2250,28 @@ def _make_elementwise_unary_prim(
22332250
log = _make_elementwise_unary_prim(
22342251
PrimIDs.LOG,
22352252
"log",
2253+
number_fn=math.log,
22362254
supported_input_dtypes=fp_math_dtypes,
22372255
)
22382256

22392257
log10 = _make_elementwise_unary_prim(
22402258
PrimIDs.LOG10,
22412259
"log10",
2260+
number_fn=math.log10,
22422261
supported_input_dtypes=fp_math_dtypes,
22432262
)
22442263

22452264
log1p = _make_elementwise_unary_prim(
22462265
PrimIDs.LOG1P,
22472266
"log1p",
2267+
number_fn=math.log1p,
22482268
supported_input_dtypes=fp_math_dtypes,
22492269
)
22502270

22512271
log2 = _make_elementwise_unary_prim(
22522272
PrimIDs.LOG2,
22532273
"log2",
2274+
number_fn=math.log2,
22542275
supported_input_dtypes=fp_math_dtypes,
22552276
)
22562277

@@ -2317,30 +2338,35 @@ def _signbit_number(a: Number) -> bool:
23172338
sin = _make_elementwise_unary_prim(
23182339
PrimIDs.SIN,
23192340
"sin",
2341+
number_fn=math.sin,
23202342
supported_input_dtypes=fp_math_dtypes,
23212343
)
23222344

23232345
sinh = _make_elementwise_unary_prim(
23242346
PrimIDs.SINH,
23252347
"sinh",
2348+
number_fn=math.sinh,
23262349
supported_input_dtypes=fp_math_dtypes,
23272350
)
23282351

23292352
sqrt = _make_elementwise_unary_prim(
23302353
PrimIDs.SQRT,
23312354
"sqrt",
2355+
number_fn=math.sqrt,
23322356
supported_input_dtypes=fp_math_dtypes,
23332357
)
23342358

23352359
tan = _make_elementwise_unary_prim(
23362360
PrimIDs.TAN,
23372361
"tan",
2362+
number_fn=math.tan,
23382363
supported_input_dtypes=fp_math_dtypes,
23392364
)
23402365

23412366
tanh = _make_elementwise_unary_prim(
23422367
PrimIDs.TANH,
23432368
"tanh",
2369+
number_fn=math.tanh,
23442370
supported_input_dtypes=fp_math_dtypes,
23452371
)
23462372

0 commit comments

Comments
 (0)