Skip to content

Commit e2f19b3

Browse files
shaofeiqilhez
andauthored
opencl: refactor expm1 and softplus (ggml-org#19404)
* opencl: refactor expm1 * opencl: refactor softplus * opencl: use h for half literals --------- Co-authored-by: Li He <lih@qti.qualcomm.com>
1 parent 983559d commit e2f19b3

File tree

3 files changed

+319
-282
lines changed

3 files changed

+319
-282
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 144 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -548,10 +548,10 @@ struct ggml_backend_opencl_context {
548548
cl_kernel kernel_pad;
549549
cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc;
550550
cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc;
551-
cl_kernel kernel_expm1_f32_nd;
552-
cl_kernel kernel_expm1_f16_nd;
553-
cl_kernel kernel_softplus_f32_nd;
554-
cl_kernel kernel_softplus_f16_nd;
551+
cl_kernel kernel_expm1_f32, kernel_expm1_f32_4, kernel_expm1_f32_nc;
552+
cl_kernel kernel_expm1_f16, kernel_expm1_f16_4, kernel_expm1_f16_nc;
553+
cl_kernel kernel_softplus_f32, kernel_softplus_f32_4, kernel_softplus_f32_nc;
554+
cl_kernel kernel_softplus_f16, kernel_softplus_f16_4, kernel_softplus_f16_nc;
555555
cl_kernel kernel_upscale;
556556
cl_kernel kernel_upscale_bilinear;
557557
cl_kernel kernel_concat_f32;
@@ -1980,20 +1980,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
19801980
#else
19811981
const std::string kernel_src = read_file("expm1.cl");
19821982
#endif
1983-
cl_program prog;
1984-
if (!kernel_src.empty()) {
1985-
prog =
1986-
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1987-
CL_CHECK((backend_ctx->kernel_expm1_f32_nd = clCreateKernel(prog, "kernel_expm1_f32_nd", &err), err));
1988-
CL_CHECK((backend_ctx->kernel_expm1_f16_nd = clCreateKernel(prog, "kernel_expm1_f16_nd", &err), err));
1989-
GGML_LOG_CONT(".");
1990-
} else {
1991-
GGML_LOG_WARN("ggml_opencl: expm1 kernel source not found or empty. Expm1 operation will not be available.\n");
1992-
prog = nullptr;
1993-
backend_ctx->kernel_expm1_f32_nd = nullptr;
1994-
backend_ctx->kernel_expm1_f16_nd = nullptr;
1995-
}
1983+
cl_program prog =
1984+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1985+
CL_CHECK((backend_ctx->kernel_expm1_f32 = clCreateKernel(prog, "kernel_expm1_f32", &err), err));
1986+
CL_CHECK((backend_ctx->kernel_expm1_f32_4 = clCreateKernel(prog, "kernel_expm1_f32_4", &err), err));
1987+
CL_CHECK((backend_ctx->kernel_expm1_f32_nc = clCreateKernel(prog, "kernel_expm1_f32_nc", &err), err));
1988+
CL_CHECK((backend_ctx->kernel_expm1_f16 = clCreateKernel(prog, "kernel_expm1_f16", &err), err));
1989+
CL_CHECK((backend_ctx->kernel_expm1_f16_4 = clCreateKernel(prog, "kernel_expm1_f16_4", &err), err));
1990+
CL_CHECK((backend_ctx->kernel_expm1_f16_nc = clCreateKernel(prog, "kernel_expm1_f16_nc", &err), err));
19961991
CL_CHECK(clReleaseProgram(prog));
1992+
GGML_LOG_CONT(".");
19971993
}
19981994

19991995
// softplus
@@ -2005,20 +2001,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
20052001
#else
20062002
const std::string kernel_src = read_file("softplus.cl");
20072003
#endif
2008-
cl_program prog;
2009-
if (!kernel_src.empty()) {
2010-
prog =
2011-
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
2012-
CL_CHECK((backend_ctx->kernel_softplus_f32_nd = clCreateKernel(prog, "kernel_softplus_f32_nd", &err), err));
2013-
CL_CHECK((backend_ctx->kernel_softplus_f16_nd = clCreateKernel(prog, "kernel_softplus_f16_nd", &err), err));
2014-
GGML_LOG_CONT(".");
2015-
} else {
2016-
GGML_LOG_WARN("ggml_opencl: softplus kernel source not found or empty. Softplus operation will not be available.\n");
2017-
prog = nullptr;
2018-
backend_ctx->kernel_softplus_f32_nd = nullptr;
2019-
backend_ctx->kernel_softplus_f16_nd = nullptr;
2020-
}
2004+
cl_program prog =
2005+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
2006+
CL_CHECK((backend_ctx->kernel_softplus_f32 = clCreateKernel(prog, "kernel_softplus_f32", &err), err));
2007+
CL_CHECK((backend_ctx->kernel_softplus_f32_4 = clCreateKernel(prog, "kernel_softplus_f32_4", &err), err));
2008+
CL_CHECK((backend_ctx->kernel_softplus_f32_nc = clCreateKernel(prog, "kernel_softplus_f32_nc", &err), err));
2009+
CL_CHECK((backend_ctx->kernel_softplus_f16 = clCreateKernel(prog, "kernel_softplus_f16", &err), err));
2010+
CL_CHECK((backend_ctx->kernel_softplus_f16_4 = clCreateKernel(prog, "kernel_softplus_f16_4", &err), err));
2011+
CL_CHECK((backend_ctx->kernel_softplus_f16_nc = clCreateKernel(prog, "kernel_softplus_f16_nc", &err), err));
20212012
CL_CHECK(clReleaseProgram(prog));
2013+
GGML_LOG_CONT(".");
20222014
}
20232015

20242016
// upscale
@@ -3465,11 +3457,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
34653457
case GGML_UNARY_OP_TANH:
34663458
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
34673459
case GGML_UNARY_OP_EXPM1:
3468-
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
3469-
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
3460+
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
34703461
case GGML_UNARY_OP_SOFTPLUS:
3471-
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
3472-
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
3462+
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
34733463
default:
34743464
return false;
34753465
}
@@ -7396,18 +7386,8 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons
73967386
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
73977387
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
73987388

7399-
cl_ulong offset0_abs = extra0->offset + src0->view_offs;
7400-
cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
7401-
7402-
cl_kernel kernel;
7403-
if (dst->type == GGML_TYPE_F32) {
7404-
kernel = backend_ctx->kernel_expm1_f32_nd;
7405-
} else if (dst->type == GGML_TYPE_F16) {
7406-
kernel = backend_ctx->kernel_expm1_f16_nd;
7407-
} else {
7408-
GGML_ASSERT(false && "Unsupported type for ggml_cl_expm1");
7409-
}
7410-
GGML_ASSERT(kernel != nullptr);
7389+
cl_ulong offset0 = extra0->offset + src0->view_offs;
7390+
cl_ulong offsetd = extrad->offset + dst->view_offs;
74117391

74127392
const int ne00 = src0->ne[0];
74137393
const int ne01 = src0->ne[1];
@@ -7419,70 +7399,74 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons
74197399
const cl_ulong nb02 = src0->nb[2];
74207400
const cl_ulong nb03 = src0->nb[3];
74217401

7422-
const int ne10 = dst->ne[0];
7423-
const int ne11 = dst->ne[1];
7424-
const int ne12 = dst->ne[2];
7425-
const int ne13 = dst->ne[3];
7402+
const cl_ulong nb0 = dst->nb[0];
7403+
const cl_ulong nb1 = dst->nb[1];
7404+
const cl_ulong nb2 = dst->nb[2];
7405+
const cl_ulong nb3 = dst->nb[3];
74267406

7427-
const cl_ulong nb10 = dst->nb[0];
7428-
const cl_ulong nb11 = dst->nb[1];
7429-
const cl_ulong nb12 = dst->nb[2];
7430-
const cl_ulong nb13 = dst->nb[3];
7407+
cl_kernel kernel;
74317408

7432-
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
7433-
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
7434-
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
7435-
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
7409+
if (ggml_is_contiguous(src0)) {
7410+
// Handle contiguous input
7411+
int n = ggml_nelements(dst);
7412+
if (n % 4 == 0) {
7413+
if (src0->type == GGML_TYPE_F32) {
7414+
kernel = backend_ctx->kernel_expm1_f32_4;
7415+
} else {
7416+
kernel = backend_ctx->kernel_expm1_f16_4;
7417+
}
7418+
n /= 4;
7419+
} else {
7420+
if (src0->type == GGML_TYPE_F32) {
7421+
kernel = backend_ctx->kernel_expm1_f32;
7422+
} else {
7423+
kernel = backend_ctx->kernel_expm1_f16;
7424+
}
7425+
}
74367426

7437-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
7438-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
7439-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
7440-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
7441-
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
7442-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
7443-
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
7444-
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
7427+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
7428+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
7429+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
7430+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
74457431

7446-
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
7447-
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
7448-
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
7449-
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
7450-
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
7451-
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
7452-
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
7453-
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
7454-
7455-
size_t global_work_size[3];
7456-
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
7457-
return;
7458-
}
7459-
global_work_size[0] = (size_t)ne10;
7460-
global_work_size[1] = (size_t)ne11;
7461-
global_work_size[2] = (size_t)ne12;
7432+
size_t global_work_size[] = {(size_t)n, 1, 1};
7433+
size_t local_work_size[] = {64, 1, 1};
7434+
7435+
size_t * local_work_size_ptr = local_work_size;
7436+
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
7437+
local_work_size_ptr = nullptr;
7438+
}
74627439

7463-
size_t lws0 = 16, lws1 = 4, lws2 = 1;
7464-
if (ne10 < 16) lws0 = ne10;
7465-
if (ne11 < 4) lws1 = ne11;
7466-
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
7440+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
7441+
} else {
7442+
// Handle non-contiguous input
7443+
if (src0->type == GGML_TYPE_F32) {
7444+
kernel = backend_ctx->kernel_expm1_f32_nc;
7445+
} else {
7446+
kernel = backend_ctx->kernel_expm1_f16_nc;
7447+
}
74677448

7468-
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
7469-
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
7470-
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
7449+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
7450+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
7451+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
7452+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
7453+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
7454+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00));
7455+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01));
7456+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02));
7457+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03));
7458+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0));
7459+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
7460+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
7461+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
74717462

7463+
int nth = 64;
74727464

7473-
size_t local_work_size[] = {lws0, lws1, lws2};
7465+
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
7466+
size_t local_work_size[] = {(size_t)nth, 1, 1};
74747467

7475-
size_t* local_work_size_ptr = local_work_size;
7476-
if (!backend_ctx->non_uniform_workgroups) {
7477-
if (global_work_size[0] % local_work_size[0] != 0 ||
7478-
global_work_size[1] % local_work_size[1] != 0 ||
7479-
global_work_size[2] % local_work_size[2] != 0) {
7480-
local_work_size_ptr = NULL;
7481-
}
7468+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
74827469
}
7483-
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
7484-
7485-
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
74867470
}
74877471

74887472
static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7498,18 +7482,8 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c
74987482
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
74997483
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
75007484

7501-
cl_ulong offset0_abs = extra0->offset + src0->view_offs;
7502-
cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
7503-
7504-
cl_kernel kernel;
7505-
if (dst->type == GGML_TYPE_F32) {
7506-
kernel = backend_ctx->kernel_softplus_f32_nd;
7507-
} else if (dst->type == GGML_TYPE_F16) {
7508-
kernel = backend_ctx->kernel_softplus_f16_nd;
7509-
} else {
7510-
GGML_ASSERT(false && "Unsupported type for ggml_cl_softplus");
7511-
}
7512-
GGML_ASSERT(kernel != nullptr);
7485+
cl_ulong offset0 = extra0->offset + src0->view_offs;
7486+
cl_ulong offsetd = extrad->offset + dst->view_offs;
75137487

75147488
const int ne00 = src0->ne[0];
75157489
const int ne01 = src0->ne[1];
@@ -7521,70 +7495,74 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c
75217495
const cl_ulong nb02 = src0->nb[2];
75227496
const cl_ulong nb03 = src0->nb[3];
75237497

7524-
const int ne10 = dst->ne[0];
7525-
const int ne11 = dst->ne[1];
7526-
const int ne12 = dst->ne[2];
7527-
const int ne13 = dst->ne[3];
7498+
const cl_ulong nb0 = dst->nb[0];
7499+
const cl_ulong nb1 = dst->nb[1];
7500+
const cl_ulong nb2 = dst->nb[2];
7501+
const cl_ulong nb3 = dst->nb[3];
75287502

7529-
const cl_ulong nb10 = dst->nb[0];
7530-
const cl_ulong nb11 = dst->nb[1];
7531-
const cl_ulong nb12 = dst->nb[2];
7532-
const cl_ulong nb13 = dst->nb[3];
7503+
cl_kernel kernel;
75337504

7534-
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
7535-
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
7536-
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
7537-
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
7505+
if (ggml_is_contiguous(src0)) {
7506+
// Handle contiguous input
7507+
int n = ggml_nelements(dst);
7508+
if (n % 4 == 0) {
7509+
if (src0->type == GGML_TYPE_F32) {
7510+
kernel = backend_ctx->kernel_softplus_f32_4;
7511+
} else {
7512+
kernel = backend_ctx->kernel_softplus_f16_4;
7513+
}
7514+
n /= 4;
7515+
} else {
7516+
if (src0->type == GGML_TYPE_F32) {
7517+
kernel = backend_ctx->kernel_softplus_f32;
7518+
} else {
7519+
kernel = backend_ctx->kernel_softplus_f16;
7520+
}
7521+
}
75387522

7539-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
7540-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
7541-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
7542-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
7543-
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
7544-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
7545-
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
7546-
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
7523+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
7524+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
7525+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
7526+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
75477527

7548-
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
7549-
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
7550-
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
7551-
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
7552-
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
7553-
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
7554-
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
7555-
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
7556-
7557-
size_t global_work_size[3];
7558-
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
7559-
return;
7560-
}
7561-
global_work_size[0] = (size_t)ne10;
7562-
global_work_size[1] = (size_t)ne11;
7563-
global_work_size[2] = (size_t)ne12;
7528+
size_t global_work_size[] = {(size_t)n, 1, 1};
7529+
size_t local_work_size[] = {64, 1, 1};
7530+
7531+
size_t * local_work_size_ptr = local_work_size;
7532+
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
7533+
local_work_size_ptr = nullptr;
7534+
}
75647535

7565-
size_t lws0 = 16, lws1 = 4, lws2 = 1;
7566-
if (ne10 < 16) lws0 = ne10;
7567-
if (ne11 < 4) lws1 = ne11;
7568-
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
7536+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
7537+
} else {
7538+
// Handle non-contiguous input
7539+
if (src0->type == GGML_TYPE_F32) {
7540+
kernel = backend_ctx->kernel_softplus_f32_nc;
7541+
} else {
7542+
kernel = backend_ctx->kernel_softplus_f16_nc;
7543+
}
75697544

7570-
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
7571-
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
7572-
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
7545+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
7546+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
7547+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
7548+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
7549+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
7550+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00));
7551+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01));
7552+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02));
7553+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03));
7554+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0));
7555+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
7556+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
7557+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
75737558

7559+
int nth = 64;
75747560

7575-
size_t local_work_size[] = {lws0, lws1, lws2};
7561+
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
7562+
size_t local_work_size[] = {(size_t)nth, 1, 1};
75767563

7577-
size_t* local_work_size_ptr = local_work_size;
7578-
if (!backend_ctx->non_uniform_workgroups) {
7579-
if (global_work_size[0] % local_work_size[0] != 0 ||
7580-
global_work_size[1] % local_work_size[1] != 0 ||
7581-
global_work_size[2] % local_work_size[2] != 0) {
7582-
local_work_size_ptr = NULL;
7583-
}
7564+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
75847565
}
7585-
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
7586-
7587-
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
75887566
}
75897567

75907568
static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {

0 commit comments

Comments
 (0)