Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions include/ck/utility/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP

#include "functional2.hpp"
#include "sequence.hpp"
#include "type.hpp"

namespace ck {

Expand All @@ -32,7 +31,10 @@ struct Array
{
static_assert(T::Size() == Size(), "wrong! size not the same");

static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
for(index_t i = 0; i < NSize; i++)
{
mData[i] = a[i];
}

return *this;
}
Expand Down
65 changes: 49 additions & 16 deletions include/ck/utility/sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <ostream>
#endif

#include "ck/utility/integral_constant.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/math.hpp"
Expand Down Expand Up @@ -296,29 +296,62 @@ struct uniform_sequence_gen
};

// reverse inclusive scan (with init) sequence
template <typename, typename, index_t>
struct sequence_reverse_inclusive_scan;
namespace impl {
template <typename Seq, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan_impl;

template <index_t I, index_t... Is, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
template <index_t... Is, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan_impl<Sequence<Is...>, Reduce, Init>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have unit tests for the reverse inclusive scan?

Copy link
Collaborator Author

@CongMa13 CongMa13 Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we have.

test/util/unit_sequence.cpp

ReverseInclusiveScan, ReverseExclusiveScan

{
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
template <index_t Size>
static constexpr Array<index_t, Size> compute_array()
{
Array<index_t, Size> values = {Is...};
Array<index_t, Size> result = {0};
result.At(Size - 1) = Reduce{}(values[Size - 1], Init);
for(index_t i = Size - 1; i > 0; --i)
{
result.At(i - 1) = Reduce{}(values[i - 1], result[i]);
}
return result;
}

static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
template <index_t... Indices>
static constexpr auto compute(Sequence<Indices...>)
{
constexpr index_t size = sizeof...(Is);

using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
};
if constexpr(size == 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's hard code for a few of the shorter sequences:

if constexpr(size == 0)
{
return Sequence<>{};
}
else if constexpr(size == 1)
{
constexpr index_t values[1] = {Is...};
return Sequence<Reduce{}(values[0], Init)>{};
}
else if constexpr(size == 2)
{
constexpr index_t values[2] = {Is...};
constexpr index_t r1 = Reduce{}(values[1], Init);
constexpr index_t r0 = Reduce{}(values[0], r1);
return Sequence<r0, r1>{};
}

{
return Sequence<>{};
}
else if constexpr(size == 1)
{
constexpr index_t values[1] = {Is...};
return Sequence<Reduce{}(values[0], Init)>{};
}
else if constexpr(size == 2)
{
constexpr index_t values[2] = {Is...};
constexpr index_t r1 = Reduce{}(values[1], Init);
constexpr index_t r0 = Reduce{}(values[0], r1);
return Sequence<r0, r1>{};
}
else
{
constexpr Array<index_t, size> arr = compute_array<size>();
return Sequence<arr[Indices]...>{};
}
}

template <index_t I, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
{
using type = Sequence<Reduce{}(I, Init)>;
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
};
} // namespace impl

template <typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
template <typename Seq, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan
{
using type = Sequence<>;
using type = typename impl::sequence_reverse_inclusive_scan_impl<Seq, Reduce, Init>::type;
};

// split sequence
Expand Down
2 changes: 2 additions & 0 deletions test/util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: MIT

add_gtest_executable(unit_sequence unit_sequence.cpp)
add_gtest_executable(unit_array unit_array.cpp)
if(result EQUAL 0)
target_link_libraries(unit_sequence PRIVATE utility)
target_link_libraries(unit_array PRIVATE utility)
endif()
198 changes: 198 additions & 0 deletions test/util/unit_array.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include <gtest/gtest.h>
#include "ck/utility/array.hpp"

using namespace ck;

// Test basic Array construction and properties
TEST(Array, BasicConstruction)
{
using Arr = Array<index_t, 5>;
EXPECT_EQ(Arr::Size(), 5);
}

TEST(Array, InitListConstruction)
{
using Arr = Array<index_t, 5>;
Arr value{1, 2, 3, 4, 5};
EXPECT_EQ(value[0], 1);
EXPECT_EQ(value[4], 5);
}

// Test At() method
TEST(Array, AtMethod)
{
Array<int, 3> arr{10, 20, 30};
EXPECT_EQ(arr.At(0), 10);
EXPECT_EQ(arr.At(1), 20);
EXPECT_EQ(arr.At(2), 30);

// Test non-const At() for modification
arr.At(1) = 25;
EXPECT_EQ(arr.At(1), 25);
}

// Test const At() method
TEST(Array, ConstAtMethod)
{
const Array<int, 3> arr{10, 20, 30};
EXPECT_EQ(arr.At(0), 10);
EXPECT_EQ(arr.At(1), 20);
EXPECT_EQ(arr.At(2), 30);
}

// Test operator[]
TEST(Array, OperatorBracket)
{
const Array<int, 4> arr{5, 10, 15, 20};
EXPECT_EQ(arr[0], 5);
EXPECT_EQ(arr[1], 10);
EXPECT_EQ(arr[2], 15);
EXPECT_EQ(arr[3], 20);
}

// Test operator()
TEST(Array, OperatorParenthesis)
{
Array<int, 3> arr{1, 2, 3};
EXPECT_EQ(arr(0), 1);
EXPECT_EQ(arr(1), 2);
EXPECT_EQ(arr(2), 3);

// Test modification through operator()
arr(1) = 99;
EXPECT_EQ(arr(1), 99);
}

// Test operator= assignment
TEST(Array, Assignment)
{
Array<int, 3> arr1{1, 2, 3};
Array<int, 3> arr2{0, 0, 0};

arr2 = arr1;

EXPECT_EQ(arr2[0], 1);
EXPECT_EQ(arr2[1], 2);
EXPECT_EQ(arr2[2], 3);
}

// Test iterators
TEST(Array, Iterators)
{
Array<int, 5> arr{1, 2, 3, 4, 5};

// Test begin() and end()
int sum = 0;
for(auto it = arr.begin(); it != arr.end(); ++it)
{
sum += *it;
}
EXPECT_EQ(sum, 15);

// Test range-based for loop
sum = 0;
for(auto val : arr)
{
sum += val;
}
EXPECT_EQ(sum, 15);
}

// Test const iterators
TEST(Array, ConstIterators)
{
const Array<int, 4> arr{10, 20, 30, 40};

int sum = 0;
for(auto it = arr.begin(); it != arr.end(); ++it)
{
sum += *it;
}
EXPECT_EQ(sum, 100);

// Test const range-based for loop
sum = 0;
for(auto val : arr)
{
sum += val;
}
EXPECT_EQ(sum, 100);
}

// Test make_array() helper function
TEST(Array, MakeArray)
{
auto arr = make_array(1, 2, 3, 4, 5);

EXPECT_EQ(arr.Size(), 5);
EXPECT_EQ(arr[0], 1);
EXPECT_EQ(arr[1], 2);
EXPECT_EQ(arr[2], 3);
EXPECT_EQ(arr[3], 4);
EXPECT_EQ(arr[4], 5);
}

// Test make_array() with different types
TEST(Array, MakeArrayFloats)
{
auto arr = make_array(1.5f, 2.5f, 3.5f);

EXPECT_EQ(arr.Size(), 3);
EXPECT_FLOAT_EQ(arr[0], 1.5f);
EXPECT_FLOAT_EQ(arr[1], 2.5f);
EXPECT_FLOAT_EQ(arr[2], 3.5f);
}

// Test empty Array<T, 0>
TEST(Array, EmptyArray)
{
using EmptyArr = Array<int, 0>;
EXPECT_EQ(EmptyArr::Size(), 0);

// Test make_array() for empty array
auto empty = make_array<int>();
EXPECT_EQ(empty.Size(), 0);
}

// Test Array with different data types
TEST(Array, DifferentTypes)
{
Array<float, 3> float_arr{1.1f, 2.2f, 3.3f};
EXPECT_FLOAT_EQ(float_arr[0], 1.1f);
EXPECT_FLOAT_EQ(float_arr[1], 2.2f);
EXPECT_FLOAT_EQ(float_arr[2], 3.3f);

Array<double, 2> double_arr{1.23, 4.56};
EXPECT_DOUBLE_EQ(double_arr[0], 1.23);
EXPECT_DOUBLE_EQ(double_arr[1], 4.56);
}

// Test Array modification through iterators
TEST(Array, ModifyThroughIterators)
{
Array<int, 3> arr{1, 2, 3};

for(auto it = arr.begin(); it != arr.end(); ++it)
{
*it *= 2;
}

EXPECT_EQ(arr[0], 2);
EXPECT_EQ(arr[1], 4);
EXPECT_EQ(arr[2], 6);
}

// Test single element Array
TEST(Array, SingleElement)
{
Array<int, 1> arr{42};
EXPECT_EQ(arr.Size(), 1);
EXPECT_EQ(arr[0], 42);

auto single = make_array(100);
EXPECT_EQ(single.Size(), 1);
EXPECT_EQ(single[0], 100);
}