diff --git a/meson_options.txt b/meson_options.txt index ff144ff536..56a2ec035a 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -46,6 +46,7 @@ option('enable-fp16', type: 'boolean', value: false) option('enable-cublas', type: 'boolean', value: false) option('enable-openmp', type: 'boolean', value: true) option('enable-opencl', type: 'boolean', value: false) +option('enable-cuda', type: 'boolean', value: false) option('enable-biqgemm', type: 'boolean', value: false) option('biqgemm-path', type: 'string', value: '../BiQGEMM') option('enable-benchmarks', type: 'boolean', value : false) diff --git a/nntrainer/cuda_context.cpp b/nntrainer/cuda_context.cpp new file mode 100644 index 0000000000..daf09641e6 --- /dev/null +++ b/nntrainer/cuda_context.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file cuda_context.cpp + * @date 13 Nov 2025 + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + * @brief This file contains app context related functions and classes that + * manages the global configuration of the current CUDA environment. It also + * creates the CUDA stream and context. + */ + +#include "cuda_context.h" + +#include +#include +#include +#include + +#include +#include + +namespace nntrainer { +std::mutex cuda_factory_mutex; + +void CudaContext::initialize() noexcept { + try { + if (!cudaInit()) { + ml_loge("Error: CudaContext::initialize() failed"); + return; + } + + add_default_object(); + setMemAllocator(std::make_shared()); + + } catch (std::exception &e) { + ml_loge("cuda_context: registering layers failed!!, reason: %s", e.what()); + } catch (...) { + ml_loge("cuda_context: registering layer failed due to unknown reason"); + } +}; + +void CudaContext::add_default_object() { + // Register default layers that support CUDA + registerFactory(nntrainer::createLayer, + FullyConnectedLayer::type, ml::train::LayerType::LAYER_FC); + + registerFactory(nntrainer::createLayer, AdditionLayer::type, + ml::train::LayerType::LAYER_ADDITION); + + registerFactory(nntrainer::createLayer, ReshapeLayer::type, + ml::train::LayerType::LAYER_RESHAPE); +} + +template +const int CudaContext::registerFactory(const FactoryType factory, + const std::string &key, + const int int_key) { + static_assert( + isSupported::value, + "cuda_context: given type is not supported for current context"); + + auto &index = std::get>(factory_map); + auto &str_map = std::get>(index); + auto &int_map = std::get(index); + + std::string assigned_key = key == "" ? factory({})->getType() : key; + + std::transform(assigned_key.begin(), assigned_key.end(), assigned_key.begin(), + [](unsigned char c) { return std::tolower(c); }); + + const std::lock_guard lock(cuda_factory_mutex); + if (str_map.find(assigned_key) != str_map.end()) { + ml_loge("cuda_context: cannot register factory with already taken key: %s", + key.c_str()); + throw std::invalid_argument(key); + } + + if (int_key != -1 && int_map.find(int_key) != int_map.end()) { + ml_loge( + "cuda_context: cannot register factory with already taken int key: %d", + int_key); + throw std::invalid_argument(std::to_string(int_key)); + } + + int assigned_int_key = int_key == -1 ? str_map.size() + 1 : int_key; + + str_map[assigned_key] = factory; + int_map[assigned_int_key] = assigned_key; + + ml_logd("cuda_context: factory has registered with key: %s, int_key: %d", + assigned_key.c_str(), assigned_int_key); + + return assigned_int_key; +} + +bool CudaContext::cudaInit() { + // if already initialized + if (cuda_initialized) + return true; + + // Initialize CUDA context + cudaError_t err = cudaSetDevice(0); + if (err != cudaSuccess) { + ml_loge("Failed to set CUDA device: %s", cudaGetErrorString(err)); + return false; + } + + // Create CUDA stream for asynchronous operations + err = cudaStreamCreate(&stream_); + if (err != cudaSuccess) { + ml_loge("Failed to create CUDA stream: %s", cudaGetErrorString(err)); + return false; + } + + cuda_initialized = true; + return cuda_initialized; +} + +/** + * @copydoc const int CudaContext::registerFactory + */ +template const int CudaContext::registerFactory( + const FactoryType factory, const std::string &key, + const int int_key); + +} // namespace nntrainer diff --git a/nntrainer/cuda_context.h b/nntrainer/cuda_context.h new file mode 100644 index 0000000000..3cf1ce8dde --- /dev/null +++ b/nntrainer/cuda_context.h @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file cuda_context.h + * @date 13 Nov 2025 + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + * @brief This file contains app context related functions and classes that + * manages the global configuration of the current CUDA environment. It also + * creates the CUDA stream and context. + */ + +#ifndef __CUDA_CONTEXT_H__ +#define __CUDA_CONTEXT_H__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include "singleton.h" + +namespace nntrainer { + +extern std::mutex cuda_factory_mutex; + +/** + * @class CudaContext contains user-dependent configuration for CUDA support + * @brief CUDA support for app context + */ +class CudaContext : public Context, public Singleton { +public: + /** + * @brief Default constructor + */ + CudaContext() : Context(std::make_shared()) {} + + /** + * @brief destructor to release cuda context + */ + ~CudaContext() override { + if (cuda_initialized) { + // Release CUDA resources + if (stream_) { + cudaStreamDestroy(stream_); + } + } + }; + + /** + * @brief Factory register function, use this function to register custom + * object + * + * @tparam T object to create. Currently Layer is supported + * @param factory factory function that creates std::unique_ptr + * @param key key to access the factory, if key is empty, try to find key by + * calling factory({})->getType(); + * @param int_key key to access the factory by integer, if it is -1(default), + * the function automatically unsigned the key and return + * @return const int unique integer value to access the current factory + * @throw invalid argument when key and/or int_key is already taken + */ + template + const int registerFactory(const PtrFactoryType factory, + const std::string &key = "", + const int int_key = -1) { + FactoryType f = factory; + return registerFactory(f, key, int_key); + } + + /** + * @brief Factory register function, use this function to register custom + * object + * + * @tparam T object to create. Currently Layer is supported + * @param factory factory function that creates std::unique_ptr + * @param key key to access the factory, if key is empty, try to find key by + * calling factory({})->getType(); + * @param int_key key to access the factory by integer, if it is -1(default), + * the function automatically unsigned the key and return + * @return const int unique integer value to access the current factory + * @throw invalid argument when key and/or int_key is already taken + */ + template + const int registerFactory(const FactoryType factory, + const std::string &key = "", + const int int_key = -1); + + /** + * @brief Create an Object from the integer key + * + * @tparam T Type of Object, currently, Only Layer is supported + * @param int_key integer key + * @param props property + * @return PtrType unique pointer to the object + */ + template + PtrType createObject(const int int_key, + const PropsType &props = {}) const { + static_assert(isSupported::value, + "given type is not supported for current app context"); + auto &index = std::get>(factory_map); + auto &int_map = std::get(index); + + const auto &entry = int_map.find(int_key); + + if (entry == int_map.end()) { + ml_loge("Int Key is not found for the object. Key: %d", int_key); + throw exception::not_supported(std::to_string(int_key)); + } + + // entry is an object of int_map which is an unordered_map + return createObject(entry->second, props); + } + + /** + * @brief Create an Object from the string key + * + * @tparam T Type of object, currently, only Layer is supported + * @param key integer key + * @param props property + * @return PtrType unique pointer to the object + */ + template + PtrType createObject(const std::string &key, + const PropsType &props = {}) const { + auto &index = std::get>(factory_map); + auto &str_map = std::get>(index); + + std::string lower_key; + lower_key.resize(key.size()); + + std::transform(key.begin(), key.end(), lower_key.begin(), + [](unsigned char c) { return std::tolower(c); }); + + const auto &entry = str_map.find(lower_key); + + if (entry == str_map.end()) { + ml_loge("Key is not found for the object. Key: %s", lower_key.c_str()); + throw exception::not_supported(lower_key); + } + + // entry -> object of str_map -> unordered_map> + return entry->second(props); + } + + /** + * @brief Create a Layer object from the string key + * + * @param type string key + * @param properties property + * @return std::unique_ptr unique pointer to the object + */ + std::unique_ptr + createLayerObject(const std::string &type, + const std::vector &properties = {}) override { + return createObject(type, properties); + } + + /** + * @brief Create a Layer object from the integer key + * + * @param type integer key + * @param properties property + * @return std::unique_ptr unique pointer to the object + */ + std::unique_ptr + createLayerObject(const int int_key, + const std::vector &properties = {}) override { + return createObject(int_key, properties); + } + + /** + * @brief Get the name of the context + */ + std::string getName() override { return "cuda"; } + + /** + * @brief Set the Mem Allocator object + * + * @param mem Memory allocator object + */ + void setMemAllocator(std::shared_ptr mem) { + getContextData()->setMemAllocator(mem); + } + + /** + * @brief Get CUDA stream + * @return cudaStream_t + */ + cudaStream_t getStream() const { return stream_; } + +private: + /** + * @brief Overriden init function + */ + void initialize() noexcept override; + + void add_default_object(); + + // flag to check cuda initialization + bool cuda_initialized = false; + + // CUDA stream for asynchronous operations + cudaStream_t stream_ = nullptr; + + FactoryMap factory_map; + + template struct isSupportedHelper; + + /** + * @brief supportHelper to check if given type is supported within cuda + * context + */ + template + struct isSupportedHelper> { + static constexpr bool value = + (std::is_same_v, std::decay_t> || ...); + }; + + /** + * @brief supportHelper to check if given type is supported within cuda + * context + */ + template + struct isSupported : isSupportedHelper {}; + + /** + * @brief Initialize cuda context and stream + * @return true if CUDA context and stream creation is successful, + * false otherwise + */ + bool cudaInit(); +}; + +/** + * @copydoc const int CudaContext::registerFactory + */ +extern template const int CudaContext::registerFactory( + const FactoryType factory, const std::string &key, + const int int_key); + +} // namespace nntrainer + +#endif /* __CUDA_CONTEXT_H__ */ diff --git a/nntrainer/engine.cpp b/nntrainer/engine.cpp index 86f9e8b320..a5ed99055c 100644 --- a/nntrainer/engine.cpp +++ b/nntrainer/engine.cpp @@ -50,6 +50,12 @@ void Engine::add_default_object() { registerContext("gpu", &cl_context); #endif + +#ifdef ENABLE_CUDA + auto &cuda_context = nntrainer::CudaContext::Global(); + + registerContext("cuda", &cuda_context); +#endif } void Engine::initialize() noexcept { diff --git a/nntrainer/meson.build b/nntrainer/meson.build index 9daa9a04d6..e4ee263cb0 100644 --- a/nntrainer/meson.build +++ b/nntrainer/meson.build @@ -37,6 +37,35 @@ if get_option('enable-opencl') nntrainer_base_deps += clblast_dep endif +if get_option('enable-cuda') + # Add CUDA runtime library dependency + if get_option('platform') == 'windows' + # Windows: Use CUDA runtime library + cuda_dep = dependency('cuda', method: 'cmake', required: false) + if not cuda_dep.found() + # Fallback to manual library specification for Windows + cuda_lib = cc.find_library('cudart', required: false) + if cuda_lib.found() + nntrainer_base_deps += cuda_lib + endif + else + nntrainer_base_deps += cuda_dep + endif + else + # Linux and others: Use pkg-config or manual library specification + cuda_dep = dependency('cuda', method: 'pkg-config', required: false) + if not cuda_dep.found() + # Fallback to manual library specification for Linux + cuda_lib = cc.find_library('cudart', required: false) + if cuda_lib.found() + nntrainer_base_deps += cuda_lib + endif + else + nntrainer_base_deps += cuda_dep + endif + endif +endif + if get_option('platform') == 'tizen' nntrainer_base_deps += dependency('dlog') endif @@ -82,6 +111,12 @@ if get_option('enable-opencl') nntrainer_common_sources += 'cl_buffer_manager.cpp' endif +if get_option('enable-cuda') + nntrainer_headers += meson.current_source_dir() / 'cuda_context.h' + nntrainer_common_sources += 'cuda_context.cpp' + extra_defines += '-DENABLE_CUDA=1' +endif + foreach s : nntrainer_common_sources nntrainer_sources += meson.current_source_dir() / s endforeach diff --git a/nntrainer/tensor/cuda_operations/addition_cuda.cu b/nntrainer/tensor/cuda_operations/addition_cuda.cu new file mode 100644 index 0000000000..112ceee42f --- /dev/null +++ b/nntrainer/tensor/cuda_operations/addition_cuda.cu @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file addition_cuda.cu + * @date 20 Nov 2025 + * @brief Common blas CUDA kernels for addition + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + * + */ + +#include "addition_cuda.h" +#include + +namespace nntrainer { + +__global__ void addition_cuda_kernel(const float *input, float *output, + unsigned int size_input, + unsigned int size_res) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size_res) { + output[idx] = output[idx] + input[idx % size_input]; + } +} + +void addition_cuda(const float *input, float *res, unsigned int size_input, + unsigned int size_res) { + const int blockSize = 256; + const int gridSize = (size_res + blockSize - 1) / blockSize; + + addition_cuda_kernel<<>>(input, res, size_input, + size_res); +} + +} // namespace nntrainer diff --git a/nntrainer/tensor/cuda_operations/addition_cuda.h b/nntrainer/tensor/cuda_operations/addition_cuda.h new file mode 100644 index 0000000000..b1390292fc --- /dev/null +++ b/nntrainer/tensor/cuda_operations/addition_cuda.h @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file addition_cuda.h + * @date 20 Nov 2025 + * @brief Common blas CUDA kernels for addition + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + * + */ + +#ifndef __ADDITION_CUDA_H__ +#define __ADDITION_CUDA_H__ + +namespace nntrainer { + +/** + * @brief addition : sum of all input vectors + * @param[in] input float * for input + * @param[in] res float * for result/output + * @param[in] size_input number of elements in input vector + * @param[in] size_res number of elements in result vector + */ +void addition_cuda(const float *input, float *res, unsigned int size_input, + unsigned int size_res); + +} // namespace nntrainer + +#endif /* __ADDITION_CUDA_H__ */ diff --git a/nntrainer/tensor/cuda_operations/cuda_interface.cpp b/nntrainer/tensor/cuda_operations/cuda_interface.cpp new file mode 100644 index 0000000000..a343eb20e1 --- /dev/null +++ b/nntrainer/tensor/cuda_operations/cuda_interface.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file cuda_interface.cpp + * @date 20 Nov 2025 + * @brief Interface for blas CUDA kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + * + */ + +#include +#include + +namespace nntrainer { + +Tensor dotCuda(Tensor const &input, Tensor const &m, bool trans, bool trans_m) { + // TODO: Implement CUDA dot operation + return Tensor(); +} + +void dotCuda(Tensor const &input, Tensor const &m, Tensor &result, bool trans, + bool trans_m) { + // TODO: Implement CUDA dot operation +} + +void dotBatchedCuda(Tensor const &input, Tensor const &m, Tensor &result, + bool trans, bool trans_m) { + // TODO: Implement CUDA batched dot operation +} + +void multiplyCuda(Tensor &input, float const &value) { + // TODO: Implement CUDA multiply operation +} + +void add_i_cuda(Tensor &result, Tensor const &input) { + // TODO: Implement CUDA add operation +} + +void transposeCuda(const std::string &direction, Tensor const &in, + Tensor &result) { + // TODO: Implement CUDA transpose operation +} + +void copyCuda(const Tensor &input, Tensor &result) { + // TODO: Implement CUDA copy operation +} + +float nrm2Cuda(const Tensor &input) { + // TODO: Implement CUDA nrm2 operation + return 0.0f; +} + +float asumCuda(const Tensor &input) { + // TODO: Implement CUDA asum operation + return 0.0f; +} + +int amaxCuda(const Tensor &input) { + // TODO: Implement CUDA amax operation + return 0; +} + +int aminCuda(const Tensor &input) { + // TODO: Implement CUDA amin operation + return 0; +} + +} // namespace nntrainer diff --git a/nntrainer/tensor/cuda_operations/cuda_interface.h b/nntrainer/tensor/cuda_operations/cuda_interface.h new file mode 100644 index 0000000000..628ff1e152 --- /dev/null +++ b/nntrainer/tensor/cuda_operations/cuda_interface.h @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file cuda_interface.h + * @date 20 Nov 2025 + * @brief Interface for blas CUDA kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + * + */ + +#ifndef __CUDA_INTERFACE_H__ +#define __CUDA_INTERFACE_H__ + +#include +#include + +namespace nntrainer { + +/** + * @brief Process data and dimensions for CUDA dot operation + * @param[in] input Tensor + * @param[in] m Tensor + * @param[in] RunLayerContext reference + * @param[in] trans bool + * @param[in] trans_m bool + */ +Tensor dotCuda(Tensor const &input, Tensor const &m, bool trans = false, + bool trans_m = false); + +/** + * @brief Process data and dimensions for CUDA dot operation + * @param[in] input Tensor + * @param[in] m Tensor + * @param[in] result Tensor + * @param[in] RunLayerContext reference + * @param[in] trans bool + * @param[in] trans_m bool + */ +void dotCuda(Tensor const &input, Tensor const &m, Tensor &result, + bool trans = false, bool trans_m = false); + +/** + * @brief Process data and dimensions for CUDA dot operation + * @param[in] input Tensor + * @param[in] m Tensor + * @param[in] result Tensor + * @param[in] RunLayerContext reference + * @param[in] trans bool + * @param[in] trans_m bool + */ +void dotBatchedCuda(Tensor const &input, Tensor const &m, Tensor &result, + bool trans = false, bool trans_m = false); + +/** + * @brief Multiply value element by element immediately + * @param[in] input Tensor + * @param[in] value multiplier + * @param[in] RunLayerContext reference + */ +void multiplyCuda(Tensor &input, float const &value); + +/** + * @brief Process data and dimensions for add operation + * @param[in] result Tensor + * @param[in] input Tensor + */ +void add_i_cuda(Tensor &result, Tensor const &input); + +/** + * @brief Process data and dimensions for transpose operation + * @param[in] direction string + * @param[in] input Tensor + * @param[in] result Tensor + */ +void transposeCuda(const std::string &direction, Tensor const &in, + Tensor &result); + +/** + * @brief Copy data from one tensor to another + * + * @param input Tensor + * @param result Tensor + */ +void copyCuda(const Tensor &input, Tensor &result); + +/** + * @brief nrm2 computation : Euclidean norm + * @param input Tensor + * @return Euclidean norm + * @note This function is used to compute the Euclidean norm of a vector. + */ +float nrm2Cuda(const Tensor &input); + +/** + * @brief Absolute sum computation + * + * @param input Tensor + * @return float absolute sum of the elements + */ +float asumCuda(const Tensor &input); + +/** + * @brief Absolute max computation + * + * @param input Tensor + * @return int index of the maximum absolute value + * @note Not necessarily the first if there are multiple maximums. + */ +int amaxCuda(const Tensor &input); + +/** + * @brief Absolute min computation + * + * @param input Tensor + * @return int index of the minimum absolute value + * @note Not necessarily the first if there are multiple minimums. + */ +int aminCuda(const Tensor &input); + +} // namespace nntrainer +#endif /* __CUDA_INTERFACE_H__ */ diff --git a/nntrainer/tensor/cuda_operations/meson.build b/nntrainer/tensor/cuda_operations/meson.build new file mode 100644 index 0000000000..9af6aff6a3 --- /dev/null +++ b/nntrainer/tensor/cuda_operations/meson.build @@ -0,0 +1,40 @@ +# Find CUDA compiler +dep = dependency('cuda', version : '>=13', modules : ['cublas']) + +nvcc = find_program('nvcc', required: true) + +if nvcc.found() + cuda_sources = [ + 'rmsnorm_cuda.cu', + 'addition_cuda.cu' + ] + + cuda_headers = [ + 'rmsnorm_cuda.h', + 'addition_cuda.h', + 'cuda_interface.h' + ] + + kernel_objects = [] + foreach kernel : cuda_sources + obj_name = kernel.replace('.cu', '.o') + obj = custom_target(obj_name, + command: [nvcc, '-c', '-Xcompiler', '/MD', '@INPUT@', '-o', '@OUTPUT@'], + input: kernel, + output: obj_name + ) + kernel_objects += obj + endforeach + + # Add cuda_interface.cpp to regular sources + nntrainer_sources += meson.current_source_dir() / 'cuda_interface.cpp' + + nntrainer_sources += kernel_objects + + foreach h : cuda_headers + nntrainer_headers += meson.current_source_dir() / h + endforeach + +else + message('CUDA compiler (nvcc) not found. CUDA kernels will not be compiled.') +endif diff --git a/nntrainer/tensor/cuda_operations/rmsnorm_cuda.cu b/nntrainer/tensor/cuda_operations/rmsnorm_cuda.cu new file mode 100644 index 0000000000..cb885871f6 --- /dev/null +++ b/nntrainer/tensor/cuda_operations/rmsnorm_cuda.cu @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file rmsnorm_cuda.cpp + * @date 14 Nov 2025 + * @brief Common blas CUDA kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + * + */ + +#include "rmsnorm_cuda.h" +#include + + __global__ void rmsnorm_cuda_kernel(const float *input, float *output, + const float *alpha, float epsilon, + int H, int W) { + // Each block processes one row (height index) + int h = blockIdx.x; + int index = h * W; + + // Shared memory for reduction + extern __shared__ float sdata[]; + + // Thread index within block + int tid = threadIdx.x; + const int blockSize = blockDim.x; + + // Load input data and compute sum of squares + const float *in = input + index; + float sum_squares = 0.0f; + + // Each thread processes multiple elements if W > blockSize + for (int i = tid; i < W; i += blockSize) { + float val = in[i]; + sum_squares += val * val; + } + + // Store partial sum in shared memory + sdata[tid] = sum_squares; + __syncthreads(); + + // Reduction in shared memory + for (int s = blockSize / 2; s > 0; s >>= 1) { + if (tid < s) { + sdata[tid] += sdata[tid + s]; + } + __syncthreads(); + } + + // First thread in block computes the final result + if (tid == 0) { + float mean = sdata[0] / W; + float scale = 1.0f / sqrtf(mean + epsilon); + + // Store the scale value in shared memory for reuse + sdata[0] = scale; + } + __syncthreads(); + + // Load the computed scale + float scale = sdata[0]; + + // Compute output values + float *out = output + index; + for (int i = tid; i < W; i += blockSize) { + out[i] = in[i] * scale * alpha[i]; + } +} + +namespace nntrainer { + +void rmsnorm_cuda(const float *input, const float *gamma, float *result, + const float epsilon, unsigned int height, unsigned int width) { + // Define block size + const int blockSize = 256; + + // Calculate grid size (one block per row) + const int gridSize = height; + + // Shared memory size for reduction + const int sharedMemSize = blockSize * sizeof(float); + + // Launch the CUDA kernel + rmsnorm_cuda_kernel<<>>( + input, result, gamma, epsilon, height, width); +} + +void sscal_cuda(float *X, const unsigned int N, const float alpha) { + // TODO: Implement CUDA kernel for sscal +} + +} // namespace nntrainer diff --git a/nntrainer/tensor/cuda_operations/rmsnorm_cuda.h b/nntrainer/tensor/cuda_operations/rmsnorm_cuda.h new file mode 100644 index 0000000000..77ad2fc811 --- /dev/null +++ b/nntrainer/tensor/cuda_operations/rmsnorm_cuda.h @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file rmsnorm_cuda.h + * @date 14 Nov 2025 + * @brief Common blas CUDA kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + * + */ + +#pragma once + +namespace nntrainer { + +/** + * @brief rmsnorm each row of the tensor + * @param[in] input float * for input + * @param[in] gamma float * for gamma multiplier for each row + * @param[in] result float * for result + * @param[in] epsilon epsilon to add to each row sum to prevent division by zero + * @param[in] height height of the tensor + * @param[in] width width of the tensor + */ +void rmsnorm_cuda(const float *input, const float *gamma, float *result, + const float epsilon, unsigned int height, unsigned int width); + +/** + * @brief sscal value element by element immediately + * @param[in] X float * input + * @param[in] N unsigned int number of elements + * @param[in] alpha float multiplier + * @param[in] context RunLayerContext reference + */ +void sscal_cuda(float *X, const unsigned int N, const float alpha); + +} // namespace nntrainer diff --git a/nntrainer/tensor/meson.build b/nntrainer/tensor/meson.build index ea6a74f8a3..8e1e88a415 100644 --- a/nntrainer/tensor/meson.build +++ b/nntrainer/tensor/meson.build @@ -90,6 +90,12 @@ if get_option('enable-opencl') nntrainer_inc_abs += meson.current_source_dir() / 'cl_operations' endif +if get_option('enable-cuda') + subdir('cuda_operations') + nntrainer_inc += include_directories('cuda_operations') + nntrainer_inc_abs += meson.current_source_dir() / 'cuda_operations' +endif + foreach s : tensor_sources nntrainer_sources += meson.current_source_dir() / s endforeach diff --git a/test/unittest/meson.build b/test/unittest/meson.build index b3100667bb..f004d0401d 100644 --- a/test/unittest/meson.build +++ b/test/unittest/meson.build @@ -84,6 +84,11 @@ if get_option('enable-opencl') test_target += [['unittest_attention_kernels_cl', []]] endif +if get_option('enable-cuda') + test_target += [['unittest_cuda', []]] + test_target += [['unittest_cuda_addition', []]] +endif + if get_option('enable-fp16') test_target += [['unittest_nntrainer_tensor_fp16', []]] test_target += [['unittest_nntrainer_tensor_pool_fp16', []]] @@ -96,12 +101,20 @@ if get_option('enable-profile') endif foreach target: test_target + cuda_deps = [] + cuda_link_args = [] + if target[0] == 'unittest_cuda' and get_option('enable-cuda') + cuda_deps = [cuda_dep] + cuda_link_args = ['-NODEFAULTLIB:LIBCMT', '-NOIMPLIB', '-NOEXP'] + endif + exe = executable( target[0], [target[0] + '.cpp'] + [target[1]], # below is temporary measure, we will eventually remove unittest_nntrainer_models include_directories: include_directories('models'), - dependencies: unittest_nntrainer_deps, + dependencies: unittest_nntrainer_deps + cuda_deps, + link_args: cuda_link_args, install: get_option('enable-test'), install_dir: application_install_dir ) diff --git a/test/unittest/unittest_cuda.cpp b/test/unittest/unittest_cuda.cpp new file mode 100644 index 0000000000..5f445e6bab --- /dev/null +++ b/test/unittest/unittest_cuda.cpp @@ -0,0 +1,371 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file unittest_cuda.cpp + * @date 18 Nov 2025 + * @brief Unit test for CUDA operations + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + */ + +#include +#include +#include +#include +#include +#include + +using namespace nntrainer; + +/** + * @brief Helper function to generate test data + * + * @tparam T data type + * @param size data length + * @param min_val minimum value + * @param max_val maximum value + * @return std::vector random vector + */ +template +static inline std::vector generate_test_data(size_t size, T min_val, + T max_val) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(min_val, max_val); + std::vector vec(size); + for (auto &val : vec) { + val = static_cast(dist(gen)); + } + return vec; +} + +/** + * @brief Test for rmsnorm_cuda function + */ +TEST(nntrainer_CUDA, rmsnorm_cuda_1) { + const int batch = 1; + const int channel = 1; + const int height = 67; + const int width = 3072; + + const float epsilon = 1e-6; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + /// Initialize CPU input data + nntrainer::Tensor input(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor gamma(1, 1, 1, width, t_type_nchw_fp32); + nntrainer::Tensor output_cuda(batch, channel, height, width, + t_type_nchw_fp32); + nntrainer::Tensor output_ref(batch, channel, height, width, t_type_nchw_fp32); + + /// Generate test data + auto input_data = generate_test_data(input.size(), -1.0f, 1.0f); + auto gamma_data = generate_test_data(gamma.size(), 0.5f, 2.0f); + + std::copy(input_data.begin(), input_data.end(), input.getData()); + std::copy(gamma_data.begin(), gamma_data.end(), gamma.getData()); + + /// Allocate CUDA memory + float *d_input = nullptr, *d_gamma = nullptr, *d_output = nullptr; + size_t input_size = input.size() * sizeof(float); + size_t gamma_size = gamma.size() * sizeof(float); + size_t output_size = output_cuda.size() * sizeof(float); + + cudaMalloc((void **)&d_input, input_size); + cudaMalloc((void **)&d_gamma, gamma_size); + cudaMalloc((void **)&d_output, output_size); + + /// Copy data to CUDA memory + cudaMemcpy(d_input, input.getData(), input_size, + cudaMemcpyHostToDevice); + cudaMemcpy(d_gamma, gamma.getData(), gamma_size, + cudaMemcpyHostToDevice); + + /// Reference implementation using CPU + std::function f = [](float x) { return 1 / std::sqrt(x); }; + auto t = input.multiply(input).average(3).add(epsilon); + t.apply_i(f); + input.multiply(t, output_ref); + output_ref.multiply_i(gamma); + + /// Create CUDA events for timing + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + /// CUDA implementation + rmsnorm_cuda(d_input, d_gamma, d_output, epsilon, + input.batch() * input.channel() * input.height(), input.width()); + + /// Record start time + cudaEventRecord(start); + + /// CUDA implementation + rmsnorm_cuda(d_input, d_gamma, d_output, epsilon, + input.batch() * input.channel() * input.height(), input.width()); + + /// Record stop time + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + /// Calculate elapsed time + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + std::cout << "RMSNorm CUDA kernel execution time for input size (" << batch + << ", " << channel << ", " << height << ", " << width + << "): " << milliseconds << " ms" << std::endl; + + /// Copy result back to host + cudaMemcpy(output_cuda.getData(), d_output, output_size, + cudaMemcpyDeviceToHost); + + /// Destroy CUDA events + cudaEventDestroy(start); + cudaEventDestroy(stop); + + /// Free CUDA memory + cudaFree(d_input); + cudaFree(d_gamma); + cudaFree(d_output); + + /// Compare results + float mseError = mse(output_cuda.getData(), + output_ref.getData(), output_cuda.size()); + + double cosSim = + cosine_similarity(output_cuda.getData(), + output_ref.getData(), output_cuda.size()); + + const float error_threshold = 1e-5; + const float cosine_threshold = 0.999; + + EXPECT_LE(mseError, error_threshold); + EXPECT_GE(cosSim, cosine_threshold); +} + +/** + * @brief Test for rmsnorm_cuda function with different dimensions + */ +TEST(nntrainer_CUDA, rmsnorm_cuda_2) { + const int batch = 2; + const int channel = 3; + const int height = 32; + const int width = 1024; + + const float epsilon = 1e-6; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + /// Initialize CPU input data + nntrainer::Tensor input(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor gamma(1, 1, 1, width, t_type_nchw_fp32); + nntrainer::Tensor output_cuda(batch, channel, height, width, + t_type_nchw_fp32); + nntrainer::Tensor output_ref(batch, channel, height, width, t_type_nchw_fp32); + + /// Generate test data + auto input_data = generate_test_data(input.size(), -2.0f, 2.0f); + auto gamma_data = generate_test_data(gamma.size(), 0.1f, 1.5f); + + std::copy(input_data.begin(), input_data.end(), input.getData()); + std::copy(gamma_data.begin(), gamma_data.end(), gamma.getData()); + + /// Allocate CUDA memory + float *d_input = nullptr, *d_gamma = nullptr, *d_output = nullptr; + size_t input_size = input.size() * sizeof(float); + size_t gamma_size = gamma.size() * sizeof(float); + size_t output_size = output_cuda.size() * sizeof(float); + + cudaMalloc((void **)&d_input, input_size); + cudaMalloc((void **)&d_gamma, gamma_size); + cudaMalloc((void **)&d_output, output_size); + + /// Copy data to CUDA memory + cudaMemcpy(d_input, input.getData(), input_size, + cudaMemcpyHostToDevice); + cudaMemcpy(d_gamma, gamma.getData(), gamma_size, + cudaMemcpyHostToDevice); + + /// Reference implementation using CPU + std::function f = [](float x) { return 1 / std::sqrt(x); }; + auto t = input.multiply(input).average(3).add(epsilon); + t.apply_i(f); + input.multiply(t, output_ref); + output_ref.multiply_i(gamma); + + /// Create CUDA events for timing + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + /// Record start time + cudaEventRecord(start); + + /// CUDA implementation + rmsnorm_cuda(d_input, d_gamma, d_output, epsilon, + input.batch() * input.channel() * input.height(), input.width()); + + /// Record stop time + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + /// Calculate elapsed time + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + std::cout << "RMSNorm CUDA kernel execution time for input size (" << batch + << ", " << channel << ", " << height << ", " << width + << "): " << milliseconds << " ms" << std::endl; + + /// Copy result back to host + cudaMemcpy(output_cuda.getData(), d_output, output_size, + cudaMemcpyDeviceToHost); + + /// Destroy CUDA events + cudaEventDestroy(start); + cudaEventDestroy(stop); + + /// Free CUDA memory + cudaFree(d_input); + cudaFree(d_gamma); + cudaFree(d_output); + + /// Compare results + float mseError = mse(output_cuda.getData(), + output_ref.getData(), output_cuda.size()); + + double cosSim = + cosine_similarity(output_cuda.getData(), + output_ref.getData(), output_cuda.size()); + + const float error_threshold = 1e-5; + const float cosine_threshold = 0.999; + + EXPECT_LE(mseError, error_threshold); + EXPECT_GE(cosSim, cosine_threshold); +} + +/** + * @brief Test for rmsnorm_cuda function with small epsilon + */ +TEST(nntrainer_CUDA, rmsnorm_cuda_3) { + const int batch = 1; + const int channel = 1; + const int height = 10; + const int width = 128; + + const float epsilon = 1e-12; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + /// Initialize CPU input data + nntrainer::Tensor input(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor gamma(1, 1, 1, width, t_type_nchw_fp32); + nntrainer::Tensor output_cuda(batch, channel, height, width, + t_type_nchw_fp32); + nntrainer::Tensor output_ref(batch, channel, height, width, t_type_nchw_fp32); + + /// Generate test data + auto input_data = generate_test_data(input.size(), -0.5f, 0.5f); + auto gamma_data = generate_test_data(gamma.size(), 0.8f, 1.2f); + + std::copy(input_data.begin(), input_data.end(), input.getData()); + std::copy(gamma_data.begin(), gamma_data.end(), gamma.getData()); + + /// Allocate CUDA memory + float *d_input = nullptr, *d_gamma = nullptr, *d_output = nullptr; + size_t input_size = input.size() * sizeof(float); + size_t gamma_size = gamma.size() * sizeof(float); + size_t output_size = output_cuda.size() * sizeof(float); + + cudaMalloc((void **)&d_input, input_size); + cudaMalloc((void **)&d_gamma, gamma_size); + cudaMalloc((void **)&d_output, output_size); + + /// Copy data to CUDA memory + cudaMemcpy(d_input, input.getData(), input_size, + cudaMemcpyHostToDevice); + cudaMemcpy(d_gamma, gamma.getData(), gamma_size, + cudaMemcpyHostToDevice); + + /// Reference implementation using CPU + std::function f = [](float x) { return 1 / std::sqrt(x); }; + auto t = input.multiply(input).average(3).add(epsilon); + t.apply_i(f); + input.multiply(t, output_ref); + output_ref.multiply_i(gamma); + + /// Create CUDA events for timing + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + /// Record start time + cudaEventRecord(start); + + /// CUDA implementation + rmsnorm_cuda(d_input, d_gamma, d_output, epsilon, + input.batch() * input.channel() * input.height(), input.width()); + + /// Record stop time + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + /// Calculate elapsed time + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + std::cout << "RMSNorm CUDA kernel execution time for input size (" << batch + << ", " << channel << ", " << height << ", " << width + << "): " << milliseconds << " ms" << std::endl; + + /// Copy result back to host + cudaMemcpy(output_cuda.getData(), d_output, output_size, + cudaMemcpyDeviceToHost); + + /// Destroy CUDA events + cudaEventDestroy(start); + cudaEventDestroy(stop); + + /// Free CUDA memory + cudaFree(d_input); + cudaFree(d_gamma); + cudaFree(d_output); + + /// Compare results + float mseError = mse(output_cuda.getData(), + output_ref.getData(), output_cuda.size()); + + double cosSim = + cosine_similarity(output_cuda.getData(), + output_ref.getData(), output_cuda.size()); + + const float error_threshold = 1e-5; + const float cosine_threshold = 0.999; + + EXPECT_LE(mseError, error_threshold); + EXPECT_GE(cosSim, cosine_threshold); +} + +GTEST_API_ int main(int argc, char **argv) { + int result = -1; + + try { + testing::InitGoogleTest(&argc, argv); + } catch (...) { + std::cerr << "Error during InitGoogleTest" << std::endl; + return 0; + } + + try { + result = RUN_ALL_TESTS(); + } catch (...) { + std::cerr << "Error during RUN_ALL_TESTS()" << std::endl; + } + + return result; +} diff --git a/test/unittest/unittest_cuda_addition.cpp b/test/unittest/unittest_cuda_addition.cpp new file mode 100644 index 0000000000..816927e16d --- /dev/null +++ b/test/unittest/unittest_cuda_addition.cpp @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. + * + * @file unittest_cuda_addition.cpp + * @date 20 Nov 2025 + * @brief Unit test for CUDA addition operations + * @see https://github.com/nnstreamer/nntrainer + * @author Samsung Electronics Co., Ltd. + * @bug No known bugs except for NYI items + */ + +#include +#include +#include +#include +#include "addition_cuda.h" +#include + +#define EXPECT_IN_RANGE(VAL, MIN, MAX) \ + EXPECT_GE((VAL), (MIN)); \ + EXPECT_LE((VAL), (MAX)) + +#define CUDA_CHECK(call) \ + do { \ + cudaError_t error = call; \ + if (error != cudaSuccess) { \ + std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ << " - " \ + << cudaGetErrorString(error) << std::endl; \ + FAIL(); \ + } \ + } while (0) + +using namespace nntrainer; + + +/** + * @brief Test for addition_cuda function + */ +TEST(nntrainer_CUDA, addition_cuda) { + const int size_input = 1024; + const int size_res = 1024; + + // Allocate host memory + std::vector input_data(size_input); + std::vector res_data(size_res); + std::vector expected_data(size_res); + + // Allocate device memory + float *d_input = nullptr; + float *d_res = nullptr; + CUDA_CHECK(cudaMalloc(&d_input, size_input * sizeof(float))); + CUDA_CHECK(cudaMalloc(&d_res, size_res * sizeof(float))); + + // Create CUDA events for timing + cudaEvent_t start, stop; + CUDA_CHECK(cudaEventCreate(&start)); + CUDA_CHECK(cudaEventCreate(&stop)); + + // Call CUDA function 10 times + for (int i = 0; i < 10; ++i) { + // Initialize input data + for (int j = 0; j < size_input; ++j) { + input_data[j] = static_cast((j + i) % 100) / 100.0f; + } + + // Initialize result data + for (int j = 0; j < size_res; ++j) { + res_data[j] = static_cast((j + i) % 50) / 100.0f; + expected_data[j] = res_data[j] + input_data[j % size_input]; + } + + // Copy data to device + CUDA_CHECK(cudaMemcpy(d_input, input_data.data(), size_input * sizeof(float), + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_res, res_data.data(), size_res * sizeof(float), + cudaMemcpyHostToDevice)); + + if (i == 0) { + // First call without timing + addition_cuda(d_input, d_res, size_input, size_res); + } else { + // Subsequent calls with timing + CUDA_CHECK(cudaEventRecord(start)); + addition_cuda(d_input, d_res, size_input, size_res); + CUDA_CHECK(cudaEventRecord(stop)); + CUDA_CHECK(cudaEventSynchronize(stop)); + + float milliseconds = 0; + CUDA_CHECK(cudaEventElapsedTime(&milliseconds, start, stop)); + std::cout << "addition_cuda kernel execution time for call " << (i + 1) + << ": " << milliseconds << " ms" << std::endl; + } + + // Copy result back to host + std::vector result_data(size_res); + CUDA_CHECK(cudaMemcpy(result_data.data(), d_res, size_res * sizeof(float), + cudaMemcpyDeviceToHost)); + + // Check results (only for the last iteration) + if (i == 9) { + float mseError = + mse(result_data.data(), expected_data.data(), size_res); + + double cosSim = cosine_similarity(result_data.data(), + expected_data.data(), size_res); + + const float epsilon = 1e-5; + + if (mseError > epsilon) { + std::cout << "MSE Error: " << mseError << std::endl; + } + EXPECT_IN_RANGE(mseError, 0, epsilon); + + if ((float)cosSim < 0.99) { + std::cout << "Cosine Similarity: " << (float)cosSim << std::endl; + } + EXPECT_IN_RANGE((float)cosSim, 0.99, 1); + } + } + + // Destroy CUDA events + CUDA_CHECK(cudaEventDestroy(start)); + CUDA_CHECK(cudaEventDestroy(stop)); + + // Free device memory + CUDA_CHECK(cudaFree(d_input)); + CUDA_CHECK(cudaFree(d_res)); +} + +GTEST_API_ int main(int argc, char **argv) { + int result = -1; + + try { + testing::InitGoogleTest(&argc, argv); + } catch (...) { + std::cerr << "Error during InitGoogleTest" << std::endl; + return 0; + } + + try { + result = RUN_ALL_TESTS(); + } catch (...) { + std::cerr << "Error during RUN_ALL_TESTS()" << std::endl; + } + + return result; +}