-
Notifications
You must be signed in to change notification settings - Fork 98
INT4 GEMV OpenCL Kernel for Adreno GPU Compatibility #3569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| // SPDX-License-Identifier: Apache-2.0 | ||
| // Modifications made by Donghyeon Jeong on September 13 2025: | ||
| // - Limit its functionality exclusively to OS_IS_YX_OSV32_ISV2 | ||
| // - Portability updates (Adreno-compatible) while preserving Intel intrinsics: | ||
|
|
||
| #if defined(cl_khr_fp16) | ||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | ||
|
|
@@ -15,7 +16,7 @@ | |
| #define CAT(x, y) __CAT(x, y) | ||
|
|
||
| #define unroll_for __attribute__((opencl_unroll_hint)) for | ||
| #define CEIL_DIV(a, b) (((a) + (b)-1) / (b)) | ||
| #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) | ||
| #define ALIGN(a, b) (CEIL_DIV(a, b) * (b)) | ||
| #define MIN(a, b) ((a) < (b) ? (a) : (b)) | ||
| #define MAX(a, b) ((a) > (b) ? (a) : (b)) | ||
|
|
@@ -52,6 +53,19 @@ | |
| #define REQD_SUB_GROUP_SIZE(sg_size) | ||
| #endif | ||
|
|
||
| // ========================================================================== | ||
| // Non-Intel: define logical subgroup mapping to X-dimension (local_id(0)) | ||
| // Intel path uses cl_intel_subgroups builtins directly. | ||
| // ========================================================================== | ||
| #if !defined(cl_intel_subgroups) | ||
| #define get_sub_group_local_id() ((uint)get_local_id(0)) | ||
| #define get_sub_group_size() ((uint)get_local_size(0)) | ||
| #define get_max_sub_group_size() ((uint)get_local_size(0)) | ||
| #endif | ||
|
|
||
| // ========================================================================== | ||
| // Block-read type plumbing (Intel path unchanged). | ||
| // ========================================================================== | ||
| #define BLOCK_READ_TYPE_size1 uchar | ||
| #define BLOCK_READ_TYPE_size2 ushort | ||
| #define BLOCK_READ_TYPE_size4 uint | ||
|
|
@@ -102,6 +116,54 @@ | |
| #define DT_FILTER_BLOCK_READ8(ptr, offset) BLOCK_READN(char, 8, ptr, offset) | ||
| #define DT_FILTER_BLOCK_READ16(ptr, offset) BLOCK_READN(char, 16, ptr, offset) | ||
|
|
||
| // ========================================================================== | ||
| // Block-read emulation (when Intel block-read intrinsics aren't present). | ||
| // ========================================================================== | ||
| #define BLOCK_READ_IMPL_1 ret = ptr[idx]; | ||
|
|
||
| #define BLOCK_READ_IMPL_2 \ | ||
| ret.s0 = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); \ | ||
| ret.s1 = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); | ||
|
|
||
| #define BLOCK_READ_IMPL_4 \ | ||
| BLOCK_READ_IMPL_2 \ | ||
| ret.s2 = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); \ | ||
| ret.s3 = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); | ||
|
|
||
| #define BLOCK_READ_IMPL_8 \ | ||
| BLOCK_READ_IMPL_4 \ | ||
| ret.s4 = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); \ | ||
| ret.s5 = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); \ | ||
| ret.s6 = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); \ | ||
| ret.s7 = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); | ||
|
|
||
| #define BLOCK_READ_IMPL_16 \ | ||
| BLOCK_READ_IMPL_8 \ | ||
| ret.s8 = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); \ | ||
| ret.s9 = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); \ | ||
| ret.sa = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); \ | ||
| ret.sb = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); \ | ||
| ret.sc = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); \ | ||
| ret.sd = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); \ | ||
| ret.se = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); \ | ||
| ret.sf = ptr[idx]; \ | ||
| idx += get_max_sub_group_size(); | ||
|
|
||
| #define BLOCK_READ_IMPL(vec_size) CAT(BLOCK_READ_IMPL_, vec_size) | ||
| #define BLOCK_READ_FUNC_NAME(type_size, vec_size) \ | ||
| MAKE_VECTOR_TYPE(BLOCK_READ_FUNC(type_size), vec_size) | ||
|
|
@@ -141,6 +203,19 @@ DECLARE_BLOCK_READ_EMULATION(1, 8) | |
| DECLARE_BLOCK_READ_EMULATION(1, 16) | ||
| #endif | ||
|
|
||
| // ---- Macro preserving intel_sub_group_block_read() with fallback ---- | ||
| #if defined(cl_intel_subgroups) | ||
| #define SLM_BLOCK_READ_FLOAT(ptr_) \ | ||
| as_float(intel_sub_group_block_read((const __local uint *)(ptr_))) | ||
| #else | ||
| #define SLM_BLOCK_READ_FLOAT(ptr_) \ | ||
| ((const __local float *)(ptr_))[get_sub_group_local_id()] | ||
| #endif | ||
| // -------------------------------------------------------------------- | ||
|
|
||
| // ========================================================================== | ||
| // GEMV configuration | ||
| // ========================================================================== | ||
| #define SIMD 16 | ||
| #define SUBGROUP_SIZE SIMD | ||
| #define DECOMPRESSION_GROUP_SIZE SIZE_QUANTIZATION_GROUP | ||
|
|
@@ -160,6 +235,59 @@ DECLARE_BLOCK_READ_EMULATION(1, 16) | |
| BLOCK_READN(half, INPUT_TILE_SIZE, ptr, offset) | ||
| #define GEMV_FILTER_BLOCK_READ(ptr, offset) BLOCK_READN(char, 16, ptr, offset) | ||
|
|
||
| // ========================================================================== | ||
| // Non-Intel subgroup broadcast / reduce emulation | ||
| // - Subgroups are lanes along X (local_id(0)) | ||
| // - thr_id is Z-dimension (local_id(2)) | ||
| // - Each thr_id slice gets its own 16-element buffer | ||
| // ========================================================================== | ||
| #if !defined(cl_intel_subgroups) | ||
|
|
||
| inline float sg_reduce_add_float(float v, __local float *buf_line) { | ||
| uint lid = get_sub_group_local_id(); // lane in X-dimension | ||
| buf_line[lid] = v; | ||
| barrier(CLK_LOCAL_MEM_FENCE); | ||
|
|
||
| uint sg_size = SUBGROUP_SIZE; // expected 16 | ||
| for (uint stride = sg_size >> 1; stride > 0; stride >>= 1) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't get it: shouldn't as we would want to reduce all no just some. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh think I get it, to make it work the way you want the reduction accumulation local buf_line[0] + to strided element should be atomic (local atomit fetch add) otherwise it races between subgroup threads. (race between is on buf_line[0] and modifiaction of buf_line[0]) Have you tested correctness of this reduction? |
||
| if (lid < stride) { | ||
| buf_line[lid] = buf_line[lid] + buf_line[lid + stride]; | ||
| } | ||
| barrier(CLK_LOCAL_MEM_FENCE); | ||
| } | ||
|
|
||
| float result = buf_line[0]; | ||
| barrier(CLK_LOCAL_MEM_FENCE); // reuse buffer_line safely later | ||
| return result; | ||
| } | ||
|
|
||
| inline half sg_broadcast_half(half v, uint src_lane, __local half *buf_line) { | ||
| uint lid = get_sub_group_local_id(); // lane in X-dimension | ||
| buf_line[lid] = v; | ||
| barrier(CLK_LOCAL_MEM_FENCE); | ||
|
|
||
| half result = buf_line[src_lane]; | ||
| barrier(CLK_LOCAL_MEM_FENCE); | ||
| return result; | ||
| } | ||
|
|
||
| #define SG_BCAST_HALF(val, lane) \ | ||
| sg_broadcast_half((val), (lane), \ | ||
| sg_bcast_buf + get_local_id(2) * SUBGROUP_SIZE) | ||
|
|
||
| #define SG_REDUCE_ADD_FLOAT(val) \ | ||
| sg_reduce_add_float((val), sg_reduce_buf + get_local_id(2) * SUBGROUP_SIZE) | ||
|
|
||
| #else // Intel: just alias to real sub-group intrinsics | ||
|
|
||
| #define SG_BCAST_HALF(val, lane) sub_group_broadcast((val), (lane)) | ||
| #define SG_REDUCE_ADD_FLOAT(val) sub_group_reduce_add((val)) | ||
|
|
||
| #endif // !cl_intel_subgroups | ||
|
|
||
| // ========================================================================== | ||
| // Helper functions | ||
| // ========================================================================== | ||
| inline int get_4bit_weight_index(int k, int n, int K, int N, int OSV) { | ||
| return (n / OSV) * (OSV * K / 2) + (n % OSV) + (k / 2) * OSV; | ||
| } | ||
|
|
@@ -184,11 +312,13 @@ inline void thread_task_splitter(const int group_num, const int thr_num, | |
| *n_end += *n_start; | ||
| } | ||
|
|
||
| __attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void | ||
| fully_connected_gpu_int4_gemv(__global half *input, const __global half *scales, | ||
| __global half *output, | ||
| const __global char *weights, const int WEIGHTS_K, | ||
| const int WEIGHTS_N) { | ||
| // ========================================================================== | ||
| // Kernel | ||
| // ========================================================================== | ||
| REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE) | ||
| kernel void fully_connected_gpu_int4_gemv( | ||
| __global half *input, const __global half *scales, __global half *output, | ||
| const __global char *weights, const int WEIGHTS_K, const int WEIGHTS_N) { | ||
| const int SCALE_GROUP_NUM = CEIL_DIV(WEIGHTS_K, SIZE_QUANTIZATION_GROUP); | ||
| int ALIGN_WEIGHTS_N = ALIGN(WEIGHTS_N, 32); | ||
| int ALIGN_WEIGHTS_K = ALIGN(WEIGHTS_K, SIZE_QUANTIZATION_GROUP); | ||
|
|
@@ -204,39 +334,45 @@ fully_connected_gpu_int4_gemv(__global half *input, const __global half *scales, | |
| __local float all_sum_even[16][16]; // [wi_id, thr_id] | ||
| __local float all_sum_odd[16][16]; | ||
|
|
||
| #if !defined(cl_intel_subgroups) | ||
| // Non-Intel: subgroup emulation scratch | ||
| __local half sg_bcast_buf[SUBGROUP_SIZE * SUBGROUP_SIZE]; // 16 * 16 | ||
| __local float sg_reduce_buf[SUBGROUP_SIZE * SUBGROUP_SIZE]; // 16 * 16 | ||
| #endif | ||
|
|
||
| #if SCALE_ROW_MAJOR | ||
| scales += ((n / 32) * 32 + (n % 32) / 2) * SCALE_GROUP_NUM; | ||
| const __global half *scales_base = | ||
| scales + ((n / 32) * 32 + (n % 32) / 2) * SCALE_GROUP_NUM; | ||
| #else | ||
| // Scale layout is fbyx | ||
| scales += (n / 32) * 32 + (n % 32) / 2; | ||
| const __global half *scales_base = scales + (n / 32) * 32 + (n % 32) / 2; | ||
| #endif | ||
|
|
||
| float2 sum_all = 0; | ||
| float2 sum_all = 0.0f; | ||
| for (int gk = gk0; gk < gk1; gk++) { | ||
| __global half *A = input + gk * DECOMPRESSION_GROUP_SIZE; | ||
| int w_id = get_4bit_weight_index(gk * DECOMPRESSION_GROUP_SIZE, n, | ||
| ALIGN_WEIGHTS_K, ALIGN_WEIGHTS_N, 32); | ||
|
|
||
| const __global char *B = weights + w_id; | ||
|
|
||
| GEMV_ACCUMULATOR_VEC_TYPE sum = 0; | ||
| GEMV_ACCUMULATOR_VEC_TYPE sum = 0.0f; | ||
|
|
||
| #if SCALE_ROW_MAJOR | ||
| float scale_0 = convert_float(scales[gk]); | ||
| float scale_1 = convert_float(scales[gk + 16 * SCALE_GROUP_NUM]); | ||
| float scale_0 = convert_float(scales_base[gk]); | ||
| float scale_1 = convert_float(scales_base[gk + 16 * SCALE_GROUP_NUM]); | ||
| #else | ||
| float scale_0 = convert_float(scales[gk * ALIGN_WEIGHTS_N]); | ||
| float scale_1 = convert_float(scales[gk * ALIGN_WEIGHTS_N + 16]); | ||
| float scale_0 = convert_float(scales_base[gk * ALIGN_WEIGHTS_N]); | ||
| float scale_1 = convert_float(scales_base[gk * ALIGN_WEIGHTS_N + 16]); | ||
| #endif | ||
|
|
||
| __attribute__((opencl_unroll_hint(4))) for (int g = 0; | ||
| g < DECOMPRESSION_GROUP_SIZE; | ||
| g += 16, B += 16 * 16) { | ||
| GEMV_INPUT_VEC_TYPE input_value = | ||
| GEMV_INPUT_BLOCK_READ(A, g); // read 16 elements of A | ||
| GEMV_INPUT_VEC_TYPE input_value = GEMV_INPUT_BLOCK_READ(A, g); | ||
|
|
||
| GEMV_FILTER_PACKED_VEC_TYPE bx16 = TO_GEMV_FILTER_PACKED_VEC_TYPE( | ||
| GEMV_FILTER_BLOCK_READ(B, 0)); // read 16x16 int8 = (16x2)x16 int4 | ||
| GEMV_FILTER_PACKED_VEC_TYPE bx16 = | ||
| TO_GEMV_FILTER_PACKED_VEC_TYPE(GEMV_FILTER_BLOCK_READ(B, 0)); | ||
|
|
||
| #if WEI_UINT4 | ||
| GEMV_FILTER_VEC_TYPE i4x16_even = | ||
|
|
@@ -245,7 +381,7 @@ fully_connected_gpu_int4_gemv(__global half *input, const __global half *scales, | |
| TO_GEMV_FILTER_VEC_TYPE(as_char16(as_uchar16(bx16) >> 4)); | ||
| #else | ||
| char16 i4x16_even_c16 = (bx16 & (char16)0xF); | ||
| char16 i4x16_odd_c16 = (as_char16(as_uchar16(bx16) >> 4)); | ||
| char16 i4x16_odd_c16 = (as_char16(as_uchar16(bx16) >> (uchar16)4)); | ||
| i4x16_even_c16 = select(i4x16_even_c16, i4x16_even_c16 - (char16)16, | ||
| i4x16_even_c16 > (char16)7); | ||
| i4x16_odd_c16 = select(i4x16_odd_c16, i4x16_odd_c16 - (char16)16, | ||
|
|
@@ -254,69 +390,66 @@ fully_connected_gpu_int4_gemv(__global half *input, const __global half *scales, | |
| GEMV_FILTER_VEC_TYPE i4x16_odd = TO_GEMV_FILTER_VEC_TYPE(i4x16_odd_c16); | ||
| #endif | ||
|
|
||
| sum[0] += as_half(sub_group_broadcast(input_value, 0)) * i4x16_even.s0 + | ||
| as_half(sub_group_broadcast(input_value, 4)) * i4x16_even.s4 + | ||
| as_half(sub_group_broadcast(input_value, 8)) * i4x16_even.s8 + | ||
| as_half(sub_group_broadcast(input_value, 12)) * i4x16_even.sc; | ||
|
|
||
| sum[1] += as_half(sub_group_broadcast(input_value, 0)) * i4x16_even.s1 + | ||
| as_half(sub_group_broadcast(input_value, 4)) * i4x16_even.s5 + | ||
| as_half(sub_group_broadcast(input_value, 8)) * i4x16_even.s9 + | ||
| as_half(sub_group_broadcast(input_value, 12)) * i4x16_even.sd; | ||
|
|
||
| sum[2] += as_half(sub_group_broadcast(input_value, 1)) * i4x16_odd.s0 + | ||
| as_half(sub_group_broadcast(input_value, 5)) * i4x16_odd.s4 + | ||
| as_half(sub_group_broadcast(input_value, 9)) * i4x16_odd.s8 + | ||
| as_half(sub_group_broadcast(input_value, 13)) * i4x16_odd.sc; | ||
|
|
||
| sum[3] += as_half(sub_group_broadcast(input_value, 1)) * i4x16_odd.s1 + | ||
| as_half(sub_group_broadcast(input_value, 5)) * i4x16_odd.s5 + | ||
| as_half(sub_group_broadcast(input_value, 9)) * i4x16_odd.s9 + | ||
| as_half(sub_group_broadcast(input_value, 13)) * i4x16_odd.sd; | ||
|
|
||
| sum[4] += as_half(sub_group_broadcast(input_value, 2)) * i4x16_even.s2 + | ||
| as_half(sub_group_broadcast(input_value, 6)) * i4x16_even.s6 + | ||
| as_half(sub_group_broadcast(input_value, 10)) * i4x16_even.sa + | ||
| as_half(sub_group_broadcast(input_value, 14)) * i4x16_even.se; | ||
|
|
||
| sum[5] += as_half(sub_group_broadcast(input_value, 2)) * i4x16_even.s3 + | ||
| as_half(sub_group_broadcast(input_value, 6)) * i4x16_even.s7 + | ||
| as_half(sub_group_broadcast(input_value, 10)) * i4x16_even.sb + | ||
| as_half(sub_group_broadcast(input_value, 14)) * i4x16_even.sf; | ||
|
|
||
| sum[6] += as_half(sub_group_broadcast(input_value, 3)) * i4x16_odd.s2 + | ||
| as_half(sub_group_broadcast(input_value, 7)) * i4x16_odd.s6 + | ||
| as_half(sub_group_broadcast(input_value, 11)) * i4x16_odd.sa + | ||
| as_half(sub_group_broadcast(input_value, 15)) * i4x16_odd.se; | ||
|
|
||
| sum[7] += as_half(sub_group_broadcast(input_value, 3)) * i4x16_odd.s3 + | ||
| as_half(sub_group_broadcast(input_value, 7)) * i4x16_odd.s7 + | ||
| as_half(sub_group_broadcast(input_value, 11)) * i4x16_odd.sb + | ||
| as_half(sub_group_broadcast(input_value, 15)) * i4x16_odd.sf; | ||
| sum.s0 += as_half(SG_BCAST_HALF(input_value, 0)) * i4x16_even.s0 + | ||
| as_half(SG_BCAST_HALF(input_value, 4)) * i4x16_even.s4 + | ||
| as_half(SG_BCAST_HALF(input_value, 8)) * i4x16_even.s8 + | ||
| as_half(SG_BCAST_HALF(input_value, 12)) * i4x16_even.sc; | ||
|
|
||
| sum.s1 += as_half(SG_BCAST_HALF(input_value, 0)) * i4x16_even.s1 + | ||
| as_half(SG_BCAST_HALF(input_value, 4)) * i4x16_even.s5 + | ||
| as_half(SG_BCAST_HALF(input_value, 8)) * i4x16_even.s9 + | ||
| as_half(SG_BCAST_HALF(input_value, 12)) * i4x16_even.sd; | ||
|
|
||
| sum.s2 += as_half(SG_BCAST_HALF(input_value, 1)) * i4x16_odd.s0 + | ||
| as_half(SG_BCAST_HALF(input_value, 5)) * i4x16_odd.s4 + | ||
| as_half(SG_BCAST_HALF(input_value, 9)) * i4x16_odd.s8 + | ||
| as_half(SG_BCAST_HALF(input_value, 13)) * i4x16_odd.sc; | ||
|
|
||
| sum.s3 += as_half(SG_BCAST_HALF(input_value, 1)) * i4x16_odd.s1 + | ||
| as_half(SG_BCAST_HALF(input_value, 5)) * i4x16_odd.s5 + | ||
| as_half(SG_BCAST_HALF(input_value, 9)) * i4x16_odd.s9 + | ||
| as_half(SG_BCAST_HALF(input_value, 13)) * i4x16_odd.sd; | ||
|
|
||
| sum.s4 += as_half(SG_BCAST_HALF(input_value, 2)) * i4x16_even.s2 + | ||
| as_half(SG_BCAST_HALF(input_value, 6)) * i4x16_even.s6 + | ||
| as_half(SG_BCAST_HALF(input_value, 10)) * i4x16_even.sa + | ||
| as_half(SG_BCAST_HALF(input_value, 14)) * i4x16_even.se; | ||
|
|
||
| sum.s5 += as_half(SG_BCAST_HALF(input_value, 2)) * i4x16_even.s3 + | ||
| as_half(SG_BCAST_HALF(input_value, 6)) * i4x16_even.s7 + | ||
| as_half(SG_BCAST_HALF(input_value, 10)) * i4x16_even.sb + | ||
| as_half(SG_BCAST_HALF(input_value, 14)) * i4x16_even.sf; | ||
|
|
||
| sum.s6 += as_half(SG_BCAST_HALF(input_value, 3)) * i4x16_odd.s2 + | ||
| as_half(SG_BCAST_HALF(input_value, 7)) * i4x16_odd.s6 + | ||
| as_half(SG_BCAST_HALF(input_value, 11)) * i4x16_odd.sa + | ||
| as_half(SG_BCAST_HALF(input_value, 15)) * i4x16_odd.se; | ||
|
|
||
| sum.s7 += as_half(SG_BCAST_HALF(input_value, 3)) * i4x16_odd.s3 + | ||
| as_half(SG_BCAST_HALF(input_value, 7)) * i4x16_odd.s7 + | ||
| as_half(SG_BCAST_HALF(input_value, 11)) * i4x16_odd.sb + | ||
| as_half(SG_BCAST_HALF(input_value, 15)) * i4x16_odd.sf; | ||
| } | ||
|
|
||
| sum_all[0] += (sum[0] + sum[2] + sum[4] + sum[6]) * scale_0; | ||
| sum_all[1] += (sum[1] + sum[3] + sum[5] + sum[7]) * scale_1; | ||
| sum_all.s0 += (sum.s0 + sum.s2 + sum.s4 + sum.s6) * scale_0; | ||
| sum_all.s1 += (sum.s1 + sum.s3 + sum.s5 + sum.s7) * scale_1; | ||
| } | ||
|
|
||
| all_sum_even[wi_id][thr_id] = sum_all[0]; | ||
| all_sum_odd[wi_id][thr_id] = sum_all[1]; | ||
| all_sum_even[wi_id][thr_id] = sum_all.s0; | ||
| all_sum_odd[wi_id][thr_id] = sum_all.s1; | ||
| barrier(CLK_LOCAL_MEM_FENCE); | ||
|
|
||
| float2 sum_value; | ||
| sum_value[0] = as_float( | ||
| intel_sub_group_block_read((const __local uint *)all_sum_even[thr_id])); | ||
| sum_value[1] = as_float( | ||
| intel_sub_group_block_read((const __local uint *)all_sum_odd[thr_id])); | ||
| sum_value[0] = sub_group_reduce_add(sum_value[0]); | ||
| sum_value[1] = sub_group_reduce_add(sum_value[1]); | ||
| sum_value.s0 = SLM_BLOCK_READ_FLOAT(all_sum_even[thr_id]); | ||
| sum_value.s1 = SLM_BLOCK_READ_FLOAT(all_sum_odd[thr_id]); | ||
|
|
||
| sum_value.s0 = SG_REDUCE_ADD_FLOAT(sum_value.s0); | ||
| sum_value.s1 = SG_REDUCE_ADD_FLOAT(sum_value.s1); | ||
|
|
||
| if (wi_id == 0) { | ||
| int cur_n = n + thr_id; | ||
|
|
||
| for (int i = 0; i < 2; i++) { | ||
| output[cur_n + 16 * i] = | ||
| TO_GEMV_OUTPUT_VEC_TYPE(convert_half(sum_value[i])); | ||
| } | ||
| output[cur_n] = TO_GEMV_OUTPUT_VEC_TYPE(convert_half(sum_value.s0)); | ||
| output[cur_n + 16] = TO_GEMV_OUTPUT_VEC_TYPE(convert_half(sum_value.s1)); | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.