-
Notifications
You must be signed in to change notification settings - Fork 132
Description
Currently, libcuvs
instantiates both uint8_t
and int8_t
which causes binary size bloat. Given the relative similarity of both types and the ease of normalization of values that are encompassed by them, I propose that we combine the two types using a wrapper cuvs::detail::byte_array
.
The public API remains the same:
namespace cuvs {
void foo(raft::device_matrix_view<uint8_t, ...>(...);
void foo(raft::device_matrix_view<int8_t, ...>(...);
}
The definition of cuvs::detail::byte_array
, which is a struct that would always normalize int8_t
values to uint8_t
when returning and de-normalize when assigning:
namespace cuvs::detail {
struct byte_array {
void* data = nullptr;
bool is_signed = false;
byte_array(void* ptr, bool signed_flag)
: data(ptr), is_signed(signed_flag) {}
// Proxy that references an element in the array
struct byte {
byte_array* parent = nullptr;
int64_t idx = -1;
uint8_t value = 0; // used for detached proxies
// Constructor for live proxy
byte(byte_array& p, int64_t i) : parent(&p), idx(i) {}
// Copy constructor: detached copy stores the current value
byte(const byte& other)
: parent(nullptr), idx(-1), value(static_cast<uint8_t>(other)) {}
// Copy assignment: detached copy stores value
byte& operator=(const byte& other) {
parent = nullptr;
idx = -1;
value = static_cast<uint8_t>(other);
return *this;
}
// Deleted move operations
byte(byte&& other) = delete;
byte& operator=(byte&& other) = delete;
// Conversion to uint8_t
operator uint8_t() const {
if (parent) {
if (parent->is_signed) {
int8_t val = reinterpret_cast<int8_t*>(parent->data)[idx];
return static_cast<uint8_t>(static_cast<int16_t>(val) + 128);
} else {
return reinterpret_cast<uint8_t*>(parent->data)[idx];
}
} else {
return value; // return local value if detached
}
}
// Assignment from uint8_t
byte& operator=(uint8_t normalized_value) {
if (parent) {
if (parent->is_signed) {
reinterpret_cast<int8_t*>(parent->data)[idx] =
static_cast<int8_t>(static_cast<int16_t>(normalized_value) - 128);
} else {
reinterpret_cast<uint8_t*>(parent->data)[idx] = normalized_value;
}
} else {
value = normalized_value; // store in local value if detached
}
return *this;
}
};
// Non-const index access: returns live proxy
byte operator[](int64_t idx) { return byte(*this, idx); }
// Const index access: returns immediate value
uint8_t operator[](int64_t idx) const {
if (is_signed) {
int8_t val = reinterpret_cast<int8_t*>(data)[idx];
return static_cast<uint8_t>(static_cast<int16_t>(val) + 128);
} else {
return reinterpret_cast<uint8_t*>(data)[idx];
}
}
// Dereference (like *ptr)
uint8_t operator*() const { return (*this)[0]; }
byte operator*() { return byte(*this, 0); }
// Pointer arithmetic
byte_array operator+(int64_t offset) const {
if (is_signed)
return byte_array(static_cast<int8_t*>(data) + offset, true);
else
return byte_array(static_cast<uint8_t*>(data) + offset, false);
}
bool operator==(const byte_array& other) const { return data == other.data; }
bool operator!=(const byte_array& other) const { return !(*this == other); }
};
} // namespace cuvs::detail
This would change the kernel definition to:
template <typename T, typename DataT = std::conditional_t<std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, cuvs::detail::byte_array, T*>
__global__ void kernel(DataT data, ...) {
...
}
And kernel instantiation would now look like:
// T is still uint8_t or int8_t at this point
template <typename T>
void launch_kernel(raft::device_matrix_view<T, ...> data, ...) {
if constexpr (std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>) {
cuvs::detail::byte_array data_b{data.data_handle(), std::is_same_v<T, int8_t>);
kernel<uint8_t><<<...>>>(data_b, ...)
}
else {
// normal instantiation, DataT is initialized to T*
kernel<<<...>>>(data.data_handle(), ...)
}
}
This approach will allow us to instantiate kernels and do all arithmetic in uint8_t
without having to modify user data. At the time of search, we still do arithmetic in uint8_t
and can write back data in int8_t
if that is what the user requests.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status