Skip to content

use atomics for loading ICD loader handle vs. C++ static initialization #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions scripts/loader.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ apihandles = {
#include <alloca.h>
#endif

#include <atomic>

#define _SCL_MAX_NUM_PLATFORMS 64

#define _SCL_VALIDATE_HANDLE_RETURN_ERROR(_handle, _error) ${"\\"}
Expand Down Expand Up @@ -318,13 +320,34 @@ struct _cl_sampler {
#ifdef _WIN32
typedef HMODULE _sclModuleHandle;
#define _sclOpenICDLoader() ::LoadLibraryA("OpenCL.dll")
#define _sclCloseICDLoader(_module) ::FreeLibrary(_module)
#define _sclGetFunctionAddress(_module, _name) ::GetProcAddress(_module, _name)
#else
typedef void* _sclModuleHandle;
#define _sclOpenICDLoader() ::dlopen("libOpenCL.so", RTLD_LAZY | RTLD_LOCAL)
#define _sclCloseICDLoader(_module) ::dlclose(_module)
#define _sclGetFunctionAddress(_module, _name) ::dlsym(_module, _name)
#endif

static std::atomic<_sclModuleHandle> g_ICDLoaderHandle{NULL};

// This is a helper function to safely get a handle the ICD loader:
static inline _sclModuleHandle _sclGetICDLoaderHandle(void)
{
_sclModuleHandle ret = g_ICDLoaderHandle.load();
if (ret == NULL) {
_sclModuleHandle loaded = _sclOpenICDLoader();
if (loaded != NULL) {
if (g_ICDLoaderHandle.compare_exchange_strong(ret, loaded)) {
ret = loaded;
} else {
_sclCloseICDLoader(loaded);
}
}
}
return ret;
}

// This is a helper function to find a platform from context properties:
static inline cl_platform_id _sclGetPlatfromFromContextProperties(
const cl_context_properties* properties)
Expand All @@ -350,7 +373,7 @@ CL_API_ENTRY cl_int CL_API_CALL clGetPlatformIDs(
cl_platform_id* platforms,
cl_uint* num_platforms)
{
static _sclModuleHandle module = _sclOpenICDLoader();
_sclModuleHandle module = _sclGetICDLoaderHandle();
_sclpfn_clGetPlatformIDs _clGetPlatformIDs =
(_sclpfn_clGetPlatformIDs)_sclGetFunctionAddress(
module, "clGetPlatformIDs");
Expand Down Expand Up @@ -425,7 +448,7 @@ CL_API_ENTRY void* CL_API_CALL clGetExtensionFunctionAddress(
const char* function_name)
{
#if 0
static _sclModuleHandle module = _sclOpenICDLoader();
_sclModuleHandle module = _sclGetICDLoaderHandle();
_sclpfn_clGetExtensionFunctionAddress _clGetExtensionFunctionAddress =
(_sclpfn_clGetExtensionFunctionAddress)::GetProcAddress(
module, "clGetExtensionFunctionAddress");
Expand Down
27 changes: 25 additions & 2 deletions src/loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include <alloca.h>
#endif

#include <atomic>

#define _SCL_MAX_NUM_PLATFORMS 64

#define _SCL_VALIDATE_HANDLE_RETURN_ERROR(_handle, _error) \
Expand Down Expand Up @@ -1531,13 +1533,34 @@ struct _cl_sampler {
#ifdef _WIN32
typedef HMODULE _sclModuleHandle;
#define _sclOpenICDLoader() ::LoadLibraryA("OpenCL.dll")
#define _sclCloseICDLoader(_module) ::FreeLibrary(_module)
#define _sclGetFunctionAddress(_module, _name) ::GetProcAddress(_module, _name)
#else
typedef void* _sclModuleHandle;
#define _sclOpenICDLoader() ::dlopen("libOpenCL.so", RTLD_LAZY | RTLD_LOCAL)
#define _sclCloseICDLoader(_module) ::dlclose(_module)
#define _sclGetFunctionAddress(_module, _name) ::dlsym(_module, _name)
#endif

static std::atomic<_sclModuleHandle> g_ICDLoaderHandle{NULL};

// This is a helper function to safely get a handle the ICD loader:
static inline _sclModuleHandle _sclGetICDLoaderHandle(void)
{
_sclModuleHandle ret = g_ICDLoaderHandle.load();
if (ret == NULL) {
_sclModuleHandle loaded = _sclOpenICDLoader();
if (loaded != NULL) {
if (g_ICDLoaderHandle.compare_exchange_strong(ret, loaded)) {
ret = loaded;
} else {
_sclCloseICDLoader(loaded);
}
}
}
return ret;
}

// This is a helper function to find a platform from context properties:
static inline cl_platform_id _sclGetPlatfromFromContextProperties(
const cl_context_properties* properties)
Expand All @@ -1563,7 +1586,7 @@ CL_API_ENTRY cl_int CL_API_CALL clGetPlatformIDs(
cl_platform_id* platforms,
cl_uint* num_platforms)
{
static _sclModuleHandle module = _sclOpenICDLoader();
_sclModuleHandle module = _sclGetICDLoaderHandle();
_sclpfn_clGetPlatformIDs _clGetPlatformIDs =
(_sclpfn_clGetPlatformIDs)_sclGetFunctionAddress(
module, "clGetPlatformIDs");
Expand Down Expand Up @@ -1638,7 +1661,7 @@ CL_API_ENTRY void* CL_API_CALL clGetExtensionFunctionAddress(
const char* function_name)
{
#if 0
static _sclModuleHandle module = _sclOpenICDLoader();
_sclModuleHandle module = _sclGetICDLoaderHandle();
_sclpfn_clGetExtensionFunctionAddress _clGetExtensionFunctionAddress =
(_sclpfn_clGetExtensionFunctionAddress)::GetProcAddress(
module, "clGetExtensionFunctionAddress");
Expand Down