Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/ck_tile/core/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@
#else
#define CK_TILE_USE_OCP_FP8 0
#endif
#elif defined(CK_USE_OCP_FP8)
// Host code: respect the build-time CK_USE_OCP_FP8 flag
#define CK_TILE_USE_OCP_FP8 1
#else
#define CK_TILE_USE_OCP_FP8 0
#endif
Expand Down
79 changes: 73 additions & 6 deletions include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,19 +296,86 @@ struct CShuffleEpilogue
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
// XOR swizzle to eliminate LDS bank conflicts
// AMD LDS architecture:
// - MI300 (gfx942): 32 banks, bank = (address_in_bytes / 4) % 32
// - MI350 (gfx950): 64 banks, bank = (address_in_bytes / 4) % 64
//
// Problem: With 2-byte FP16 elements and 4-byte banks, adjacent columns
// (N, N+1) share the same bank. When MFMA warp distribution has adjacent
// threads accessing adjacent columns, this causes 2-way bank conflicts.
//
// Solution: XOR on N1 (low bit of N) to interleave even/odd columns
// into different physical rows. This spreads adjacent columns to
// different bank regions, eliminating conflicts.
//
// Strategy: M' = M ^ (N & 1)
// - Even columns (N=0,2,4...) stay in physical rows M
// - Odd columns (N=1,3,5...) go to physical rows M^1

// N is contiguous dimension
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
constexpr index_t N1_size = 2; // Split N into even/odd
constexpr index_t N0_size = NPerIterationShuffle / N1_size;

// Step 1: Create 3D descriptor [M, N0, N1]
constexpr auto lds_desc_3d = make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<N0_size>{}, number<N1_size>{}),
make_tuple(number<NPerIterationShuffle>{}, number<N1_size>{}, number<1>{}),
number<N1_size>{},
number<1>{});

// Step 2: Apply XOR between M and N1 (the low bit of N)
// This interleaves even/odd columns into different physical rows
constexpr auto lds_desc_xor = transform_tensor_descriptor(
lds_desc_3d,
make_tuple(make_xor_transform(
make_tuple(number<MPerIterationShuffle>{}, number<N1_size>{})),
make_pass_through_transform(number<N0_size>{})),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}));

// Step 3: Merge N0 and N1 back to N
return transform_tensor_descriptor(
lds_desc_xor,
make_tuple(make_pass_through_transform(number<MPerIterationShuffle>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<N0_size>{}, number<N1_size>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
// M is contiguous dimension
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
constexpr index_t M1_size = 2; // Split M into even/odd
constexpr index_t M0_size = MPerIterationShuffle / M1_size;

// Step 1: Create 3D descriptor [M0, M1, N]
constexpr auto lds_desc_3d = make_naive_tensor_descriptor(
make_tuple(number<M0_size>{}, number<M1_size>{}, number<NPerIterationShuffle>{}),
make_tuple(number<M1_size>{}, number<1>{}, number<MPerIterationShuffle>{}),
number<M1_size>{},
number<1>{});

// Step 2: Apply XOR between M1 and N (the low bit of M with N)
// This interleaves even/odd rows into different physical columns
constexpr auto lds_desc_xor = transform_tensor_descriptor(
lds_desc_3d,
make_tuple(make_pass_through_transform(number<M0_size>{}),
make_xor_transform(
make_tuple(number<M1_size>{}, number<NPerIterationShuffle>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));

// Step 3: Merge M0 and M1 back to M
return transform_tensor_descriptor(
lds_desc_xor,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<M0_size>{}, number<M1_size>{})),
make_pass_through_transform(number<NPerIterationShuffle>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
Expand Down
125 changes: 100 additions & 25 deletions test/ck_tile/epilogue/test_cshuffle_epilogue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,51 @@

using namespace ck_tile;

class CShuffleEpilogueTest : public ::testing::Test
// Test configuration template for parameterized tests
template <typename DataType_,
index_t MPerBlock_,
index_t NPerBlock_,
index_t MWave_,
index_t NWave_,
index_t MPerXdl_,
index_t NPerXdl_,
index_t KPerXdl_>
struct TileConfig
{
protected:
void SetUp() override {}
using DataType = DataType_;
static constexpr index_t kMPerBlock = MPerBlock_;
static constexpr index_t kNPerBlock = NPerBlock_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
};

TEST_F(CShuffleEpilogueTest, BasicHalfTest)
// Type-parameterized test fixture
template <typename Config>
class CShuffleEpilogueTypedTest : public ::testing::Test
{
// Basic test configuration with half_t data types
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
};

TYPED_TEST_SUITE_P(CShuffleEpilogueTypedTest);

TYPED_TEST_P(CShuffleEpilogueTypedTest, BasicTest)
{
using Config = TypeParam;
using DataType = typename Config::DataType;
using ADataType = DataType;
using BDataType = DataType;
using AccDataType = float;
using ODataType = ck_tile::half_t;
using ODataType = DataType;

constexpr index_t kMPerBlock = 256;
constexpr index_t kNPerBlock = 256;
constexpr index_t MWave = 2;
constexpr index_t NWave = 2;
constexpr index_t MPerXdl = 32;
constexpr index_t NPerXdl = 32;
constexpr index_t KPerXdl = 8;
constexpr index_t kMPerBlock = Config::kMPerBlock;
constexpr index_t kNPerBlock = Config::kNPerBlock;
constexpr index_t MWave = Config::MWave;
constexpr index_t NWave = Config::NWave;
constexpr index_t MPerXdl = Config::MPerXdl;
constexpr index_t NPerXdl = Config::NPerXdl;
constexpr index_t KPerXdl = Config::KPerXdl;

using TestProblem = SimpleCShuffleEpilogueProblem<ADataType,
BDataType,
Expand All @@ -42,12 +66,66 @@ TEST_F(CShuffleEpilogueTest, BasicHalfTest)
KPerXdl>;

auto result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::None);
EXPECT_FLOAT_EQ(result[0], 2.0F) << "Basic CShuffleEpilogue test failed";

if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
{
EXPECT_EQ(result[0], ck_tile::type_convert<ck_tile::fp8_t>(2.f))
<< "CShuffleEpilogue FP8 test failed";
}
else
{
EXPECT_FLOAT_EQ(ck_tile::type_convert<float>(result[0]), 2.0F)
<< "CShuffleEpilogue test failed";
}
}

TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)
REGISTER_TYPED_TEST_SUITE_P(CShuffleEpilogueTypedTest, BasicTest);

// Half precision test configurations
using HalfConfig_256x256_2x2x1_32x32x8 = TileConfig<half_t, 256, 256, 2, 2, 32, 32, 8>;
using HalfConfig_128x128_1x4x1_16x16x16 = TileConfig<half_t, 128, 128, 1, 4, 16, 16, 16>;
using HalfConfig_128x128_2x2x1_16x16x16 = TileConfig<half_t, 128, 128, 2, 2, 16, 16, 16>;
using HalfConfig_128x128_4x1x1_16x16x16 = TileConfig<half_t, 128, 128, 4, 1, 16, 16, 16>;
using HalfConfig_128x128_2x2x1_32x32x16 = TileConfig<half_t, 128, 128, 2, 2, 32, 32, 16>;

// FP8 test configurations
using FP8Config_128x128_2x2x1_16x16x16 = TileConfig<fp8_t, 128, 128, 2, 2, 16, 16, 16>;
using FP8Config_128x128_1x4x1_16x16x16 = TileConfig<fp8_t, 128, 128, 1, 4, 16, 16, 16>;
using FP8Config_128x128_4x1x1_16x16x16 = TileConfig<fp8_t, 128, 128, 4, 1, 16, 16, 16>;
using FP8Config_128x128_2x2x1_32x32x16 = TileConfig<fp8_t, 128, 128, 2, 2, 32, 32, 16>;
using FP8Config_128x128_2x2x1_16x16x32 = TileConfig<fp8_t, 128, 128, 2, 2, 16, 16, 32>;
using FP8Config_128x128_2x2x1_32x32x32 = TileConfig<fp8_t, 128, 128, 2, 2, 32, 32, 32>;
using FP8Config_128x128_2x2x1_16x16x64 = TileConfig<fp8_t, 128, 128, 2, 2, 16, 16, 64>;

using HalfTestTypes = ::testing::Types<HalfConfig_256x256_2x2x1_32x32x8,
HalfConfig_128x128_1x4x1_16x16x16,
HalfConfig_128x128_2x2x1_16x16x16,
HalfConfig_128x128_4x1x1_16x16x16,
HalfConfig_128x128_2x2x1_32x32x16>;

using FP8TestTypes = ::testing::Types<FP8Config_128x128_2x2x1_16x16x16,
FP8Config_128x128_1x4x1_16x16x16,
FP8Config_128x128_4x1x1_16x16x16,
FP8Config_128x128_2x2x1_32x32x16,
FP8Config_128x128_2x2x1_16x16x32,
FP8Config_128x128_2x2x1_32x32x32,
FP8Config_128x128_2x2x1_16x16x64>;

// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wused-but-marked-unused"
INSTANTIATE_TYPED_TEST_SUITE_P(Half, CShuffleEpilogueTypedTest, HalfTestTypes);
INSTANTIATE_TYPED_TEST_SUITE_P(FP8, CShuffleEpilogueTypedTest, FP8TestTypes);
#pragma clang diagnostic pop
// clang-format on

// Additional tests for scale operations (not parameterized due to different verification logic)
class CShuffleEpilogueScaleTest : public ::testing::Test
{
};

TEST_F(CShuffleEpilogueScaleTest, HalfTestWithRowColScale)
{
// Basic test configuration with half_t data types
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
Expand Down Expand Up @@ -75,14 +153,12 @@ TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)

auto result =
run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::RowCol);
EXPECT_FLOAT_EQ(result[0], 2.0F) << "RowCol CShuffleEpilogue test failed: first element not 2";
EXPECT_FLOAT_EQ(result[1], 4.0F)
<< "RowCol CShuffleEpilogue test failed: second element not 2*2";
EXPECT_FLOAT_EQ(result[0], 2.0F) << "RowCol scale test failed: first element not 2";
EXPECT_FLOAT_EQ(result[1], 4.0F) << "RowCol scale test failed: second element not 2*2";
}

TEST_F(CShuffleEpilogueTest, BasicHalfTestWithTensorScale)
TEST_F(CShuffleEpilogueScaleTest, HalfTestWithTensorScale)
{
// Basic test configuration with half_t data types
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
Expand Down Expand Up @@ -110,8 +186,7 @@ TEST_F(CShuffleEpilogueTest, BasicHalfTestWithTensorScale)

auto result =
run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::Tensor);
EXPECT_FLOAT_EQ(result[0], 4.0F)
<< "TensorScale CShuffleEpilogue test failed: first element not 2*2=4";
EXPECT_FLOAT_EQ(result[0], 4.0F) << "Tensor scale test failed: first element not 2*2=4";
}

int main(int argc, char** argv)
Expand Down
8 changes: 6 additions & 2 deletions test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,13 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res
auto acc_tile =
make_static_distributed_tensor<typename Epilogue::AccDataType>(lds_distribution_encode);

// Fill acc_tile with a simple pattern
// Fill acc_tile with a simple pattern - fill entire buffer to ensure correct
// output regardless of tile distribution
auto& acc_buffer = acc_tile.get_thread_buffer();
acc_buffer[0] = 2.0F;
for(index_t i = 0; i < acc_buffer.size(); i++)
{
acc_buffer[i] = 2.0F;
}

// Create output tensor view
auto output_tensor_view =
Expand Down