diff --git a/meson_options.txt b/meson_options.txt index ff144ff536..56a2ec035a 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -46,6 +46,7 @@ option('enable-fp16', type: 'boolean', value: false) option('enable-cublas', type: 'boolean', value: false) option('enable-openmp', type: 'boolean', value: true) option('enable-opencl', type: 'boolean', value: false) +option('enable-cuda', type: 'boolean', value: false) option('enable-biqgemm', type: 'boolean', value: false) option('biqgemm-path', type: 'string', value: '../BiQGEMM') option('enable-benchmarks', type: 'boolean', value : false) diff --git a/nntrainer/cuda_context.cpp b/nntrainer/cuda_context.cpp new file mode 100644 index 0000000000..daf09641e6 --- /dev/null +++ b/nntrainer/cuda_context.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file cuda_context.cpp + * @date 13 Nov 2025 + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + * @brief This file contains app context related functions and classes that + * manages the global configuration of the current CUDA environment. It also + * creates the CUDA stream and context. + */ + +#include "cuda_context.h" + +#include +#include +#include +#include + +#include +#include + +namespace nntrainer { +std::mutex cuda_factory_mutex; + +void CudaContext::initialize() noexcept { + try { + if (!cudaInit()) { + ml_loge("Error: CudaContext::initialize() failed"); + return; + } + + add_default_object(); + setMemAllocator(std::make_shared()); + + } catch (std::exception &e) { + ml_loge("cuda_context: registering layers failed!!, reason: %s", e.what()); + } catch (...) { + ml_loge("cuda_context: registering layer failed due to unknown reason"); + } +}; + +void CudaContext::add_default_object() { + // Register default layers that support CUDA + registerFactory(nntrainer::createLayer, + FullyConnectedLayer::type, ml::train::LayerType::LAYER_FC); + + registerFactory(nntrainer::createLayer, AdditionLayer::type, + ml::train::LayerType::LAYER_ADDITION); + + registerFactory(nntrainer::createLayer, ReshapeLayer::type, + ml::train::LayerType::LAYER_RESHAPE); +} + +template +const int CudaContext::registerFactory(const FactoryType factory, + const std::string &key, + const int int_key) { + static_assert( + isSupported::value, + "cuda_context: given type is not supported for current context"); + + auto &index = std::get>(factory_map); + auto &str_map = std::get>(index); + auto &int_map = std::get(index); + + std::string assigned_key = key == "" ? factory({})->getType() : key; + + std::transform(assigned_key.begin(), assigned_key.end(), assigned_key.begin(), + [](unsigned char c) { return std::tolower(c); }); + + const std::lock_guard lock(cuda_factory_mutex); + if (str_map.find(assigned_key) != str_map.end()) { + ml_loge("cuda_context: cannot register factory with already taken key: %s", + key.c_str()); + throw std::invalid_argument(key); + } + + if (int_key != -1 && int_map.find(int_key) != int_map.end()) { + ml_loge( + "cuda_context: cannot register factory with already taken int key: %d", + int_key); + throw std::invalid_argument(std::to_string(int_key)); + } + + int assigned_int_key = int_key == -1 ? str_map.size() + 1 : int_key; + + str_map[assigned_key] = factory; + int_map[assigned_int_key] = assigned_key; + + ml_logd("cuda_context: factory has registered with key: %s, int_key: %d", + assigned_key.c_str(), assigned_int_key); + + return assigned_int_key; +} + +bool CudaContext::cudaInit() { + // if already initialized + if (cuda_initialized) + return true; + + // Initialize CUDA context + cudaError_t err = cudaSetDevice(0); + if (err != cudaSuccess) { + ml_loge("Failed to set CUDA device: %s", cudaGetErrorString(err)); + return false; + } + + // Create CUDA stream for asynchronous operations + err = cudaStreamCreate(&stream_); + if (err != cudaSuccess) { + ml_loge("Failed to create CUDA stream: %s", cudaGetErrorString(err)); + return false; + } + + cuda_initialized = true; + return cuda_initialized; +} + +/** + * @copydoc const int CudaContext::registerFactory + */ +template const int CudaContext::registerFactory( + const FactoryType factory, const std::string &key, + const int int_key); + +} // namespace nntrainer diff --git a/nntrainer/cuda_context.h b/nntrainer/cuda_context.h new file mode 100644 index 0000000000..3cf1ce8dde --- /dev/null +++ b/nntrainer/cuda_context.h @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file cuda_context.h + * @date 13 Nov 2025 + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + * @brief This file contains app context related functions and classes that + * manages the global configuration of the current CUDA environment. It also + * creates the CUDA stream and context. + */ + +#ifndef __CUDA_CONTEXT_H__ +#define __CUDA_CONTEXT_H__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include "singleton.h" + +namespace nntrainer { + +extern std::mutex cuda_factory_mutex; + +/** + * @class CudaContext contains user-dependent configuration for CUDA support + * @brief CUDA support for app context + */ +class CudaContext : public Context, public Singleton { +public: + /** + * @brief Default constructor + */ + CudaContext() : Context(std::make_shared()) {} + + /** + * @brief destructor to release cuda context + */ + ~CudaContext() override { + if (cuda_initialized) { + // Release CUDA resources + if (stream_) { + cudaStreamDestroy(stream_); + } + } + }; + + /** + * @brief Factory register function, use this function to register custom + * object + * + * @tparam T object to create. Currently Layer is supported + * @param factory factory function that creates std::unique_ptr + * @param key key to access the factory, if key is empty, try to find key by + * calling factory({})->getType(); + * @param int_key key to access the factory by integer, if it is -1(default), + * the function automatically unsigned the key and return + * @return const int unique integer value to access the current factory + * @throw invalid argument when key and/or int_key is already taken + */ + template + const int registerFactory(const PtrFactoryType factory, + const std::string &key = "", + const int int_key = -1) { + FactoryType f = factory; + return registerFactory(f, key, int_key); + } + + /** + * @brief Factory register function, use this function to register custom + * object + * + * @tparam T object to create. Currently Layer is supported + * @param factory factory function that creates std::unique_ptr + * @param key key to access the factory, if key is empty, try to find key by + * calling factory({})->getType(); + * @param int_key key to access the factory by integer, if it is -1(default), + * the function automatically unsigned the key and return + * @return const int unique integer value to access the current factory + * @throw invalid argument when key and/or int_key is already taken + */ + template + const int registerFactory(const FactoryType factory, + const std::string &key = "", + const int int_key = -1); + + /** + * @brief Create an Object from the integer key + * + * @tparam T Type of Object, currently, Only Layer is supported + * @param int_key integer key + * @param props property + * @return PtrType unique pointer to the object + */ + template + PtrType createObject(const int int_key, + const PropsType &props = {}) const { + static_assert(isSupported::value, + "given type is not supported for current app context"); + auto &index = std::get>(factory_map); + auto &int_map = std::get(index); + + const auto &entry = int_map.find(int_key); + + if (entry == int_map.end()) { + ml_loge("Int Key is not found for the object. Key: %d", int_key); + throw exception::not_supported(std::to_string(int_key)); + } + + // entry is an object of int_map which is an unordered_map + return createObject(entry->second, props); + } + + /** + * @brief Create an Object from the string key + * + * @tparam T Type of object, currently, only Layer is supported + * @param key integer key + * @param props property + * @return PtrType unique pointer to the object + */ + template + PtrType createObject(const std::string &key, + const PropsType &props = {}) const { + auto &index = std::get>(factory_map); + auto &str_map = std::get>(index); + + std::string lower_key; + lower_key.resize(key.size()); + + std::transform(key.begin(), key.end(), lower_key.begin(), + [](unsigned char c) { return std::tolower(c); }); + + const auto &entry = str_map.find(lower_key); + + if (entry == str_map.end()) { + ml_loge("Key is not found for the object. Key: %s", lower_key.c_str()); + throw exception::not_supported(lower_key); + } + + // entry -> object of str_map -> unordered_map> + return entry->second(props); + } + + /** + * @brief Create a Layer object from the string key + * + * @param type string key + * @param properties property + * @return std::unique_ptr unique pointer to the object + */ + std::unique_ptr + createLayerObject(const std::string &type, + const std::vector &properties = {}) override { + return createObject(type, properties); + } + + /** + * @brief Create a Layer object from the integer key + * + * @param type integer key + * @param properties property + * @return std::unique_ptr unique pointer to the object + */ + std::unique_ptr + createLayerObject(const int int_key, + const std::vector &properties = {}) override { + return createObject(int_key, properties); + } + + /** + * @brief Get the name of the context + */ + std::string getName() override { return "cuda"; } + + /** + * @brief Set the Mem Allocator object + * + * @param mem Memory allocator object + */ + void setMemAllocator(std::shared_ptr mem) { + getContextData()->setMemAllocator(mem); + } + + /** + * @brief Get CUDA stream + * @return cudaStream_t + */ + cudaStream_t getStream() const { return stream_; } + +private: + /** + * @brief Overriden init function + */ + void initialize() noexcept override; + + void add_default_object(); + + // flag to check cuda initialization + bool cuda_initialized = false; + + // CUDA stream for asynchronous operations + cudaStream_t stream_ = nullptr; + + FactoryMap factory_map; + + template struct isSupportedHelper; + + /** + * @brief supportHelper to check if given type is supported within cuda + * context + */ + template + struct isSupportedHelper> { + static constexpr bool value = + (std::is_same_v, std::decay_t> || ...); + }; + + /** + * @brief supportHelper to check if given type is supported within cuda + * context + */ + template + struct isSupported : isSupportedHelper {}; + + /** + * @brief Initialize cuda context and stream + * @return true if CUDA context and stream creation is successful, + * false otherwise + */ + bool cudaInit(); +}; + +/** + * @copydoc const int CudaContext::registerFactory + */ +extern template const int CudaContext::registerFactory( + const FactoryType factory, const std::string &key, + const int int_key); + +} // namespace nntrainer + +#endif /* __CUDA_CONTEXT_H__ */ diff --git a/nntrainer/engine.cpp b/nntrainer/engine.cpp index 86f9e8b320..a5ed99055c 100644 --- a/nntrainer/engine.cpp +++ b/nntrainer/engine.cpp @@ -50,6 +50,12 @@ void Engine::add_default_object() { registerContext("gpu", &cl_context); #endif + +#ifdef ENABLE_CUDA + auto &cuda_context = nntrainer::CudaContext::Global(); + + registerContext("cuda", &cuda_context); +#endif } void Engine::initialize() noexcept { diff --git a/nntrainer/meson.build b/nntrainer/meson.build index 9daa9a04d6..6d0f3dd549 100644 --- a/nntrainer/meson.build +++ b/nntrainer/meson.build @@ -37,6 +37,35 @@ if get_option('enable-opencl') nntrainer_base_deps += clblast_dep endif +if get_option('enable-cuda') + # Add CUDA runtime library dependency + if get_option('platform') == 'windows' + # Windows: Use CUDA runtime library + cuda_dep = dependency('cuda', method: 'cmake', required: false) + if not cuda_dep.found() + # Fallback to manual library specification for Windows + cuda_lib = cc.find_library('cudart', required: false) + if cuda_lib.found() + nntrainer_base_deps += cuda_lib + endif + else + nntrainer_base_deps += cuda_dep + endif + else + # Linux and others: Use pkg-config or manual library specification + cuda_dep = dependency('cuda', method: 'pkg-config', required: false) + if not cuda_dep.found() + # Fallback to manual library specification for Linux + cuda_lib = cc.find_library('cudart', required: false) + if cuda_lib.found() + nntrainer_base_deps += cuda_lib + endif + else + nntrainer_base_deps += cuda_dep + endif + endif +endif + if get_option('platform') == 'tizen' nntrainer_base_deps += dependency('dlog') endif @@ -66,6 +95,11 @@ foreach elem : nntrainer_elements nntrainer_inc_abs += meson.current_source_dir() / elem endforeach +# Add CUDA operations subdir if CUDA is enabled +if get_option('enable-cuda') + subdir('tensor/cuda_operations') +endif + nntrainer_common_sources = [ 'nntrainer_logger.cpp', 'app_context.cpp', @@ -82,6 +116,12 @@ if get_option('enable-opencl') nntrainer_common_sources += 'cl_buffer_manager.cpp' endif +if get_option('enable-cuda') + nntrainer_headers += meson.current_source_dir() / 'cuda_context.h' + nntrainer_common_sources += 'cuda_context.cpp' + extra_defines += '-DENABLE_CUDA=1' +endif + foreach s : nntrainer_common_sources nntrainer_sources += meson.current_source_dir() / s endforeach diff --git a/nntrainer/tensor/cuda_operations/meson.build b/nntrainer/tensor/cuda_operations/meson.build new file mode 100644 index 0000000000..61cafc73c6 --- /dev/null +++ b/nntrainer/tensor/cuda_operations/meson.build @@ -0,0 +1,34 @@ +# Find CUDA compiler +dep = dependency('cuda', version : '>=13', modules : ['cublas']) + +nvcc = find_program('nvcc', required: true) + +if nvcc.found() + cuda_sources = [ + 'rmsnorm_cuda.cu' + ] + + cuda_headers = [ + 'rmsnorm_cuda.h' + ] + + kernel_objects = [] + foreach kernel : cuda_sources + obj_name = kernel.replace('.cu', '.o') + obj = custom_target(obj_name, + command: [nvcc, '-c', '-Xcompiler', '/MD', '@INPUT@', '-o', '@OUTPUT@'], + input: kernel, + output: obj_name + ) + kernel_objects += obj + endforeach + + nntrainer_sources += kernel_objects + + foreach h : cuda_headers + nntrainer_headers += meson.current_source_dir() / h + endforeach + +else + message('CUDA compiler (nvcc) not found. CUDA kernels will not be compiled.') +endif diff --git a/nntrainer/tensor/cuda_operations/rmsnorm_cuda.cu b/nntrainer/tensor/cuda_operations/rmsnorm_cuda.cu new file mode 100644 index 0000000000..cb885871f6 --- /dev/null +++ b/nntrainer/tensor/cuda_operations/rmsnorm_cuda.cu @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file rmsnorm_cuda.cpp + * @date 14 Nov 2025 + * @brief Common blas CUDA kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + * + */ + +#include "rmsnorm_cuda.h" +#include + + __global__ void rmsnorm_cuda_kernel(const float *input, float *output, + const float *alpha, float epsilon, + int H, int W) { + // Each block processes one row (height index) + int h = blockIdx.x; + int index = h * W; + + // Shared memory for reduction + extern __shared__ float sdata[]; + + // Thread index within block + int tid = threadIdx.x; + const int blockSize = blockDim.x; + + // Load input data and compute sum of squares + const float *in = input + index; + float sum_squares = 0.0f; + + // Each thread processes multiple elements if W > blockSize + for (int i = tid; i < W; i += blockSize) { + float val = in[i]; + sum_squares += val * val; + } + + // Store partial sum in shared memory + sdata[tid] = sum_squares; + __syncthreads(); + + // Reduction in shared memory + for (int s = blockSize / 2; s > 0; s >>= 1) { + if (tid < s) { + sdata[tid] += sdata[tid + s]; + } + __syncthreads(); + } + + // First thread in block computes the final result + if (tid == 0) { + float mean = sdata[0] / W; + float scale = 1.0f / sqrtf(mean + epsilon); + + // Store the scale value in shared memory for reuse + sdata[0] = scale; + } + __syncthreads(); + + // Load the computed scale + float scale = sdata[0]; + + // Compute output values + float *out = output + index; + for (int i = tid; i < W; i += blockSize) { + out[i] = in[i] * scale * alpha[i]; + } +} + +namespace nntrainer { + +void rmsnorm_cuda(const float *input, const float *gamma, float *result, + const float epsilon, unsigned int height, unsigned int width) { + // Define block size + const int blockSize = 256; + + // Calculate grid size (one block per row) + const int gridSize = height; + + // Shared memory size for reduction + const int sharedMemSize = blockSize * sizeof(float); + + // Launch the CUDA kernel + rmsnorm_cuda_kernel<<>>( + input, result, gamma, epsilon, height, width); +} + +void sscal_cuda(float *X, const unsigned int N, const float alpha) { + // TODO: Implement CUDA kernel for sscal +} + +} // namespace nntrainer diff --git a/nntrainer/tensor/cuda_operations/rmsnorm_cuda.h b/nntrainer/tensor/cuda_operations/rmsnorm_cuda.h new file mode 100644 index 0000000000..77ad2fc811 --- /dev/null +++ b/nntrainer/tensor/cuda_operations/rmsnorm_cuda.h @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file rmsnorm_cuda.h + * @date 14 Nov 2025 + * @brief Common blas CUDA kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + * + */ + +#pragma once + +namespace nntrainer { + +/** + * @brief rmsnorm each row of the tensor + * @param[in] input float * for input + * @param[in] gamma float * for gamma multiplier for each row + * @param[in] result float * for result + * @param[in] epsilon epsilon to add to each row sum to prevent division by zero + * @param[in] height height of the tensor + * @param[in] width width of the tensor + */ +void rmsnorm_cuda(const float *input, const float *gamma, float *result, + const float epsilon, unsigned int height, unsigned int width); + +/** + * @brief sscal value element by element immediately + * @param[in] X float * input + * @param[in] N unsigned int number of elements + * @param[in] alpha float multiplier + * @param[in] context RunLayerContext reference + */ +void sscal_cuda(float *X, const unsigned int N, const float alpha); + +} // namespace nntrainer