Skip to content

[SYCL] fix asserts after logical operation changes #18411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: sycl
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions sycl/include/sycl/group_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,19 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
reduce_over_group(Group g, T x, BinaryOperation binary_op) {

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(x, x)), bool>
: std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(
std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match reduction accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
#if defined(__NVPTX__)
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<Group>) {
Expand Down Expand Up @@ -291,9 +301,18 @@ std::enable_if_t<
std::is_convertible_v<V, T>),
T>
reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(init, x)), bool>
: std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(
std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match reduction accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
return binary_op(init, reduce_over_group(g, T(x), binary_op));
#else
Expand Down Expand Up @@ -341,9 +360,18 @@ std::enable_if_t<
detail::is_native_op<T, BinaryOperation>::value),
T>
joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(init, *first)), bool>
: std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(
std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match reduction accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
T partial = detail::identity_for_ga_op<T, BinaryOperation>();
sycl::detail::for_each(
Expand Down Expand Up @@ -679,8 +707,16 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(x, x)), bool>
: std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif
#ifdef __SYCL_DEVICE_ONLY__
#if defined(__NVPTX__)
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<Group>) {
Expand Down Expand Up @@ -752,8 +788,16 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(x, x)), bool>
: std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif
T result;
typename detail::get_scalar_binary_op<BinaryOperation>::type
scalar_binary_op{};
Expand All @@ -775,8 +819,17 @@ std::enable_if_t<
std::is_convertible_v<V, T>),
T>
exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(init, x)), bool>
: std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
typename Group::linear_id_type local_linear_id =
sycl::detail::get_local_linear_id(g);
Expand Down Expand Up @@ -831,8 +884,17 @@ std::enable_if_t<
OutPtr>
joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init,
BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(init, *first)), bool>
: std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
ptrdiff_t stride = sycl::detail::get_local_linear_range(g);
Expand Down Expand Up @@ -883,9 +945,22 @@ std::enable_if_t<
OutPtr>
joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
using binary_op_t =
std::remove_cv_t<std::remove_reference_t<decltype(*first)>>;
static_assert(
(std::is_same_v<BinaryOperation, sycl::logical_or<binary_op_t>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<binary_op_t>>)
? std::is_same_v<decltype(binary_op(binary_op_t(), binary_op_t())),
bool>
: std::is_same_v<decltype(binary_op(*first, *first)),
typename detail::remove_pointer<OutPtr>::type>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(*first, *first)),
typename detail::remove_pointer<OutPtr>::type>,
"Result type of binary_op must match scan accumulation type.");
#endif
using T = typename detail::remove_pointer<OutPtr>::type;
T init = detail::identity_for_ga_op<T, BinaryOperation>();
return joint_exclusive_scan(g, first, last, result, init, binary_op);
Expand All @@ -903,8 +978,19 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(x, x)), bool>
: std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else

static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
#if defined(__NVPTX__)
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<Group>) {
Expand Down Expand Up @@ -972,8 +1058,18 @@ std::enable_if_t<
std::is_convertible_v<V, T>),
T>
inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(init, x)), bool>
: std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
T y = x;
if (sycl::detail::get_local_linear_id(g) == 0) {
Expand Down Expand Up @@ -1022,8 +1118,17 @@ std::enable_if_t<
OutPtr>
joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
BinaryOperation binary_op, T init) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
? std::is_same_v<decltype(binary_op(init, *first)), bool>
: std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match scan accumulation type.");
#else
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match scan accumulation type.");
#endif

#ifdef __SYCL_DEVICE_ONLY__
ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
ptrdiff_t stride = sycl::detail::get_local_linear_range(g);
Expand Down Expand Up @@ -1071,9 +1176,22 @@ std::enable_if_t<
OutPtr>
joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
BinaryOperation binary_op) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
using binary_op_t =
std::remove_cv_t<std::remove_reference_t<decltype(*first)>>;
static_assert(
(std::is_same_v<BinaryOperation, sycl::logical_or<binary_op_t>> ||
std::is_same_v<BinaryOperation, sycl::logical_and<binary_op_t>>)
? std::is_same_v<decltype(binary_op(binary_op_t(), binary_op_t())),
bool>
: std::is_same_v<decltype(binary_op(*first, *first)),
typename detail::remove_pointer<OutPtr>::type>,
"Result type of binary_op must match scan accumulation type.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

must match scan accumulation type

What is the accumulation type? Can we use it directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, the accumulation type is the type of input/output data type, i.e part of typename InPtr:

using binary_op_t = std::remove_cv_t<std::remove_reference_t<decltype(*first)>>;
or
typename detail::remove_pointer<OutPtr>::type>

depending on __INTEL_PREVIEW_BREAKING_CHANGES

Copy link
Contributor

@aelovikov-intel aelovikov-intel May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use it directly? Or, in other words, the current condition of the static_assert doesn't match its message. Why are you changing the condition of the assert vs changing the accumulator type? How confident are you in your change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it isn't. You still have std::is_same_v<decltype(binary_op(binary_op_t(), binary_op_t())), bool>. Either change the assert message (and explain why that would be a valid change!) or always compare with the actual result type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original problem is based on changes of return type of sycl::logical_and/sycl::logical_or.
#17239

So, when one of these structs is passed as binary operator and -fpreview-changes flag is used, the return type is not the same as the type passed to sycl::logical_and/sycl::logical_or.
That is why original assert does not work and was updated for this case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like you've treated symptoms instead of root causing the issue. The original assert had a very specific wording that doesn't take place anymore. Do you know why it was worded like that instead of a more loose wording that would have matched your changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aelovikov-intel I just know why assert failed: some of binary operations return bool type instead of binary_op_t. These operations are logical_or, logical_and. I was wondering if you explain me 'why it was worded like that' and what is a new string to represent the right output if assert fails

#else
static_assert(std::is_same_v<decltype(binary_op(*first, *first)),
typename detail::remove_pointer<OutPtr>::type>,
"Result type of binary_op must match scan accumulation type.");
#endif

using T = typename detail::remove_pointer<OutPtr>::type;
T init = detail::identity_for_ga_op<T, BinaryOperation>();
Expand Down
Loading