Skip to content

[SYCL] Add marray support to rest math built-in functions #8912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 6, 2023
90 changes: 90 additions & 0 deletions sycl/include/sycl/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <sycl/detail/builtins.hpp>
#include <sycl/detail/common.hpp>
#include <sycl/detail/generic_type_traits.hpp>
#include <sycl/pointers.hpp>
#include <sycl/types.hpp>

// TODO Decide whether to mark functions with this attribute.
Expand Down Expand Up @@ -775,6 +776,95 @@ detail::enable_if_t<detail::is_svgenfloat<T>::value, T> trunc(T x) __NOEXC {
return __sycl_std::__invoke_trunc<T>(x);
}

// other marray math functions

// TODO: can be optimized in the way marray math functions above are optimized
// (usage of vec<T, 2>)
#define __SYCL_MARRAY_MATH_FUNCTION_W_GENPTR_ARG_OVERLOAD_IMPL(NAME, ARGPTR, \
...) \
marray<T, N> res; \
for (int j = 0; j < N; j++) { \
res[j] = \
NAME(__VA_ARGS__, \
address_space_cast<AddressSpace, IsDecorated, \
detail::marray_element_t<T2>>(&(*ARGPTR)[j])); \
} \
return res;

#define __SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENFLOATPTR_OVERLOAD( \
NAME, ARG1, ARG2, ...) \
template <typename T, size_t N, typename T2, \
access::address_space AddressSpace, access::decorated IsDecorated> \
std::enable_if_t< \
detail::is_svgenfloat<T>::value && \
detail::is_genfloatptr_marray<T2, AddressSpace, IsDecorated>::value, \
marray<T, N>> \
NAME(marray<T, N> ARG1, multi_ptr<T2, AddressSpace, IsDecorated> ARG2) \
__NOEXC { \
__SYCL_MARRAY_MATH_FUNCTION_W_GENPTR_ARG_OVERLOAD_IMPL(NAME, ARG2, \
__VA_ARGS__) \
}

__SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENFLOATPTR_OVERLOAD(fract, x, iptr,
x[j])
__SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENFLOATPTR_OVERLOAD(modf, x, iptr,
x[j])
__SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENFLOATPTR_OVERLOAD(sincos, x,
cosval, x[j])

#undef __SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_GENFLOATPTR_OVERLOAD

#define __SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENINTPTR_OVERLOAD( \
NAME, ARG1, ARG2, ...) \
template <typename T, size_t N, typename T2, \
access::address_space AddressSpace, access::decorated IsDecorated> \
std::enable_if_t< \
detail::is_svgenfloat<T>::value && \
detail::is_genintptr_marray<T2, AddressSpace, IsDecorated>::value, \
marray<T, N>> \
NAME(marray<T, N> ARG1, multi_ptr<T2, AddressSpace, IsDecorated> ARG2) \
__NOEXC { \
__SYCL_MARRAY_MATH_FUNCTION_W_GENPTR_ARG_OVERLOAD_IMPL(NAME, ARG2, \
__VA_ARGS__) \
}

__SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENINTPTR_OVERLOAD(frexp, x, exp,
x[j])
__SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENINTPTR_OVERLOAD(lgamma_r, x, signp,
x[j])

#undef __SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_GENINTPTR_OVERLOAD

#define __SYCL_MARRAY_MATH_FUNCTION_REMQUO_OVERLOAD(NAME, ...) \
template <typename T, size_t N, typename T2, \
access::address_space AddressSpace, access::decorated IsDecorated> \
std::enable_if_t< \
detail::is_svgenfloat<T>::value && \
detail::is_genintptr_marray<T2, AddressSpace, IsDecorated>::value, \
marray<T, N>> \
NAME(marray<T, N> x, marray<T, N> y, \
multi_ptr<T2, AddressSpace, IsDecorated> quo) __NOEXC { \
__SYCL_MARRAY_MATH_FUNCTION_W_GENPTR_ARG_OVERLOAD_IMPL(NAME, quo, \
__VA_ARGS__) \
}

__SYCL_MARRAY_MATH_FUNCTION_REMQUO_OVERLOAD(remquo, x[j], y[j])

#undef __SYCL_MARRAY_MATH_FUNCTION_REMQUO_OVERLOAD

#undef __SYCL_MARRAY_MATH_FUNCTION_W_GENPTR_ARG_OVERLOAD_IMPL

template <typename T, size_t N>
std::enable_if_t<detail::is_nan_type<T>::value,
marray<detail::nan_return_t<T>, N>>
nan(marray<T, N> nancode) __NOEXC {
marray<detail::nan_return_t<T>, N> res;
for (int j = 0; j < N; j++) {
res[j] = nan(nancode[j]);
}
return res;
}

/* --------------- 4.13.5 Common functions. ---------------------------------*/
// svgenfloat clamp (svgenfloat x, svgenfloat minval, svgenfloat maxval)
template <typename T>
Expand Down
19 changes: 19 additions & 0 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,30 @@ using is_genintptr = bool_constant<
is_pointer<T>::value && is_genint<remove_pointer_t<T>>::value &&
is_address_space_compliant<T, gvl::nonconst_address_space_list>::value>;

template <typename T, access::address_space AddressSpace,
access::decorated IsDecorated>
using is_genintptr_marray = bool_constant<
std::is_same<T, sycl::marray<marray_element_t<T>, T::size()>>::value &&
is_genint<marray_element_t<remove_pointer_t<T>>>::value &&
is_address_space_compliant<multi_ptr<T, AddressSpace, IsDecorated>,
gvl::nonconst_address_space_list>::value &&
(IsDecorated == access::decorated::yes ||
IsDecorated == access::decorated::no)>;

template <typename T>
using is_genfloatptr = bool_constant<
is_pointer<T>::value && is_genfloat<remove_pointer_t<T>>::value &&
is_address_space_compliant<T, gvl::nonconst_address_space_list>::value>;

template <typename T, access::address_space AddressSpace,
access::decorated IsDecorated>
using is_genfloatptr_marray = bool_constant<
is_mgenfloat<T>::value &&
is_address_space_compliant<multi_ptr<T, AddressSpace, IsDecorated>,
gvl::nonconst_address_space_list>::value &&
(IsDecorated == access::decorated::yes ||
IsDecorated == access::decorated::no)>;

template <typename T>
using is_genptr = bool_constant<
is_pointer<T>::value && is_gentype<remove_pointer_t<T>>::value &&
Expand Down
122 changes: 122 additions & 0 deletions sycl/test-e2e/Basic/built-ins/marray_common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
// RUN: %CPU_RUN_PLACEHOLDER %t.out
// RUN: %GPU_RUN_PLACEHOLDER %t.out
// RUN: %ACC_RUN_PLACEHOLDER %t.out

#ifdef _WIN32
#define _USE_MATH_DEFINES // To use math constants
#include <cmath>
#endif

#include <sycl/sycl.hpp>

#define TEST(FUNC, MARRAY_ELEM_TYPE, DIM, EXPECTED, DELTA, ...) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to use a macro here? Seems like we can get away with a function template, which would be preferrable. The same goes for the other test file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to implement this in a separate PR for all marray builtin tests, including these.

{ \
{ \
MARRAY_ELEM_TYPE result[DIM]; \
{ \
sycl::buffer<MARRAY_ELEM_TYPE> b(result, sycl::range{DIM}); \
deviceQueue.submit([&](sycl::handler &cgh) { \
sycl::accessor res_access{b, cgh}; \
cgh.single_task([=]() { \
sycl::marray<MARRAY_ELEM_TYPE, DIM> res = FUNC(__VA_ARGS__); \
for (int i = 0; i < DIM; i++) \
res_access[i] = res[i]; \
}); \
}); \
} \
for (int i = 0; i < DIM; i++) \
assert(abs(result[i] - EXPECTED[i]) <= DELTA); \
} \
}

#define EXPECTED(TYPE, ...) ((TYPE[]){__VA_ARGS__})

int main() {
sycl::queue deviceQueue;
sycl::device dev = deviceQueue.get_device();

sycl::marray<float, 2> ma1{1.0f, 2.0f};
sycl::marray<float, 2> ma2{1.0f, 2.0f};
sycl::marray<float, 2> ma3{3.0f, 2.0f};
sycl::marray<double, 2> ma4{1.0, 2.0};
sycl::marray<float, 3> ma5{M_PI, M_PI, M_PI};
sycl::marray<double, 3> ma6{M_PI, M_PI, M_PI};
sycl::marray<sycl::half, 3> ma7{M_PI, M_PI, M_PI};
sycl::marray<float, 2> ma8{0.3f, 0.6f};
sycl::marray<double, 2> ma9{5.0, 8.0};
sycl::marray<float, 3> ma10{180, 180, 180};
sycl::marray<double, 3> ma11{180, 180, 180};
sycl::marray<sycl::half, 3> ma12{180, 180, 180};
sycl::marray<sycl::half, 3> ma13{181, 179, 181};
sycl::marray<float, 2> ma14{+0.0f, -0.6f};
sycl::marray<double, 2> ma15{-0.0, 0.6f};

// sycl::clamp
TEST(sycl::clamp, float, 2, EXPECTED(float, 1.0f, 2.0f), 0, ma1, ma2, ma3);
TEST(sycl::clamp, float, 2, EXPECTED(float, 1.0f, 2.0f), 0, ma1, 1.0f, 3.0f);
if (dev.has(sycl::aspect::fp64))
TEST(sycl::clamp, double, 2, EXPECTED(double, 1.0, 2.0), 0, ma4, 1.0, 3.0);
// sycl::degrees
TEST(sycl::degrees, float, 3, EXPECTED(float, 180, 180, 180), 0, ma5);
if (dev.has(sycl::aspect::fp64))
TEST(sycl::degrees, double, 3, EXPECTED(double, 180, 180, 180), 0, ma6);
if (dev.has(sycl::aspect::fp16))
TEST(sycl::degrees, sycl::half, 3, EXPECTED(sycl::half, 180, 180, 180), 0.2,
ma7);
// sycl::max
TEST(sycl::max, float, 2, EXPECTED(float, 3.0f, 2.0f), 0, ma1, ma3);
TEST(sycl::max, float, 2, EXPECTED(float, 1.5f, 2.0f), 0, ma1, 1.5f);
if (dev.has(sycl::aspect::fp64))
TEST(sycl::max, double, 2, EXPECTED(double, 1.5, 2.0), 0, ma4, 1.5);
// sycl::min
TEST(sycl::min, float, 2, EXPECTED(float, 1.0f, 2.0f), 0, ma1, ma3);
TEST(sycl::min, float, 2, EXPECTED(float, 1.0f, 1.5f), 0, ma1, 1.5f);
if (dev.has(sycl::aspect::fp64))
TEST(sycl::min, double, 2, EXPECTED(double, 1.0, 1.5), 0, ma4, 1.5);
// sycl::mix
TEST(sycl::mix, float, 2, EXPECTED(float, 1.6f, 2.0f), 0, ma1, ma3, ma8);
TEST(sycl::mix, float, 2, EXPECTED(float, 1.4f, 2.0f), 0, ma1, ma3, 0.2);
if (dev.has(sycl::aspect::fp64))
TEST(sycl::mix, double, 2, EXPECTED(double, 3.0, 5.0), 0, ma4, ma9, 0.5);
// sycl::radians
TEST(sycl::radians, float, 3, EXPECTED(float, M_PI, M_PI, M_PI), 0, ma10);
if (dev.has(sycl::aspect::fp64))
TEST(sycl::radians, double, 3, EXPECTED(double, M_PI, M_PI, M_PI), 0, ma11);
if (dev.has(sycl::aspect::fp16))
TEST(sycl::radians, sycl::half, 3, EXPECTED(sycl::half, M_PI, M_PI, M_PI),
0.002, ma12);
// sycl::step
TEST(sycl::step, float, 2, EXPECTED(float, 1.0f, 1.0f), 0, ma1, ma3);
if (dev.has(sycl::aspect::fp64))
TEST(sycl::step, double, 2, EXPECTED(double, 1.0, 1.0), 0, ma4, ma9);
if (dev.has(sycl::aspect::fp16))
TEST(sycl::step, sycl::half, 3, EXPECTED(sycl::half, 1.0, 0.0, 1.0), 0,
ma12, ma13);
TEST(sycl::step, float, 2, EXPECTED(float, 1.0f, 0.0f), 0, 2.5f, ma3);
if (dev.has(sycl::aspect::fp64))
TEST(sycl::step, double, 2, EXPECTED(double, 0.0f, 1.0f), 0, 6.0f, ma9);
// sycl::smoothstep
TEST(sycl::smoothstep, float, 2, EXPECTED(float, 1.0f, 1.0f), 0, ma8, ma1,
ma2);
if (dev.has(sycl::aspect::fp64))
TEST(sycl::smoothstep, double, 2, EXPECTED(double, 1.0, 1.0f), 0.00000001,
ma4, ma9, ma9);
if (dev.has(sycl::aspect::fp16))
TEST(sycl::smoothstep, sycl::half, 3, EXPECTED(sycl::half, 1.0, 1.0, 1.0),
0, ma7, ma12, ma13);
TEST(sycl::smoothstep, float, 2, EXPECTED(float, 0.0553936f, 0.0f), 0.0000001,
2.5f, 6.0f, ma3);
if (dev.has(sycl::aspect::fp64))
TEST(sycl::smoothstep, double, 2, EXPECTED(double, 0.0f, 1.0f), 0, 6.0f,
8.0f, ma9);
// sign
TEST(sycl::sign, float, 2, EXPECTED(float, +0.0f, -1.0f), 0, ma14);
if (dev.has(sycl::aspect::fp64))
TEST(sycl::sign, double, 2, EXPECTED(double, -0.0, 1.0), 0, ma15);
if (dev.has(sycl::aspect::fp16))
TEST(sycl::sign, sycl::half, 3, EXPECTED(sycl::half, 1.0, 1.0, 1.0), 0,
ma12);

return 0;
}
Loading