Skip to content

[SYCL] Add marray support to common + some math functions #8631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
182 changes: 182 additions & 0 deletions sycl/include/sycl/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,27 @@ __SYCL_MATH_FUNCTION_OVERLOAD_FM(log2)
__SYCL_MATH_FUNCTION_OVERLOAD_FM(log10)
__SYCL_MATH_FUNCTION_OVERLOAD_FM(sqrt)
__SYCL_MATH_FUNCTION_OVERLOAD_FM(rsqrt)
__SYCL_MATH_FUNCTION_OVERLOAD_FM(fabs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to outline such tiny additions into a separate PR, because we can merge them almost right away.


#undef __SYCL_MATH_FUNCTION_OVERLOAD_FM
#undef __SYCL_MATH_FUNCTION_OVERLOAD_IMPL

template <typename T, size_t N>
inline __SYCL_ALWAYS_INLINE
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<int, N>>
ilogb(marray<T, N> x) __NOEXC {
marray<int, N> res;
for (size_t i = 0; i < N / 2; i++) {
vec<int, 2> partial_res =
__sycl_std::__invoke_ilogb<vec<int, 2>>(detail::to_vec2(x, i * 2));
std::memcpy(&res[i * 2], &partial_res, sizeof(vec<int, 2>));
}
if (N % 2) {
res[N - 1] = __sycl_std::__invoke_ilogb<int>(x[N - 1]);
}
return res;
}

#define __SYCL_MATH_FUNCTION_2_OVERLOAD_IMPL(NAME) \
marray<T, N> res; \
for (size_t i = 0; i < N / 2; i++) { \
Expand Down Expand Up @@ -170,6 +187,98 @@ inline __SYCL_ALWAYS_INLINE

#undef __SYCL_MATH_FUNCTION_2_OVERLOAD_IMPL

#define __SYCL_MATH_FUNCTION_2_SGENFLOAT_Y_OVERLOAD(NAME) \
template <typename T, size_t N> \
inline __SYCL_ALWAYS_INLINE \
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>> \
NAME(marray<T, N> x, T y) __NOEXC { \
marray<T, N> res; \
sycl::vec<T, 2> y_vec{y, y}; \
for (size_t i = 0; i < N / 2; i++) { \
auto partial_res = __sycl_std::__invoke_##NAME<vec<T, 2>>( \
detail::to_vec2(x, i * 2), y_vec); \
std::memcpy(&res[i * 2], &partial_res, sizeof(vec<T, 2>)); \
} \
if (N % 2) { \
res[N - 1] = __sycl_std::__invoke_##NAME<T>(x[N - 1], y_vec[0]); \
} \
return res; \
}

__SYCL_MATH_FUNCTION_2_SGENFLOAT_Y_OVERLOAD(fmax)
// clang-format off
__SYCL_MATH_FUNCTION_2_SGENFLOAT_Y_OVERLOAD(fmin)

#undef __SYCL_MATH_FUNCTION_2_SGENFLOAT_Y_OVERLOAD

template <typename T, size_t N>
inline __SYCL_ALWAYS_INLINE
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>>
ldexp(marray<T, N> x, marray<int, N> k) __NOEXC {
// clang-format on
marray<T, N> res;
for (size_t i = 0; i < N; i++) {
res[i] = __sycl_std::__invoke_ldexp<T>(x[i], k[i]);
}
return res;
}

template <typename T, size_t N>
inline __SYCL_ALWAYS_INLINE
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>>
ldexp(marray<T, N> x, int k) __NOEXC {
marray<T, N> res;
for (size_t i = 0; i < N; i++) {
res[i] = __sycl_std::__invoke_ldexp<T>(x[i], k);
}
return res;
}

#define __SYCL_MATH_FUNCTION_2_GENINT_Y_OVERLOAD_IMPL(NAME) \
marray<T, N> res; \
for (size_t i = 0; i < N; i++) { \
res[i] = __sycl_std::__invoke_##NAME<T>(x[i], y[i]); \
} \
return res;

template <typename T, size_t N>
inline __SYCL_ALWAYS_INLINE
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>>
pown(marray<T, N> x, marray<int, N> y) __NOEXC {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we can add one more high-level macro to generalize definitions of pown, rootn and ldexp. Similar to __SYCL_MATH_FUNCTION_2_OVERLOAD(atan2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically yes, the reason I didn't do this is to stick to code style in this part of the file.

__SYCL_MATH_FUNCTION_2_GENINT_Y_OVERLOAD_IMPL(pown)
}

template <typename T, size_t N>
inline __SYCL_ALWAYS_INLINE
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>>
rootn(marray<T, N> x, marray<int, N> y) __NOEXC {
__SYCL_MATH_FUNCTION_2_GENINT_Y_OVERLOAD_IMPL(rootn)
}

#undef __SYCL_MATH_FUNCTION_2_GENINT_Y_OVERLOAD_IMPL

#define __SYCL_MATH_FUNCTION_2_INT_Y_OVERLOAD_IMPL(NAME) \
marray<T, N> res; \
for (size_t i = 0; i < N; i++) { \
res[i] = __sycl_std::__invoke_##NAME<T>(x[i], y); \
} \
return res;

template <typename T, size_t N>
inline __SYCL_ALWAYS_INLINE
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>>
pown(marray<T, N> x, int y) __NOEXC {
__SYCL_MATH_FUNCTION_2_INT_Y_OVERLOAD_IMPL(pown)
}

template <typename T, size_t N>
inline __SYCL_ALWAYS_INLINE
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>>
rootn(marray<T, N> x,
int y) __NOEXC{__SYCL_MATH_FUNCTION_2_INT_Y_OVERLOAD_IMPL(rootn)}

#undef __SYCL_MATH_FUNCTION_2_INT_Y_OVERLOAD_IMPL

#define __SYCL_MATH_FUNCTION_3_OVERLOAD(NAME) \
template <typename T, size_t N> \
inline __SYCL_ALWAYS_INLINE \
Expand Down Expand Up @@ -789,6 +898,78 @@ detail::enable_if_t<detail::is_svgenfloat<T>::value, T> sign(T x) __NOEXC {
return __sycl_std::__invoke_sign<T>(x);
}

// marray common functions

// TODO: can be optimized in the way math functions are optimized (usage of
// vec<T, 2>)
#define __SYCL_MARRAY_COMMON_FUNCTION_OVERLOAD_IMPL(NAME, ...) \
T res; \
for (int i = 0; i < T::size(); i++) { \
res[i] = NAME(__VA_ARGS__); \
} \
return res;

#define __SYCL_MARRAY_COMMON_FUNCTION_UNOP_OVERLOAD(NAME, ARG, ...) \
template <typename T, \
typename = std::enable_if_t<detail::is_mgenfloat<T>::value>> \
T NAME(ARG) __NOEXC { \
__SYCL_MARRAY_COMMON_FUNCTION_OVERLOAD_IMPL(NAME, __VA_ARGS__) \
}

__SYCL_MARRAY_COMMON_FUNCTION_UNOP_OVERLOAD(degrees, T radians, radians[i])
__SYCL_MARRAY_COMMON_FUNCTION_UNOP_OVERLOAD(radians, T degrees, degrees[i])
__SYCL_MARRAY_COMMON_FUNCTION_UNOP_OVERLOAD(sign, T x, x[i])

#undef __SYCL_MARRAY_COMMON_FUNCTION_UNOP_OVERLOAD

#define __SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(NAME, ARG1, ARG2, ...) \
template <typename T, \
typename = std::enable_if_t<detail::is_mgenfloat<T>::value>> \
T NAME(ARG1, ARG2) __NOEXC { \
__SYCL_MARRAY_COMMON_FUNCTION_OVERLOAD_IMPL(NAME, __VA_ARGS__) \
}

__SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(min, T x, T y, x[i], y[i])
__SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(min, T x,
detail::marray_element_type<T> y,
x[i], y)
__SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(max, T x, T y, x[i], y[i])
__SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(max, T x,
detail::marray_element_type<T> y,
x[i], y)
__SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(step, T edge, T x, edge[i], x[i])
__SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(
step, detail::marray_element_type<T> edge, T x, edge, x[i])

#undef __SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD

#define __SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(NAME, ARG1, ARG2, ARG3, \
...) \
template <typename T, \
typename = std::enable_if_t<detail::is_mgenfloat<T>::value>> \
T NAME(ARG1, ARG2, ARG3) __NOEXC { \
__SYCL_MARRAY_COMMON_FUNCTION_OVERLOAD_IMPL(NAME, __VA_ARGS__) \
}

__SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(clamp, T x, T minval, T maxval,
x[i], minval[i], maxval[i])
__SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(
clamp, T x, detail::marray_element_type<T> minval,
detail::marray_element_type<T> maxval, x[i], minval, maxval)
__SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(mix, T x, T y, T a, x[i], y[i],
a[i])
__SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(mix, T x, T y,
detail::marray_element_type<T> a,
x[i], y[i], a)
__SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(smoothstep, T edge0, T edge1, T x,
edge0[i], edge1[i], x[i])
__SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(
smoothstep, detail::marray_element_type<T> edge0,
detail::marray_element_type<T> edge1, T x, edge0, edge1, x[i])

#undef __SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD
#undef __SYCL_MARRAY_COMMON_FUNCTION_OVERLOAD_IMPL

/* --------------- 4.13.4 Integer functions. --------------------------------*/
// ugeninteger abs (geninteger x)
template <typename T>
Expand Down Expand Up @@ -1724,6 +1905,7 @@ __SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(cos)
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(tan)
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(exp)
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(exp2)
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(exp10)
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(log)
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(log2)
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(log10)
Expand Down