-
Notifications
You must be signed in to change notification settings - Fork 99
Add CUDA context support and build configuration #3567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dkjung
wants to merge
2
commits into
nnstreamer:main
Choose a base branch
from
dkjung:feature/cuda2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+604
−0
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| /** | ||
| * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. | ||
| * | ||
| * @file cuda_context.cpp | ||
| * @date 13 Nov 2025 | ||
| * @see https://github.com/nnstreamer/nntrainer | ||
| * @author Samsung Electronics Co., Ltd. | ||
| * @bug No known bugs except for NYI items | ||
| * @brief This file contains app context related functions and classes that | ||
| * manages the global configuration of the current CUDA environment. It also | ||
| * creates the CUDA stream and context. | ||
| */ | ||
|
|
||
| #include "cuda_context.h" | ||
|
|
||
| #include <addition_layer.h> | ||
| #include <fc_layer.h> | ||
| #include <nntrainer_log.h> | ||
| #include <reshape_layer.h> | ||
|
|
||
| #include <cuda.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| namespace nntrainer { | ||
| std::mutex cuda_factory_mutex; | ||
|
|
||
| void CudaContext::initialize() noexcept { | ||
| try { | ||
| if (!cudaInit()) { | ||
| ml_loge("Error: CudaContext::initialize() failed"); | ||
| return; | ||
| } | ||
|
|
||
| add_default_object(); | ||
| setMemAllocator(std::make_shared<MemAllocator>()); | ||
|
|
||
| } catch (std::exception &e) { | ||
| ml_loge("cuda_context: registering layers failed!!, reason: %s", e.what()); | ||
| } catch (...) { | ||
| ml_loge("cuda_context: registering layer failed due to unknown reason"); | ||
| } | ||
| }; | ||
|
|
||
| void CudaContext::add_default_object() { | ||
| // Register default layers that support CUDA | ||
| registerFactory(nntrainer::createLayer<FullyConnectedLayer>, | ||
| FullyConnectedLayer::type, ml::train::LayerType::LAYER_FC); | ||
|
|
||
| registerFactory(nntrainer::createLayer<AdditionLayer>, AdditionLayer::type, | ||
| ml::train::LayerType::LAYER_ADDITION); | ||
|
|
||
| registerFactory(nntrainer::createLayer<ReshapeLayer>, ReshapeLayer::type, | ||
| ml::train::LayerType::LAYER_RESHAPE); | ||
| } | ||
|
|
||
| template <typename T> | ||
| const int CudaContext::registerFactory(const FactoryType<T> factory, | ||
| const std::string &key, | ||
| const int int_key) { | ||
| static_assert( | ||
| isSupported<T>::value, | ||
| "cuda_context: given type is not supported for current context"); | ||
|
|
||
| auto &index = std::get<IndexType<T>>(factory_map); | ||
| auto &str_map = std::get<StrIndexType<T>>(index); | ||
| auto &int_map = std::get<IntIndexType>(index); | ||
|
|
||
| std::string assigned_key = key == "" ? factory({})->getType() : key; | ||
|
|
||
| std::transform(assigned_key.begin(), assigned_key.end(), assigned_key.begin(), | ||
| [](unsigned char c) { return std::tolower(c); }); | ||
|
|
||
| const std::lock_guard<std::mutex> lock(cuda_factory_mutex); | ||
| if (str_map.find(assigned_key) != str_map.end()) { | ||
| ml_loge("cuda_context: cannot register factory with already taken key: %s", | ||
| key.c_str()); | ||
| throw std::invalid_argument(key); | ||
| } | ||
|
|
||
| if (int_key != -1 && int_map.find(int_key) != int_map.end()) { | ||
| ml_loge( | ||
| "cuda_context: cannot register factory with already taken int key: %d", | ||
| int_key); | ||
| throw std::invalid_argument(std::to_string(int_key)); | ||
| } | ||
|
|
||
| int assigned_int_key = int_key == -1 ? str_map.size() + 1 : int_key; | ||
|
|
||
| str_map[assigned_key] = factory; | ||
| int_map[assigned_int_key] = assigned_key; | ||
|
|
||
| ml_logd("cuda_context: factory has registered with key: %s, int_key: %d", | ||
| assigned_key.c_str(), assigned_int_key); | ||
|
|
||
| return assigned_int_key; | ||
| } | ||
|
|
||
| bool CudaContext::cudaInit() { | ||
| // if already initialized | ||
| if (cuda_initialized) | ||
| return true; | ||
|
|
||
| // Initialize CUDA context | ||
| cudaError_t err = cudaSetDevice(0); | ||
| if (err != cudaSuccess) { | ||
| ml_loge("Failed to set CUDA device: %s", cudaGetErrorString(err)); | ||
| return false; | ||
| } | ||
|
|
||
| // Create CUDA stream for asynchronous operations | ||
| err = cudaStreamCreate(&stream_); | ||
| if (err != cudaSuccess) { | ||
| ml_loge("Failed to create CUDA stream: %s", cudaGetErrorString(err)); | ||
| return false; | ||
| } | ||
|
|
||
| cuda_initialized = true; | ||
| return cuda_initialized; | ||
| } | ||
|
|
||
| /** | ||
| * @copydoc const int CudaContext::registerFactory | ||
| */ | ||
| template const int CudaContext::registerFactory<nntrainer::Layer>( | ||
| const FactoryType<nntrainer::Layer> factory, const std::string &key, | ||
| const int int_key); | ||
|
|
||
| } // namespace nntrainer | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,260 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| /** | ||
| * Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. | ||
| * | ||
| * @file cuda_context.h | ||
| * @date 13 Nov 2025 | ||
| * @see https://github.com/nnstreamer/nntrainer | ||
| * @author Samsung Electronics Co., Ltd. | ||
| * @bug No known bugs except for NYI items | ||
| * @brief This file contains app context related functions and classes that | ||
| * manages the global configuration of the current CUDA environment. It also | ||
| * creates the CUDA stream and context. | ||
| */ | ||
|
|
||
| #ifndef __CUDA_CONTEXT_H__ | ||
| #define __CUDA_CONTEXT_H__ | ||
|
|
||
| #include <algorithm> | ||
| #include <functional> | ||
| #include <memory> | ||
| #include <mutex> | ||
| #include <stdexcept> | ||
| #include <string> | ||
| #include <type_traits> | ||
| #include <unordered_map> | ||
| #include <vector> | ||
|
|
||
| #include <cuda.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #include <context.h> | ||
| #include <layer.h> | ||
| #include <layer_devel.h> | ||
| #include <mem_allocator.h> | ||
|
|
||
| #include "singleton.h" | ||
|
|
||
| namespace nntrainer { | ||
|
|
||
| extern std::mutex cuda_factory_mutex; | ||
|
|
||
| /** | ||
| * @class CudaContext contains user-dependent configuration for CUDA support | ||
| * @brief CUDA support for app context | ||
| */ | ||
| class CudaContext : public Context, public Singleton<CudaContext> { | ||
| public: | ||
| /** | ||
| * @brief Default constructor | ||
| */ | ||
| CudaContext() : Context(std::make_shared<ContextData>()) {} | ||
|
|
||
| /** | ||
| * @brief destructor to release cuda context | ||
| */ | ||
| ~CudaContext() override { | ||
| if (cuda_initialized) { | ||
| // Release CUDA resources | ||
| if (stream_) { | ||
| cudaStreamDestroy(stream_); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| /** | ||
| * @brief Factory register function, use this function to register custom | ||
| * object | ||
| * | ||
| * @tparam T object to create. Currently Layer is supported | ||
| * @param factory factory function that creates std::unique_ptr<T> | ||
| * @param key key to access the factory, if key is empty, try to find key by | ||
| * calling factory({})->getType(); | ||
| * @param int_key key to access the factory by integer, if it is -1(default), | ||
| * the function automatically unsigned the key and return | ||
| * @return const int unique integer value to access the current factory | ||
| * @throw invalid argument when key and/or int_key is already taken | ||
| */ | ||
| template <typename T> | ||
| const int registerFactory(const PtrFactoryType<T> factory, | ||
| const std::string &key = "", | ||
| const int int_key = -1) { | ||
| FactoryType<T> f = factory; | ||
| return registerFactory(f, key, int_key); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Factory register function, use this function to register custom | ||
| * object | ||
| * | ||
| * @tparam T object to create. Currently Layer is supported | ||
| * @param factory factory function that creates std::unique_ptr<T> | ||
| * @param key key to access the factory, if key is empty, try to find key by | ||
| * calling factory({})->getType(); | ||
| * @param int_key key to access the factory by integer, if it is -1(default), | ||
| * the function automatically unsigned the key and return | ||
| * @return const int unique integer value to access the current factory | ||
| * @throw invalid argument when key and/or int_key is already taken | ||
| */ | ||
| template <typename T> | ||
| const int registerFactory(const FactoryType<T> factory, | ||
| const std::string &key = "", | ||
| const int int_key = -1); | ||
|
|
||
| /** | ||
| * @brief Create an Object from the integer key | ||
| * | ||
| * @tparam T Type of Object, currently, Only Layer is supported | ||
| * @param int_key integer key | ||
| * @param props property | ||
| * @return PtrType<T> unique pointer to the object | ||
| */ | ||
| template <typename T> | ||
| PtrType<T> createObject(const int int_key, | ||
| const PropsType &props = {}) const { | ||
| static_assert(isSupported<T>::value, | ||
| "given type is not supported for current app context"); | ||
| auto &index = std::get<IndexType<T>>(factory_map); | ||
| auto &int_map = std::get<IntIndexType>(index); | ||
|
|
||
| const auto &entry = int_map.find(int_key); | ||
|
|
||
| if (entry == int_map.end()) { | ||
| ml_loge("Int Key is not found for the object. Key: %d", int_key); | ||
| throw exception::not_supported(std::to_string(int_key)); | ||
| } | ||
|
|
||
| // entry is an object of int_map which is an unordered_map<int, std::string> | ||
| return createObject<T>(entry->second, props); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Create an Object from the string key | ||
| * | ||
| * @tparam T Type of object, currently, only Layer is supported | ||
| * @param key integer key | ||
| * @param props property | ||
| * @return PtrType<T> unique pointer to the object | ||
| */ | ||
| template <typename T> | ||
| PtrType<T> createObject(const std::string &key, | ||
| const PropsType &props = {}) const { | ||
| auto &index = std::get<IndexType<T>>(factory_map); | ||
| auto &str_map = std::get<StrIndexType<T>>(index); | ||
|
|
||
| std::string lower_key; | ||
| lower_key.resize(key.size()); | ||
|
|
||
| std::transform(key.begin(), key.end(), lower_key.begin(), | ||
| [](unsigned char c) { return std::tolower(c); }); | ||
|
|
||
| const auto &entry = str_map.find(lower_key); | ||
|
|
||
| if (entry == str_map.end()) { | ||
| ml_loge("Key is not found for the object. Key: %s", lower_key.c_str()); | ||
| throw exception::not_supported(lower_key); | ||
| } | ||
|
|
||
| // entry -> object of str_map -> unordered_map<std::string, FactoryType<T>> | ||
| return entry->second(props); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Create a Layer object from the string key | ||
| * | ||
| * @param type string key | ||
| * @param properties property | ||
| * @return std::unique_ptr<nntrainer::Layer> unique pointer to the object | ||
| */ | ||
| std::unique_ptr<nntrainer::Layer> | ||
| createLayerObject(const std::string &type, | ||
| const std::vector<std::string> &properties = {}) override { | ||
| return createObject<nntrainer::Layer>(type, properties); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Create a Layer object from the integer key | ||
| * | ||
| * @param type integer key | ||
| * @param properties property | ||
| * @return std::unique_ptr<nntrainer::Layer> unique pointer to the object | ||
| */ | ||
| std::unique_ptr<nntrainer::Layer> | ||
| createLayerObject(const int int_key, | ||
| const std::vector<std::string> &properties = {}) override { | ||
| return createObject<nntrainer::Layer>(int_key, properties); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Get the name of the context | ||
| */ | ||
| std::string getName() override { return "cuda"; } | ||
|
|
||
| /** | ||
| * @brief Set the Mem Allocator object | ||
| * | ||
| * @param mem Memory allocator object | ||
| */ | ||
| void setMemAllocator(std::shared_ptr<MemAllocator> mem) { | ||
| getContextData()->setMemAllocator(mem); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Get CUDA stream | ||
| * @return cudaStream_t | ||
| */ | ||
| cudaStream_t getStream() const { return stream_; } | ||
|
|
||
| private: | ||
| /** | ||
| * @brief Overriden init function | ||
| */ | ||
| void initialize() noexcept override; | ||
|
|
||
| void add_default_object(); | ||
|
|
||
| // flag to check cuda initialization | ||
| bool cuda_initialized = false; | ||
|
|
||
| // CUDA stream for asynchronous operations | ||
| cudaStream_t stream_ = nullptr; | ||
|
|
||
| FactoryMap<nntrainer::Layer> factory_map; | ||
|
|
||
| template <typename Args, typename T> struct isSupportedHelper; | ||
|
|
||
| /** | ||
| * @brief supportHelper to check if given type is supported within cuda | ||
| * context | ||
| */ | ||
| template <typename T, typename... Args> | ||
| struct isSupportedHelper<T, CudaContext::FactoryMap<Args...>> { | ||
| static constexpr bool value = | ||
| (std::is_same_v<std::decay_t<T>, std::decay_t<Args>> || ...); | ||
| }; | ||
|
|
||
| /** | ||
| * @brief supportHelper to check if given type is supported within cuda | ||
| * context | ||
| */ | ||
| template <typename T> | ||
| struct isSupported : isSupportedHelper<T, decltype(factory_map)> {}; | ||
|
|
||
| /** | ||
| * @brief Initialize cuda context and stream | ||
| * @return true if CUDA context and stream creation is successful, | ||
| * false otherwise | ||
| */ | ||
| bool cudaInit(); | ||
| }; | ||
|
|
||
| /** | ||
| * @copydoc const int CudaContext::registerFactory | ||
| */ | ||
| extern template const int CudaContext::registerFactory<nntrainer::Layer>( | ||
| const FactoryType<nntrainer::Layer> factory, const std::string &key, | ||
| const int int_key); | ||
|
|
||
| } // namespace nntrainer | ||
|
|
||
| #endif /* __CUDA_CONTEXT_H__ */ |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it ok to let the caller of initialize() unaware of the init error?