Skip to content

[SYCL] fix asserts after logical operation changes #18411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: sycl
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions sycl/include/sycl/group_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,19 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
reduce_over_group(Group g, T x, BinaryOperation binary_op) {

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(x, x)), bool>
: std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(
std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match reduction accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
#if defined(__NVPTX__)
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<Group>) {
Expand Down Expand Up @@ -291,9 +301,18 @@ std::enable_if_t<
std::is_convertible_v<V, T>),
T>
reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(init, x)), bool>
: std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(
std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match reduction accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
return binary_op(init, reduce_over_group(g, T(x), binary_op));
#else
Expand Down Expand Up @@ -341,9 +360,18 @@ std::enable_if_t<
detail::is_native_op<T, BinaryOperation>::value),
T>
joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(init, *first)), bool>
: std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(
std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match reduction accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
T partial = detail::identity_for_ga_op<T, BinaryOperation>();
sycl::detail::for_each(
Expand Down Expand Up @@ -679,8 +707,16 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(x, x)), bool>
: std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif
#ifdef __SYCL_DEVICE_ONLY__
#if defined(__NVPTX__)
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<Group>) {
Expand Down Expand Up @@ -752,8 +788,16 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(x, x)), bool>
: std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif
T result;
typename detail::get_scalar_binary_op<BinaryOperation>::type
scalar_binary_op{};
Expand All @@ -775,8 +819,17 @@ std::enable_if_t<
std::is_convertible_v<V, T>),
T>
exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(init, x)), bool>
: std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
typename Group::linear_id_type local_linear_id =
sycl::detail::get_local_linear_id(g);
Expand Down Expand Up @@ -831,8 +884,17 @@ std::enable_if_t<
OutPtr>
joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init,
BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(init, *first)), bool>
: std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
ptrdiff_t stride = sycl::detail::get_local_linear_range(g);
Expand Down Expand Up @@ -883,9 +945,22 @@ std::enable_if_t<
OutPtr>
joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
using binary_op_t =
std::remove_cv_t<std::remove_reference_t<decltype(*first)>>;
static_assert(
(std::is_same_v<BinaryOperation, sycl::logical_or<binary_op_t>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<binary_op_t>>)
? std::is_same_v<decltype(binary_op(binary_op_t(), binary_op_t())),
bool>
: std::is_same_v<decltype(binary_op(binary_op_t(), binary_op_t())),
binary_op_t>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(*first, *first)),
typename detail::remove_pointer<OutPtr>::type>,
"Result type of binary_op must match scan accumulation type.");
#endif
using T = typename detail::remove_pointer<OutPtr>::type;
T init = detail::identity_for_ga_op<T, BinaryOperation>();
return joint_exclusive_scan(g, first, last, result, init, binary_op);
Expand All @@ -903,8 +978,19 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(x, x)), bool>
: std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else

static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
#if defined(__NVPTX__)
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<Group>) {
Expand Down Expand Up @@ -972,8 +1058,18 @@ std::enable_if_t<
std::is_convertible_v<V, T>),
T>
inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(init, x)), bool>
: std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
T y = x;
if (sycl::detail::get_local_linear_id(g) == 0) {
Expand Down Expand Up @@ -1022,8 +1118,17 @@ std::enable_if_t<
OutPtr>
joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
BinaryOperation binary_op, T init) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(init, *first)), bool>
: std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
ptrdiff_t stride = sycl::detail::get_local_linear_range(g);
Expand Down Expand Up @@ -1071,9 +1176,22 @@ std::enable_if_t<
OutPtr>
joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
using binary_op_t =
std::remove_cv_t<std::remove_reference_t<decltype(*first)>>;
static_assert(
(std::is_same_v<BinaryOperation, sycl::logical_or<binary_op_t>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<binary_op_t>>)
? std::is_same_v<decltype(binary_op(binary_op_t(), binary_op_t())),
bool>
: std::is_same_v<decltype(binary_op(binary_op_t(), binary_op_t())),
binary_op_t>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(*first, *first)),
typename detail::remove_pointer<OutPtr>::type>,
"Result type of binary_op must match scan accumulation type.");
#endif

using T = typename detail::remove_pointer<OutPtr>::type;
T init = detail::identity_for_ga_op<T, BinaryOperation>();
Expand Down
Loading