Skip to content

Commit 250c498

Browse files
authored
[SYCL][libclc][CUDA] Add native math extension (#5747)
This patch extends the native math definitions in order to include builtins out of the current SYCL specification. In particular, this patch adds a ``tanh`` builtin for floats/halfs and a exp2 builtin for ``halfs`` which are mapped to instructions introduced for ``sm_75`` and above. Tests in intel/llvm-test-suite#895
1 parent 0272ec2 commit 250c498

File tree

13 files changed

+428
-17
lines changed

13 files changed

+428
-17
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.def

+8
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ BUILTIN(__nvvm_saturate_d, "dd", "")
205205
BUILTIN(__nvvm_ex2_approx_ftz_f, "ff", "")
206206
BUILTIN(__nvvm_ex2_approx_f, "ff", "")
207207
BUILTIN(__nvvm_ex2_approx_d, "dd", "")
208+
TARGET_BUILTIN(__nvvm_ex2_approx_f16, "hh", "", AND(SM_75, PTX70))
209+
TARGET_BUILTIN(__nvvm_ex2_approx_f16x2, "V2hV2h", "", AND(SM_75, PTX70))
208210

209211
BUILTIN(__nvvm_lg2_approx_ftz_f, "ff", "")
210212
BUILTIN(__nvvm_lg2_approx_f, "ff", "")
@@ -218,6 +220,12 @@ BUILTIN(__nvvm_sin_approx_f, "ff", "")
218220
BUILTIN(__nvvm_cos_approx_ftz_f, "ff", "")
219221
BUILTIN(__nvvm_cos_approx_f, "ff", "")
220222

223+
// Tanh
224+
225+
TARGET_BUILTIN(__nvvm_tanh_approx_f, "ff", "", AND(SM_75,PTX70))
226+
TARGET_BUILTIN(__nvvm_tanh_approx_f16, "hh", "", AND(SM_75, PTX70))
227+
TARGET_BUILTIN(__nvvm_tanh_approx_f16x2, "V2hV2h", "", AND(SM_75, PTX70))
228+
221229
// Fma
222230

223231
BUILTIN(__nvvm_fma_rn_ftz_f, "ffff", "")

libclc/generic/include/clcmacro.h

+27-17
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99
#ifndef __CLC_MACRO_H
1010
#define __CLC_MACRO_H
1111

12-
#define _CLC_UNARY_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE) \
13-
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE##2 x) { \
14-
return (RET_TYPE##2)(FUNCTION(x.x), FUNCTION(x.y)); \
15-
} \
16-
\
12+
#define _CLC_UNARY_VECTORIZE_HAVE2(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE) \
1713
DECLSPEC RET_TYPE##3 FUNCTION(ARG1_TYPE##3 x) { \
1814
return (RET_TYPE##3)(FUNCTION(x.x), FUNCTION(x.y), FUNCTION(x.z)); \
1915
} \
@@ -30,12 +26,14 @@
3026
return (RET_TYPE##16)(FUNCTION(x.lo), FUNCTION(x.hi)); \
3127
}
3228

33-
#define _CLC_BINARY_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
34-
ARG2_TYPE) \
35-
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE##2 x, ARG2_TYPE##2 y) { \
36-
return (RET_TYPE##2)(FUNCTION(x.x, y.x), FUNCTION(x.y, y.y)); \
29+
#define _CLC_UNARY_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE) \
30+
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE##2 x) { \
31+
return (RET_TYPE##2)(FUNCTION(x.x), FUNCTION(x.y)); \
3732
} \
38-
\
33+
_CLC_UNARY_VECTORIZE_HAVE2(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE)
34+
35+
#define _CLC_BINARY_VECTORIZE_HAVE2(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
36+
ARG2_TYPE) \
3937
DECLSPEC RET_TYPE##3 FUNCTION(ARG1_TYPE##3 x, ARG2_TYPE##3 y) { \
4038
return (RET_TYPE##3)(FUNCTION(x.x, y.x), FUNCTION(x.y, y.y), \
4139
FUNCTION(x.z, y.z)); \
@@ -53,6 +51,14 @@
5351
return (RET_TYPE##16)(FUNCTION(x.lo, y.lo), FUNCTION(x.hi, y.hi)); \
5452
}
5553

54+
#define _CLC_BINARY_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
55+
ARG2_TYPE) \
56+
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE##2 x, ARG2_TYPE##2 y) { \
57+
return (RET_TYPE##2)(FUNCTION(x.x, y.x), FUNCTION(x.y, y.y)); \
58+
} \
59+
_CLC_BINARY_VECTORIZE_HAVE2(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
60+
ARG2_TYPE)
61+
5662
#define _CLC_V_S_V_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
5763
ARG2_TYPE) \
5864
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE x, ARG2_TYPE##2 y) { \
@@ -76,13 +82,8 @@
7682
return (RET_TYPE##16)(FUNCTION(x, y.lo), FUNCTION(x, y.hi)); \
7783
}
7884

79-
#define _CLC_TERNARY_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
80-
ARG2_TYPE, ARG3_TYPE) \
81-
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE##2 x, ARG2_TYPE##2 y, \
82-
ARG3_TYPE##2 z) { \
83-
return (RET_TYPE##2)(FUNCTION(x.x, y.x, z.x), FUNCTION(x.y, y.y, z.y)); \
84-
} \
85-
\
85+
#define _CLC_TERNARY_VECTORIZE_HAVE2(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
86+
ARG2_TYPE, ARG3_TYPE) \
8687
DECLSPEC RET_TYPE##3 FUNCTION(ARG1_TYPE##3 x, ARG2_TYPE##3 y, \
8788
ARG3_TYPE##3 z) { \
8889
return (RET_TYPE##3)(FUNCTION(x.x, y.x, z.x), FUNCTION(x.y, y.y, z.y), \
@@ -107,6 +108,15 @@
107108
FUNCTION(x.hi, y.hi, z.hi)); \
108109
}
109110

111+
#define _CLC_TERNARY_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
112+
ARG2_TYPE, ARG3_TYPE) \
113+
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE##2 x, ARG2_TYPE##2 y, \
114+
ARG3_TYPE##2 z) { \
115+
return (RET_TYPE##2)(FUNCTION(x.x, y.x, z.x), FUNCTION(x.y, y.y, z.y)); \
116+
} \
117+
_CLC_TERNARY_VECTORIZE_HAVE2(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
118+
ARG2_TYPE, ARG3_TYPE)
119+
110120
#define _CLC_V_S_S_V_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
111121
ARG2_TYPE, ARG3_TYPE) \
112122
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE x, ARG2_TYPE y, ARG3_TYPE##2 z) { \

libclc/generic/include/spirv/spirv_builtins.h

+43
Original file line numberDiff line numberDiff line change
@@ -15776,6 +15776,21 @@ _CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec8_fp32_t
1577615776
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec16_fp32_t
1577715777
__spirv_ocl_native_exp2(__clc_vec16_fp32_t);
1577815778

15779+
#ifdef cl_khr_fp16
15780+
_CLC_OVERLOAD
15781+
_CLC_DECL _CLC_CONSTFN __clc_fp16_t __clc_native_exp2(__clc_fp16_t);
15782+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec2_fp16_t
15783+
__clc_native_exp2(__clc_vec2_fp16_t);
15784+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec3_fp16_t
15785+
__clc_native_exp2(__clc_vec3_fp16_t);
15786+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec4_fp16_t
15787+
__clc_native_exp2(__clc_vec4_fp16_t);
15788+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec8_fp16_t
15789+
__clc_native_exp2(__clc_vec8_fp16_t);
15790+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec16_fp16_t
15791+
__clc_native_exp2(__clc_vec16_fp16_t);
15792+
#endif
15793+
1577915794
_CLC_OVERLOAD
1578015795
_CLC_DECL _CLC_CONSTFN __clc_fp32_t __spirv_ocl_native_log(__clc_fp32_t);
1578115796
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec2_fp32_t
@@ -19077,6 +19092,34 @@ _CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec16_fp16_t
1907719092
__spirv_ocl_tanh(__clc_vec16_fp16_t);
1907819093
#endif
1907919094

19095+
_CLC_OVERLOAD
19096+
_CLC_DECL _CLC_CONSTFN __clc_fp32_t __clc_native_tanh(__clc_fp32_t);
19097+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec2_fp32_t
19098+
__clc_native_tanh(__clc_vec2_fp32_t);
19099+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec3_fp32_t
19100+
__clc_native_tanh(__clc_vec3_fp32_t);
19101+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec4_fp32_t
19102+
__clc_native_tanh(__clc_vec4_fp32_t);
19103+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec8_fp32_t
19104+
__clc_native_tanh(__clc_vec8_fp32_t);
19105+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec16_fp32_t
19106+
__clc_native_tanh(__clc_vec16_fp32_t);
19107+
19108+
#ifdef cl_khr_fp16
19109+
_CLC_OVERLOAD
19110+
_CLC_DECL _CLC_CONSTFN __clc_fp16_t __clc_native_tanh(__clc_fp16_t);
19111+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec2_fp16_t
19112+
__clc_native_tanh(__clc_vec2_fp16_t);
19113+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec3_fp16_t
19114+
__clc_native_tanh(__clc_vec3_fp16_t);
19115+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec4_fp16_t
19116+
__clc_native_tanh(__clc_vec4_fp16_t);
19117+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec8_fp16_t
19118+
__clc_native_tanh(__clc_vec8_fp16_t);
19119+
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec16_fp16_t
19120+
__clc_native_tanh(__clc_vec16_fp16_t);
19121+
#endif
19122+
1908019123
_CLC_OVERLOAD
1908119124
_CLC_DECL _CLC_CONSTFN __clc_fp32_t __spirv_ocl_tanpi(__clc_fp32_t);
1908219125
_CLC_OVERLOAD _CLC_DECL _CLC_CONSTFN __clc_vec2_fp32_t

libclc/generic/libspirv/float16.cl

+60
Original file line numberDiff line numberDiff line change
@@ -4344,6 +4344,36 @@ __spirv_ocl_exp2(__clc_vec16_float16_t args_0) {
43444344
return __spirv_ocl_exp2(as_half16(args_0));
43454345
}
43464346

4347+
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_fp16_t
4348+
__clc_native_exp2(__clc_float16_t args_0) {
4349+
return __clc_native_exp2(as_half(args_0));
4350+
}
4351+
4352+
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_vec2_fp16_t
4353+
__clc_native_exp2(__clc_vec2_float16_t args_0) {
4354+
return __clc_native_exp2(as_half2(args_0));
4355+
}
4356+
4357+
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_vec3_fp16_t
4358+
__clc_native_exp2(__clc_vec3_float16_t args_0) {
4359+
return __clc_native_exp2(as_half3(args_0));
4360+
}
4361+
4362+
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_vec4_fp16_t
4363+
__clc_native_exp2(__clc_vec4_float16_t args_0) {
4364+
return __clc_native_exp2(as_half4(args_0));
4365+
}
4366+
4367+
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_vec8_fp16_t
4368+
__clc_native_exp2(__clc_vec8_float16_t args_0) {
4369+
return __clc_native_exp2(as_half8(args_0));
4370+
}
4371+
4372+
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_vec16_fp16_t
4373+
__clc_native_exp2(__clc_vec16_float16_t args_0) {
4374+
return __clc_native_exp2(as_half16(args_0));
4375+
}
4376+
43474377
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_fp16_t
43484378
__spirv_ocl_expm1(__clc_float16_t args_0) {
43494379
return __spirv_ocl_expm1(as_half(args_0));
@@ -6613,6 +6643,36 @@ __spirv_ocl_tanh(__clc_vec16_float16_t args_0) {
66136643
return __spirv_ocl_tanh(as_half16(args_0));
66146644
}
66156645

6646+
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_fp16_t
6647+
__clc_native_tanh(__clc_float16_t args_0) {
6648+
return __clc_native_tanh(as_half(args_0));
6649+
}
6650+
6651+
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_vec2_fp16_t
6652+
__clc_native_tanh(__clc_vec2_float16_t args_0) {
6653+
return __clc_native_tanh(as_half2(args_0));
6654+
}
6655+
6656+
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_vec3_fp16_t
6657+
__clc_native_tanh(__clc_vec3_float16_t args_0) {
6658+
return __clc_native_tanh(as_half3(args_0));
6659+
}
6660+
6661+
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_vec4_fp16_t
6662+
__clc_native_tanh(__clc_vec4_float16_t args_0) {
6663+
return __clc_native_tanh(as_half4(args_0));
6664+
}
6665+
6666+
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_vec8_fp16_t
6667+
__clc_native_tanh(__clc_vec8_float16_t args_0) {
6668+
return __clc_native_tanh(as_half8(args_0));
6669+
}
6670+
6671+
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_vec16_fp16_t
6672+
__clc_native_tanh(__clc_vec16_float16_t args_0) {
6673+
return __clc_native_tanh(as_half16(args_0));
6674+
}
6675+
66166676
_CLC_OVERLOAD _CLC_DEF _CLC_CONSTFN __clc_fp16_t
66176677
__spirv_ocl_tanpi(__clc_float16_t args_0) {
66186678
return __spirv_ocl_tanpi(as_half(args_0));

libclc/ptx-nvidiacl/libspirv/SOURCES

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ math/native_rsqrt.cl
5353
math/native_sin.cl
5454
math/native_sqrt.cl
5555
math/native_tan.cl
56+
math/native_tanh.cl
5657
math/nextafter.cl
5758
math/pow.cl
5859
math/remainder.cl

libclc/ptx-nvidiacl/libspirv/math/native_exp2.cl

+30
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,34 @@
1414
#define __CLC_FUNCTION __spirv_ocl_native_exp2
1515
#define __CLC_BUILTIN __nv_exp2
1616
#define __CLC_BUILTIN_F __CLC_XCONCAT(__CLC_BUILTIN, f)
17+
18+
#ifdef cl_khr_fp16
19+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
20+
21+
int __clc_nvvm_reflect_arch();
22+
#define __USE_HALF_EXP2_APPROX (__clc_nvvm_reflect_arch() >= 750)
23+
24+
_CLC_DEF _CLC_OVERLOAD half __clc_native_exp2(half x) {
25+
return (__USE_HALF_EXP2_APPROX) ? __nvvm_ex2_approx_f16(x)
26+
: __spirv_ocl_native_exp2((float)x);
27+
}
28+
29+
_CLC_DEF _CLC_OVERLOAD half2 __clc_native_exp2(half2 x) {
30+
return (__USE_HALF_EXP2_APPROX)
31+
? __nvvm_ex2_approx_f16x2(x)
32+
: (half2)(__spirv_ocl_native_exp2((float)x.x),
33+
__spirv_ocl_native_exp2((float)x.y));
34+
}
35+
36+
_CLC_UNARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __clc_native_exp2,
37+
half)
38+
39+
#undef __USE_HALF_EXP2_APPROX
40+
41+
#endif // cl_khr_fp16
42+
43+
// Undef halfs before uncluding unary builtins, as they are handled above.
44+
#ifdef cl_khr_fp16
45+
#undef cl_khr_fp16
46+
#endif // cl_khr_fp16
1747
#include <math/unary_builtin.inc>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include <spirv/spirv.h>
9+
10+
#include "../../include/libdevice.h"
11+
#include <clcmacro.h>
12+
13+
extern int __clc_nvvm_reflect_arch();
14+
15+
#define __USE_TANH_APPROX (__clc_nvvm_reflect_arch() >= 750)
16+
17+
_CLC_DEF _CLC_OVERLOAD float __clc_native_tanh(float x) {
18+
return (__USE_TANH_APPROX) ? __nvvm_tanh_approx_f(x) : __nv_tanhf(x);
19+
}
20+
21+
_CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, float, __clc_native_tanh, float)
22+
23+
#ifdef cl_khr_fp16
24+
25+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
26+
27+
_CLC_DEF _CLC_OVERLOAD half __clc_native_tanh(half x) {
28+
return (__USE_TANH_APPROX) ? __nvvm_tanh_approx_f16(x) : __nv_tanhf(x);
29+
}
30+
31+
_CLC_DEF _CLC_OVERLOAD half2 __clc_native_tanh(half2 x) {
32+
return (__USE_TANH_APPROX) ? __nvvm_tanh_approx_f16x2(x)
33+
: (half2)(__nv_tanhf(x.x), __nv_tanhf(x.y));
34+
}
35+
36+
_CLC_UNARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __clc_native_tanh, half)
37+
38+
#endif
39+
40+
#undef __USE_TANH_APPROX
41+

llvm/include/llvm/IR/IntrinsicsNVVM.td

+11
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,17 @@ let TargetPrefix = "nvvm" in {
854854
def int_nvvm_cos_approx_f : GCCBuiltin<"__nvvm_cos_approx_f">,
855855
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
856856

857+
//
858+
// Tanh
859+
//
860+
861+
def int_nvvm_tanh_approx_f : GCCBuiltin<"__nvvm_tanh_approx_f">,
862+
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
863+
def int_nvvm_tanh_approx_f16 : GCCBuiltin<"__nvvm_tanh_approx_f16">,
864+
DefaultAttrsIntrinsic<[llvm_half_ty], [llvm_half_ty], [IntrNoMem]>;
865+
def int_nvvm_tanh_approx_f16x2 : GCCBuiltin<"__nvvm_tanh_approx_f16x2">,
866+
DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty], [IntrNoMem]>;
867+
857868
//
858869
// Fma
859870
//

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

+11
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,17 @@ def INT_NVVM_COS_APPROX_FTZ_F : F_MATH_1<"cos.approx.ftz.f32 \t$dst, $src0;",
933933
def INT_NVVM_COS_APPROX_F : F_MATH_1<"cos.approx.f32 \t$dst, $src0;",
934934
Float32Regs, Float32Regs, int_nvvm_cos_approx_f>;
935935

936+
//
937+
// Tanh
938+
//
939+
940+
def INT_NVVM_TANH_APPROX_F : F_MATH_1<"tanh.approx.f32 \t$dst, $src0;",
941+
Float32Regs, Float32Regs, int_nvvm_tanh_approx_f>;
942+
def INT_NVVM_TANH_APPROX_F16 : F_MATH_1<"tanh.approx.f16 \t$dst, $src0;",
943+
Float16Regs, Float16Regs, int_nvvm_tanh_approx_f16>;
944+
def INT_NVVM_TANH_APPROX_F16X2 : F_MATH_1<"tanh.approx.f16x2 \t$dst, $src0;",
945+
Float16x2Regs, Float16x2Regs, int_nvvm_tanh_approx_f16x2>;
946+
936947
//
937948
// Fma
938949
//

0 commit comments

Comments
 (0)