Skip to content

Using a wrapper byte_array to support both uint8_t and int8_t #1418

@divyegala

Description

@divyegala

Currently, libcuvs instantiates both uint8_t and int8_t which causes binary size bloat. Given the relative similarity of both types and the ease of normalization of values that are encompassed by them, I propose that we combine the two types using a wrapper cuvs::detail::byte_array.

The public API remains the same:

namespace cuvs {

void foo(raft::device_matrix_view<uint8_t, ...>(...);

void foo(raft::device_matrix_view<int8_t, ...>(...);

}

The definition of cuvs::detail::byte_array, which is a struct that would always normalize int8_t values to uint8_t when returning and de-normalize when assigning:

namespace cuvs::detail {

struct byte_array {
  void* data = nullptr;
  bool is_signed = false;

  byte_array(void* ptr, bool signed_flag)
      : data(ptr), is_signed(signed_flag) {}

  // Proxy that references an element in the array
  struct byte {
    byte_array* parent = nullptr;
    int64_t idx = -1;
    uint8_t value = 0;   // used for detached proxies

    // Constructor for live proxy
    byte(byte_array& p, int64_t i) : parent(&p), idx(i) {}

    // Copy constructor: detached copy stores the current value
    byte(const byte& other)
        : parent(nullptr), idx(-1), value(static_cast<uint8_t>(other)) {}

    // Copy assignment: detached copy stores value
    byte& operator=(const byte& other) {
        parent = nullptr;
        idx = -1;
        value = static_cast<uint8_t>(other);
        return *this;
    }

    // Deleted move operations
    byte(byte&& other) = delete;
    byte& operator=(byte&& other) = delete;

    // Conversion to uint8_t
    operator uint8_t() const {
        if (parent) {
            if (parent->is_signed) {
                int8_t val = reinterpret_cast<int8_t*>(parent->data)[idx];
                return static_cast<uint8_t>(static_cast<int16_t>(val) + 128);
            } else {
                return reinterpret_cast<uint8_t*>(parent->data)[idx];
            }
        } else {
            return value;  // return local value if detached
        }
    }

    // Assignment from uint8_t
    byte& operator=(uint8_t normalized_value) {
        if (parent) {
            if (parent->is_signed) {
                reinterpret_cast<int8_t*>(parent->data)[idx] =
                    static_cast<int8_t>(static_cast<int16_t>(normalized_value) - 128);
            } else {
                reinterpret_cast<uint8_t*>(parent->data)[idx] = normalized_value;
            }
        } else {
            value = normalized_value;  // store in local value if detached
        }
        return *this;
    }
  };

  // Non-const index access: returns live proxy
  byte operator[](int64_t idx) { return byte(*this, idx); }

  // Const index access: returns immediate value
  uint8_t operator[](int64_t idx) const {
    if (is_signed) {
      int8_t val = reinterpret_cast<int8_t*>(data)[idx];
      return static_cast<uint8_t>(static_cast<int16_t>(val) + 128);
    } else {
      return reinterpret_cast<uint8_t*>(data)[idx];
    }
  }

  // Dereference (like *ptr)
  uint8_t operator*() const { return (*this)[0]; }
  byte operator*() { return byte(*this, 0); }

  // Pointer arithmetic
  byte_array operator+(int64_t offset) const {
    if (is_signed)
      return byte_array(static_cast<int8_t*>(data) + offset, true);
    else
      return byte_array(static_cast<uint8_t*>(data) + offset, false);
  }

  bool operator==(const byte_array& other) const { return data == other.data; }
  bool operator!=(const byte_array& other) const { return !(*this == other); }
};

} // namespace cuvs::detail

This would change the kernel definition to:

template <typename T, typename DataT = std::conditional_t<std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, cuvs::detail::byte_array, T*>
__global__ void kernel(DataT data, ...) {
  ...
} 

And kernel instantiation would now look like:

// T is still uint8_t or int8_t at this point
template <typename T>
void launch_kernel(raft::device_matrix_view<T, ...> data, ...) {
   if constexpr (std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>) {
       cuvs::detail::byte_array data_b{data.data_handle(), std::is_same_v<T, int8_t>);
       kernel<uint8_t><<<...>>>(data_b, ...)
   }
   else {
        // normal instantiation, DataT is initialized to T*
        kernel<<<...>>>(data.data_handle(), ...)
    }
}

This approach will allow us to instantiate kernels and do all arithmetic in uint8_t without having to modify user data. At the time of search, we still do arithmetic in uint8_t and can write back data in int8_t if that is what the user requests.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions