Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp"

#include "ck/host_utility/hip_check_error.hpp"

using ::ck::hip_check_error;

template <ck::index_t... Is>
Expand Down
8 changes: 8 additions & 0 deletions example/59_grouped_gemm_multi_ABD/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@ add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm

add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp)
add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8)

add_custom_target(example_grouped_gemm_wmma_multi_abd)

add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16 grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16)

add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp)
add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8)

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"

#include "ck/host_utility/hip_check_error.hpp"

using ::ck::DeviceMem;
using ::ck::hip_check_error;
using ::ck::HostTensorDescriptor;
Expand Down Expand Up @@ -220,8 +222,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co

for(int i = 0; i < group_count; i++)
{
a0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
a0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i]));

b0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
Expand All @@ -232,21 +234,12 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
d0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));

c_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));

a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(),
a0_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(A0DataType));

b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data(),
b0_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(B0DataType));

b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data(),
b1_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(B1DataType));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i]));

a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data());
b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data());
b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data());
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
c_tensors_device[i]->SetZero();

Expand Down Expand Up @@ -398,7 +391,7 @@ int main(int argc, char* argv[])
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: k_batch (>0)\n");
exit(0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"

#include "ck/host_utility/hip_check_error.hpp"

using ::ck::DeviceMem;
using ::ck::hip_check_error;
using ::ck::HostTensorDescriptor;
Expand Down Expand Up @@ -47,9 +49,9 @@ using B0DataType = F16;
using BsDataType = ck::Tuple<B0DataType>;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D0DataType = F16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F32;
using EDataType = F16;

using A0Layout = Row;
using A1Layout = Row;
Expand Down Expand Up @@ -210,31 +212,24 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co

for(int i = 0; i < group_count; i++)
{
a0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
a0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i]));

a1_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(A1DataType) * sum_of_m * problem_size.Ks[i]));
a1_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(A1DataType) * problem_size.Ms[i] * problem_size.Ks[i]));

b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));

d0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));

c_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));

a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(),
a0_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(A0DataType));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i]));

a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data(),
a1_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(A1DataType));
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(B0DataType));
a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data());
a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
c_tensors_device[i]->SetZero();

Expand Down Expand Up @@ -394,7 +389,7 @@ int main(int argc, char* argv[])
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: k_batch (>0)\n");
exit(0);
}
Expand Down
Loading