Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 45 additions & 35 deletions include/ck/utility/data_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,25 @@ using f4_t = unsigned _BitInt(4);
using f6_t = _BitInt(6); // e2m3 format
using bf6_t = unsigned _BitInt(6); // e3m2 format

// scalar_type
template <typename TV>
struct scalar_type;
/**
* @brief Wrapper for native vector type
* @tparam T The element type of the vector
* @tparam Rank The number of elements in the vector
*/
template <typename T, index_t Rank>
using NativeVectorT = T __attribute__((ext_vector_type(Rank)));

/**
* @brief Mapping of incoming type to local native storage type and vector size
* @tparam T Incoming data type
*/
template <typename T>
struct scalar_type
{
// Basic data type mapping to unsigned _BitInt of appropriate size
using type = unsigned _BitInt(8 * sizeof(T));
static constexpr index_t vector_size = 1;
};
Comment on lines +49 to +55
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The introduction of a default template specialization for scalar_type that uses unsigned _BitInt(8 * sizeof(T)) is a significant change. While this provides a sensible fallback for types without explicit specializations, it may cause issues with certain types. For example, pointer types, function types, or complex class types would get this BitInt mapping which may not be the intended behavior. The previous design likely intentionally required explicit specializations to ensure only valid types were used. Consider either documenting this behavior more explicitly or adding static assertions to catch problematic type usages at compile time.

Copilot uses AI. Check for mistakes.

struct f4x2_pk_t
{
Expand Down Expand Up @@ -89,7 +105,7 @@ struct f6_pk_t
static constexpr index_t vector_size =
(packed_size * num_bits_elem) / num_bits_vec_elem; // 3 or 6 element_type units

using storage_type = element_type __attribute__((ext_vector_type(vector_size)));
using storage_type = NativeVectorT<element_type, vector_size>;
storage_type data_{storage_type(0)}; // packed data

using type = f6_pk_t<BitType, packed_size>;
Expand All @@ -110,7 +126,7 @@ struct f6_pk_t

// Broadcast single initialization value to all packed elements
__host__ __device__ f6_pk_t(const int8_t v)
: f6_pk_t(static_cast<int8_t __attribute__((ext_vector_type(packed_size)))>(v))
: f6_pk_t(static_cast<NativeVectorT<int8_t, packed_size>>(v))
{
// TODO: consider removing initialization similar to vector_type<T, 256>
}
Expand Down Expand Up @@ -191,12 +207,6 @@ struct pk_i4_t
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
};

inline constexpr auto next_pow2(uint32_t x)
{
// Precondition: x > 1.
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
}

// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
template <typename T>
Expand All @@ -208,10 +218,6 @@ inline constexpr bool is_native_type()
is_same_v<T, _BitInt(8)> || is_same_v<T, unsigned _BitInt(8)> || is_same<T, bool>::value;
}

// scalar_type
template <typename TV>
struct scalar_type;

// is_scalar_type
template <typename TV>
struct is_scalar_type
Expand All @@ -224,14 +230,13 @@ template <typename X, typename Y>
using has_same_scalar_type = is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<Y>>::type>;

template <typename T, index_t N>
struct scalar_type<T __attribute__((ext_vector_type(N)))>
template <>
struct scalar_type<bool>
{
using type = T;
static constexpr index_t vector_size = N;
using type = bool;
static constexpr index_t vector_size = 1;
};

//
template <>
struct scalar_type<double>
{
Expand Down Expand Up @@ -293,87 +298,92 @@ struct scalar_type<int4_t>
template <>
struct scalar_type<pk_i4_t>
{
using type = pk_i4_t;
using type = typename pk_i4_t::type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<f8_fnuz_t>
{
using type = f8_fnuz_t::data_type;
using type = typename f8_fnuz_t::data_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<bf8_fnuz_t>
{
using type = bf8_fnuz_t::data_type;
using type = typename bf8_fnuz_t::data_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<f8_ocp_t>
{
using type = f8_ocp_t::data_type;
using type = typename f8_ocp_t::data_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<bf8_ocp_t>
{
using type = bf8_ocp_t::data_type;
using type = typename bf8_ocp_t::data_type;
static constexpr index_t vector_size = 1;
};

#ifndef CK_CODE_GEN_RTC
template <>
struct scalar_type<e8m0_bexp_t>
{
using type = e8m0_bexp_t::type;
using type = typename e8m0_bexp_t::type;
static constexpr index_t vector_size = 1;
};
#endif

template <>
struct scalar_type<f4x2_pk_t>
{
using type = f4x2_pk_t::type;
using type = typename f4x2_pk_t::type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<f6x32_pk_t>
{
using type = f6x32_pk_t::storage_type;
using type = typename f6x32_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<bf6x32_pk_t>
{
using type = bf6x32_pk_t::storage_type;
using type = typename bf6x32_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<f6x16_pk_t>
{
using type = f6x16_pk_t::storage_type;
using type = typename f6x16_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<bf6x16_pk_t>
{
using type = bf6x16_pk_t::storage_type;
using type = typename bf6x16_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<bool>
/**
* @brief scalar_type trait override for NativeVectorT
* @tparam T The vector type
* @tparam Rank The number of elements in the vector
*/
template <typename T, index_t Rank>
struct scalar_type<NativeVectorT<T, Rank>>
{
using type = bool;
static constexpr index_t vector_size = 1;
using type = T;
static constexpr index_t vector_size = Rank;
};

template <typename T>
Expand Down
Loading
Loading