Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example/ck_tile/38_block_scale_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()

list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -Wno-global-constructors) # use global constructors to add kernel instances
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)

if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
// template <typename T>
// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode<T>;

void abquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
lut[hash_multiple_strings({"fp8",
"abquant",
"non-preshuffleb",
Expand Down Expand Up @@ -135,4 +134,5 @@ void abquant_quantgrouped_instance_factory(
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ using GemmConfig = GemmConfigQuantDecode<T>;
// template <typename T>
// using GemmConfig = GemmConfigQuantPrefill<T>;

void aquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings(
{"fp8", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
Expand Down Expand Up @@ -56,4 +55,5 @@ void aquant_quantgrouped_instance_factory(
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
template <typename T>
using GemmConfig = GemmConfigPreshuffleQuantDecode<T>;

void aquant_quantgrouped_preshufflequant_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings(
{"fp8", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
Expand Down Expand Up @@ -52,4 +51,5 @@ void aquant_quantgrouped_preshufflequant_instance_factory(
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_bf16fp4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf16_t,
ck_tile::pk_fp4_raw_t,
ck_tile::bf16_t,
Expand All @@ -38,4 +37,5 @@ void bquant_quantgrouped_bf16fp4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
#ifndef CK_GFX950_SUPPORT
Expand Down Expand Up @@ -55,4 +54,5 @@ void bquant_quantgrouped_bf8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
Expand Down Expand Up @@ -57,4 +56,5 @@ void bquant_quantgrouped_bf8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
#ifndef CK_GFX950_SUPPORT
Expand Down Expand Up @@ -55,4 +54,5 @@ void bquant_quantgrouped_fp8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
Expand Down Expand Up @@ -57,4 +56,5 @@ void bquant_quantgrouped_fp8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_preshuffleb_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
Expand Down Expand Up @@ -50,4 +49,5 @@ void bquant_quantgrouped_preshuffleb_bf8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
Expand Down Expand Up @@ -54,4 +53,5 @@ void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_preshuffleb_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
Expand Down Expand Up @@ -50,4 +49,5 @@ void bquant_quantgrouped_preshuffleb_fp8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
Expand Down Expand Up @@ -54,4 +53,5 @@ void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
Expand Down Expand Up @@ -47,4 +46,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
Expand Down Expand Up @@ -49,4 +48,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
Expand Down Expand Up @@ -47,4 +46,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
Expand Down Expand Up @@ -49,4 +48,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_preshufflequant_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
Expand Down Expand Up @@ -52,4 +51,5 @@ void bquant_quantgrouped_preshufflequant_bf8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
Expand Down Expand Up @@ -56,4 +55,5 @@ void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_preshufflequant_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
Expand Down Expand Up @@ -52,4 +51,5 @@ void bquant_quantgrouped_preshufflequant_fp8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
Expand Down Expand Up @@ -56,4 +55,5 @@ void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();
Loading