Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 205 additions & 72 deletions nntrainer/tensor/cl_operations/cl_kernels/int4_gemv.cl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications made by Donghyeon Jeong on September 13 2025:
// - Limit its functionality exclusively to OS_IS_YX_OSV32_ISV2
// - Portability updates (Adreno-compatible) while preserving Intel intrinsics:

#if defined(cl_khr_fp16)
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
Expand All @@ -15,7 +16,7 @@
#define CAT(x, y) __CAT(x, y)

#define unroll_for __attribute__((opencl_unroll_hint)) for
#define CEIL_DIV(a, b) (((a) + (b)-1) / (b))
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
#define ALIGN(a, b) (CEIL_DIV(a, b) * (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
Expand Down Expand Up @@ -52,6 +53,19 @@
#define REQD_SUB_GROUP_SIZE(sg_size)
#endif

// ==========================================================================
// Non-Intel: define logical subgroup mapping to X-dimension (local_id(0))
// Intel path uses cl_intel_subgroups builtins directly.
// ==========================================================================
#if !defined(cl_intel_subgroups)
#define get_sub_group_local_id() ((uint)get_local_id(0))
#define get_sub_group_size() ((uint)get_local_size(0))
#define get_max_sub_group_size() ((uint)get_local_size(0))
#endif

// ==========================================================================
// Block-read type plumbing (Intel path unchanged).
// ==========================================================================
#define BLOCK_READ_TYPE_size1 uchar
#define BLOCK_READ_TYPE_size2 ushort
#define BLOCK_READ_TYPE_size4 uint
Expand Down Expand Up @@ -102,6 +116,54 @@
#define DT_FILTER_BLOCK_READ8(ptr, offset) BLOCK_READN(char, 8, ptr, offset)
#define DT_FILTER_BLOCK_READ16(ptr, offset) BLOCK_READN(char, 16, ptr, offset)

// ==========================================================================
// Block-read emulation (when Intel block-read intrinsics aren't present).
// ==========================================================================
#define BLOCK_READ_IMPL_1 ret = ptr[idx];

#define BLOCK_READ_IMPL_2 \
ret.s0 = ptr[idx]; \
idx += get_max_sub_group_size(); \
ret.s1 = ptr[idx]; \
idx += get_max_sub_group_size();

#define BLOCK_READ_IMPL_4 \
BLOCK_READ_IMPL_2 \
ret.s2 = ptr[idx]; \
idx += get_max_sub_group_size(); \
ret.s3 = ptr[idx]; \
idx += get_max_sub_group_size();

#define BLOCK_READ_IMPL_8 \
BLOCK_READ_IMPL_4 \
ret.s4 = ptr[idx]; \
idx += get_max_sub_group_size(); \
ret.s5 = ptr[idx]; \
idx += get_max_sub_group_size(); \
ret.s6 = ptr[idx]; \
idx += get_max_sub_group_size(); \
ret.s7 = ptr[idx]; \
idx += get_max_sub_group_size();

#define BLOCK_READ_IMPL_16 \
BLOCK_READ_IMPL_8 \
ret.s8 = ptr[idx]; \
idx += get_max_sub_group_size(); \
ret.s9 = ptr[idx]; \
idx += get_max_sub_group_size(); \
ret.sa = ptr[idx]; \
idx += get_max_sub_group_size(); \
ret.sb = ptr[idx]; \
idx += get_max_sub_group_size(); \
ret.sc = ptr[idx]; \
idx += get_max_sub_group_size(); \
ret.sd = ptr[idx]; \
idx += get_max_sub_group_size(); \
ret.se = ptr[idx]; \
idx += get_max_sub_group_size(); \
ret.sf = ptr[idx]; \
idx += get_max_sub_group_size();

#define BLOCK_READ_IMPL(vec_size) CAT(BLOCK_READ_IMPL_, vec_size)
#define BLOCK_READ_FUNC_NAME(type_size, vec_size) \
MAKE_VECTOR_TYPE(BLOCK_READ_FUNC(type_size), vec_size)
Expand Down Expand Up @@ -141,6 +203,19 @@ DECLARE_BLOCK_READ_EMULATION(1, 8)
DECLARE_BLOCK_READ_EMULATION(1, 16)
#endif

// ---- Macro preserving intel_sub_group_block_read() with fallback ----
#if defined(cl_intel_subgroups)
#define SLM_BLOCK_READ_FLOAT(ptr_) \
as_float(intel_sub_group_block_read((const __local uint *)(ptr_)))
#else
#define SLM_BLOCK_READ_FLOAT(ptr_) \
((const __local float *)(ptr_))[get_sub_group_local_id()]
#endif
// --------------------------------------------------------------------

// ==========================================================================
// GEMV configuration
// ==========================================================================
#define SIMD 16
#define SUBGROUP_SIZE SIMD
#define DECOMPRESSION_GROUP_SIZE SIZE_QUANTIZATION_GROUP
Expand All @@ -160,6 +235,59 @@ DECLARE_BLOCK_READ_EMULATION(1, 16)
BLOCK_READN(half, INPUT_TILE_SIZE, ptr, offset)
#define GEMV_FILTER_BLOCK_READ(ptr, offset) BLOCK_READN(char, 16, ptr, offset)

// ==========================================================================
// Non-Intel subgroup broadcast / reduce emulation
// - Subgroups are lanes along X (local_id(0))
// - thr_id is Z-dimension (local_id(2))
// - Each thr_id slice gets its own 16-element buffer
// ==========================================================================
#if !defined(cl_intel_subgroups)

inline float sg_reduce_add_float(float v, __local float *buf_line) {
uint lid = get_sub_group_local_id(); // lane in X-dimension
buf_line[lid] = v;
barrier(CLK_LOCAL_MEM_FENCE);

uint sg_size = SUBGROUP_SIZE; // expected 16
for (uint stride = sg_size >> 1; stride > 0; stride >>= 1) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it:

shouldn't

if (get_subgroup_local_id() != 0)
{
   for (uint i = 1; i <SUBGROUP_SIZE; ++i)
   {
       buf_line[0] = buf_line[0] + buf_line[i];
   }
}

as we would want to reduce all no just some.
Simply don't get it what's going on with this stride.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh think I get it, to make it work the way you want the reduction accumulation local buf_line[0] + to strided element should be atomic (local atomit fetch add) otherwise it races between subgroup threads. (race between is on buf_line[0] and modifiaction of buf_line[0])

Have you tested correctness of this reduction?

if (lid < stride) {
buf_line[lid] = buf_line[lid] + buf_line[lid + stride];
}
barrier(CLK_LOCAL_MEM_FENCE);
}

float result = buf_line[0];
barrier(CLK_LOCAL_MEM_FENCE); // reuse buffer_line safely later
return result;
}

inline half sg_broadcast_half(half v, uint src_lane, __local half *buf_line) {
uint lid = get_sub_group_local_id(); // lane in X-dimension
buf_line[lid] = v;
barrier(CLK_LOCAL_MEM_FENCE);

half result = buf_line[src_lane];
barrier(CLK_LOCAL_MEM_FENCE);
return result;
}

#define SG_BCAST_HALF(val, lane) \
sg_broadcast_half((val), (lane), \
sg_bcast_buf + get_local_id(2) * SUBGROUP_SIZE)

#define SG_REDUCE_ADD_FLOAT(val) \
sg_reduce_add_float((val), sg_reduce_buf + get_local_id(2) * SUBGROUP_SIZE)

#else // Intel: just alias to real sub-group intrinsics

#define SG_BCAST_HALF(val, lane) sub_group_broadcast((val), (lane))
#define SG_REDUCE_ADD_FLOAT(val) sub_group_reduce_add((val))

#endif // !cl_intel_subgroups

// ==========================================================================
// Helper functions
// ==========================================================================
inline int get_4bit_weight_index(int k, int n, int K, int N, int OSV) {
return (n / OSV) * (OSV * K / 2) + (n % OSV) + (k / 2) * OSV;
}
Expand All @@ -184,11 +312,13 @@ inline void thread_task_splitter(const int group_num, const int thr_num,
*n_end += *n_start;
}

__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void
fully_connected_gpu_int4_gemv(__global half *input, const __global half *scales,
__global half *output,
const __global char *weights, const int WEIGHTS_K,
const int WEIGHTS_N) {
// ==========================================================================
// Kernel
// ==========================================================================
REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
kernel void fully_connected_gpu_int4_gemv(
__global half *input, const __global half *scales, __global half *output,
const __global char *weights, const int WEIGHTS_K, const int WEIGHTS_N) {
const int SCALE_GROUP_NUM = CEIL_DIV(WEIGHTS_K, SIZE_QUANTIZATION_GROUP);
int ALIGN_WEIGHTS_N = ALIGN(WEIGHTS_N, 32);
int ALIGN_WEIGHTS_K = ALIGN(WEIGHTS_K, SIZE_QUANTIZATION_GROUP);
Expand All @@ -204,39 +334,45 @@ fully_connected_gpu_int4_gemv(__global half *input, const __global half *scales,
__local float all_sum_even[16][16]; // [wi_id, thr_id]
__local float all_sum_odd[16][16];

#if !defined(cl_intel_subgroups)
// Non-Intel: subgroup emulation scratch
__local half sg_bcast_buf[SUBGROUP_SIZE * SUBGROUP_SIZE]; // 16 * 16
__local float sg_reduce_buf[SUBGROUP_SIZE * SUBGROUP_SIZE]; // 16 * 16
#endif

#if SCALE_ROW_MAJOR
scales += ((n / 32) * 32 + (n % 32) / 2) * SCALE_GROUP_NUM;
const __global half *scales_base =
scales + ((n / 32) * 32 + (n % 32) / 2) * SCALE_GROUP_NUM;
#else
// Scale layout is fbyx
scales += (n / 32) * 32 + (n % 32) / 2;
const __global half *scales_base = scales + (n / 32) * 32 + (n % 32) / 2;
#endif

float2 sum_all = 0;
float2 sum_all = 0.0f;
for (int gk = gk0; gk < gk1; gk++) {
__global half *A = input + gk * DECOMPRESSION_GROUP_SIZE;
int w_id = get_4bit_weight_index(gk * DECOMPRESSION_GROUP_SIZE, n,
ALIGN_WEIGHTS_K, ALIGN_WEIGHTS_N, 32);

const __global char *B = weights + w_id;

GEMV_ACCUMULATOR_VEC_TYPE sum = 0;
GEMV_ACCUMULATOR_VEC_TYPE sum = 0.0f;

#if SCALE_ROW_MAJOR
float scale_0 = convert_float(scales[gk]);
float scale_1 = convert_float(scales[gk + 16 * SCALE_GROUP_NUM]);
float scale_0 = convert_float(scales_base[gk]);
float scale_1 = convert_float(scales_base[gk + 16 * SCALE_GROUP_NUM]);
#else
float scale_0 = convert_float(scales[gk * ALIGN_WEIGHTS_N]);
float scale_1 = convert_float(scales[gk * ALIGN_WEIGHTS_N + 16]);
float scale_0 = convert_float(scales_base[gk * ALIGN_WEIGHTS_N]);
float scale_1 = convert_float(scales_base[gk * ALIGN_WEIGHTS_N + 16]);
#endif

__attribute__((opencl_unroll_hint(4))) for (int g = 0;
g < DECOMPRESSION_GROUP_SIZE;
g += 16, B += 16 * 16) {
GEMV_INPUT_VEC_TYPE input_value =
GEMV_INPUT_BLOCK_READ(A, g); // read 16 elements of A
GEMV_INPUT_VEC_TYPE input_value = GEMV_INPUT_BLOCK_READ(A, g);

GEMV_FILTER_PACKED_VEC_TYPE bx16 = TO_GEMV_FILTER_PACKED_VEC_TYPE(
GEMV_FILTER_BLOCK_READ(B, 0)); // read 16x16 int8 = (16x2)x16 int4
GEMV_FILTER_PACKED_VEC_TYPE bx16 =
TO_GEMV_FILTER_PACKED_VEC_TYPE(GEMV_FILTER_BLOCK_READ(B, 0));

#if WEI_UINT4
GEMV_FILTER_VEC_TYPE i4x16_even =
Expand All @@ -245,7 +381,7 @@ fully_connected_gpu_int4_gemv(__global half *input, const __global half *scales,
TO_GEMV_FILTER_VEC_TYPE(as_char16(as_uchar16(bx16) >> 4));
#else
char16 i4x16_even_c16 = (bx16 & (char16)0xF);
char16 i4x16_odd_c16 = (as_char16(as_uchar16(bx16) >> 4));
char16 i4x16_odd_c16 = (as_char16(as_uchar16(bx16) >> (uchar16)4));
i4x16_even_c16 = select(i4x16_even_c16, i4x16_even_c16 - (char16)16,
i4x16_even_c16 > (char16)7);
i4x16_odd_c16 = select(i4x16_odd_c16, i4x16_odd_c16 - (char16)16,
Expand All @@ -254,69 +390,66 @@ fully_connected_gpu_int4_gemv(__global half *input, const __global half *scales,
GEMV_FILTER_VEC_TYPE i4x16_odd = TO_GEMV_FILTER_VEC_TYPE(i4x16_odd_c16);
#endif

sum[0] += as_half(sub_group_broadcast(input_value, 0)) * i4x16_even.s0 +
as_half(sub_group_broadcast(input_value, 4)) * i4x16_even.s4 +
as_half(sub_group_broadcast(input_value, 8)) * i4x16_even.s8 +
as_half(sub_group_broadcast(input_value, 12)) * i4x16_even.sc;

sum[1] += as_half(sub_group_broadcast(input_value, 0)) * i4x16_even.s1 +
as_half(sub_group_broadcast(input_value, 4)) * i4x16_even.s5 +
as_half(sub_group_broadcast(input_value, 8)) * i4x16_even.s9 +
as_half(sub_group_broadcast(input_value, 12)) * i4x16_even.sd;

sum[2] += as_half(sub_group_broadcast(input_value, 1)) * i4x16_odd.s0 +
as_half(sub_group_broadcast(input_value, 5)) * i4x16_odd.s4 +
as_half(sub_group_broadcast(input_value, 9)) * i4x16_odd.s8 +
as_half(sub_group_broadcast(input_value, 13)) * i4x16_odd.sc;

sum[3] += as_half(sub_group_broadcast(input_value, 1)) * i4x16_odd.s1 +
as_half(sub_group_broadcast(input_value, 5)) * i4x16_odd.s5 +
as_half(sub_group_broadcast(input_value, 9)) * i4x16_odd.s9 +
as_half(sub_group_broadcast(input_value, 13)) * i4x16_odd.sd;

sum[4] += as_half(sub_group_broadcast(input_value, 2)) * i4x16_even.s2 +
as_half(sub_group_broadcast(input_value, 6)) * i4x16_even.s6 +
as_half(sub_group_broadcast(input_value, 10)) * i4x16_even.sa +
as_half(sub_group_broadcast(input_value, 14)) * i4x16_even.se;

sum[5] += as_half(sub_group_broadcast(input_value, 2)) * i4x16_even.s3 +
as_half(sub_group_broadcast(input_value, 6)) * i4x16_even.s7 +
as_half(sub_group_broadcast(input_value, 10)) * i4x16_even.sb +
as_half(sub_group_broadcast(input_value, 14)) * i4x16_even.sf;

sum[6] += as_half(sub_group_broadcast(input_value, 3)) * i4x16_odd.s2 +
as_half(sub_group_broadcast(input_value, 7)) * i4x16_odd.s6 +
as_half(sub_group_broadcast(input_value, 11)) * i4x16_odd.sa +
as_half(sub_group_broadcast(input_value, 15)) * i4x16_odd.se;

sum[7] += as_half(sub_group_broadcast(input_value, 3)) * i4x16_odd.s3 +
as_half(sub_group_broadcast(input_value, 7)) * i4x16_odd.s7 +
as_half(sub_group_broadcast(input_value, 11)) * i4x16_odd.sb +
as_half(sub_group_broadcast(input_value, 15)) * i4x16_odd.sf;
sum.s0 += as_half(SG_BCAST_HALF(input_value, 0)) * i4x16_even.s0 +
as_half(SG_BCAST_HALF(input_value, 4)) * i4x16_even.s4 +
as_half(SG_BCAST_HALF(input_value, 8)) * i4x16_even.s8 +
as_half(SG_BCAST_HALF(input_value, 12)) * i4x16_even.sc;

sum.s1 += as_half(SG_BCAST_HALF(input_value, 0)) * i4x16_even.s1 +
as_half(SG_BCAST_HALF(input_value, 4)) * i4x16_even.s5 +
as_half(SG_BCAST_HALF(input_value, 8)) * i4x16_even.s9 +
as_half(SG_BCAST_HALF(input_value, 12)) * i4x16_even.sd;

sum.s2 += as_half(SG_BCAST_HALF(input_value, 1)) * i4x16_odd.s0 +
as_half(SG_BCAST_HALF(input_value, 5)) * i4x16_odd.s4 +
as_half(SG_BCAST_HALF(input_value, 9)) * i4x16_odd.s8 +
as_half(SG_BCAST_HALF(input_value, 13)) * i4x16_odd.sc;

sum.s3 += as_half(SG_BCAST_HALF(input_value, 1)) * i4x16_odd.s1 +
as_half(SG_BCAST_HALF(input_value, 5)) * i4x16_odd.s5 +
as_half(SG_BCAST_HALF(input_value, 9)) * i4x16_odd.s9 +
as_half(SG_BCAST_HALF(input_value, 13)) * i4x16_odd.sd;

sum.s4 += as_half(SG_BCAST_HALF(input_value, 2)) * i4x16_even.s2 +
as_half(SG_BCAST_HALF(input_value, 6)) * i4x16_even.s6 +
as_half(SG_BCAST_HALF(input_value, 10)) * i4x16_even.sa +
as_half(SG_BCAST_HALF(input_value, 14)) * i4x16_even.se;

sum.s5 += as_half(SG_BCAST_HALF(input_value, 2)) * i4x16_even.s3 +
as_half(SG_BCAST_HALF(input_value, 6)) * i4x16_even.s7 +
as_half(SG_BCAST_HALF(input_value, 10)) * i4x16_even.sb +
as_half(SG_BCAST_HALF(input_value, 14)) * i4x16_even.sf;

sum.s6 += as_half(SG_BCAST_HALF(input_value, 3)) * i4x16_odd.s2 +
as_half(SG_BCAST_HALF(input_value, 7)) * i4x16_odd.s6 +
as_half(SG_BCAST_HALF(input_value, 11)) * i4x16_odd.sa +
as_half(SG_BCAST_HALF(input_value, 15)) * i4x16_odd.se;

sum.s7 += as_half(SG_BCAST_HALF(input_value, 3)) * i4x16_odd.s3 +
as_half(SG_BCAST_HALF(input_value, 7)) * i4x16_odd.s7 +
as_half(SG_BCAST_HALF(input_value, 11)) * i4x16_odd.sb +
as_half(SG_BCAST_HALF(input_value, 15)) * i4x16_odd.sf;
}

sum_all[0] += (sum[0] + sum[2] + sum[4] + sum[6]) * scale_0;
sum_all[1] += (sum[1] + sum[3] + sum[5] + sum[7]) * scale_1;
sum_all.s0 += (sum.s0 + sum.s2 + sum.s4 + sum.s6) * scale_0;
sum_all.s1 += (sum.s1 + sum.s3 + sum.s5 + sum.s7) * scale_1;
}

all_sum_even[wi_id][thr_id] = sum_all[0];
all_sum_odd[wi_id][thr_id] = sum_all[1];
all_sum_even[wi_id][thr_id] = sum_all.s0;
all_sum_odd[wi_id][thr_id] = sum_all.s1;
barrier(CLK_LOCAL_MEM_FENCE);

float2 sum_value;
sum_value[0] = as_float(
intel_sub_group_block_read((const __local uint *)all_sum_even[thr_id]));
sum_value[1] = as_float(
intel_sub_group_block_read((const __local uint *)all_sum_odd[thr_id]));
sum_value[0] = sub_group_reduce_add(sum_value[0]);
sum_value[1] = sub_group_reduce_add(sum_value[1]);
sum_value.s0 = SLM_BLOCK_READ_FLOAT(all_sum_even[thr_id]);
sum_value.s1 = SLM_BLOCK_READ_FLOAT(all_sum_odd[thr_id]);

sum_value.s0 = SG_REDUCE_ADD_FLOAT(sum_value.s0);
sum_value.s1 = SG_REDUCE_ADD_FLOAT(sum_value.s1);

if (wi_id == 0) {
int cur_n = n + thr_id;

for (int i = 0; i < 2; i++) {
output[cur_n + 16 * i] =
TO_GEMV_OUTPUT_VEC_TYPE(convert_half(sum_value[i]));
}
output[cur_n] = TO_GEMV_OUTPUT_VEC_TYPE(convert_half(sum_value.s0));
output[cur_n + 16] = TO_GEMV_OUTPUT_VEC_TYPE(convert_half(sum_value.s1));
}
}
Loading