Skip to content

Commit fd48dca

Browse files
committed
Add CUDA context support with build configuration
Adds CUDA context management files (cuda_context.h and cuda_context.cpp) that provide similar functionality to the existing OpenCL context. The changes include: - CudaContext class inheriting from Context and Singleton - CUDA kernel management and execution interfaces - Build system updates to support CUDA with enable-cuda option - Conditional linking of CUDA runtime library for both Windows and Linux - Addition of enable-cuda option in meson_options.txt Signed-off-by: Daekyoung Jung <[email protected]>
1 parent 6147b46 commit fd48dca

File tree

4 files changed

+424
-0
lines changed

4 files changed

+424
-0
lines changed

meson_options.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ option('enable-fp16', type: 'boolean', value: false)
4646
option('enable-cublas', type: 'boolean', value: false)
4747
option('enable-openmp', type: 'boolean', value: true)
4848
option('enable-opencl', type: 'boolean', value: false)
49+
option('enable-cuda', type: 'boolean', value: false)
4950
option('enable-biqgemm', type: 'boolean', value: false)
5051
option('biqgemm-path', type: 'string', value: '../BiQGEMM')
5152
option('enable-benchmarks', type: 'boolean', value : false)

nntrainer/cuda_context.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
/**
3+
* Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved.
4+
*
5+
* @file cuda_context.cpp
6+
* @date 13 Nov 2025
7+
* @see https://github.com/nnstreamer/nntrainer
8+
* @author Samsung Electronics Co., Ltd.
9+
* @bug No known bugs except for NYI items
10+
* @brief This file contains app context related functions and classes that
11+
* manages the global configuration of the current CUDA environment. It also
12+
* creates the CUDA stream and context.
13+
*/
14+
15+
#include "cuda_context.h"
16+
17+
#include <addition_layer.h>
18+
#include <fc_layer.h>
19+
#include <nntrainer_log.h>
20+
#include <reshape_layer.h>
21+
22+
#include <cuda.h>
23+
#include <cuda_runtime.h>
24+
25+
namespace nntrainer {
26+
std::mutex cuda_factory_mutex;
27+
28+
void CudaContext::initialize() noexcept {
29+
try {
30+
if (!cudaInit()) {
31+
ml_loge("Error: CudaContext::initialize() failed");
32+
return;
33+
}
34+
35+
add_default_object();
36+
setMemAllocator(std::make_shared<MemAllocator>());
37+
38+
} catch (std::exception &e) {
39+
ml_loge("cuda_context: registering layers failed!!, reason: %s", e.what());
40+
} catch (...) {
41+
ml_loge("cuda_context: registering layer failed due to unknown reason");
42+
}
43+
};
44+
45+
void CudaContext::add_default_object() {
46+
// Register default layers that support CUDA
47+
registerFactory(nntrainer::createLayer<FullyConnectedLayer>,
48+
FullyConnectedLayer::type, ml::train::LayerType::LAYER_FC);
49+
50+
registerFactory(nntrainer::createLayer<AdditionLayer>, AdditionLayer::type,
51+
ml::train::LayerType::LAYER_ADDITION);
52+
53+
registerFactory(nntrainer::createLayer<ReshapeLayer>, ReshapeLayer::type,
54+
ml::train::LayerType::LAYER_RESHAPE);
55+
}
56+
57+
template <typename T>
58+
const int CudaContext::registerFactory(const FactoryType<T> factory,
59+
const std::string &key,
60+
const int int_key) {
61+
static_assert(
62+
isSupported<T>::value,
63+
"cuda_context: given type is not supported for current context");
64+
65+
auto &index = std::get<IndexType<T>>(factory_map);
66+
auto &str_map = std::get<StrIndexType<T>>(index);
67+
auto &int_map = std::get<IntIndexType>(index);
68+
69+
std::string assigned_key = key == "" ? factory({})->getType() : key;
70+
71+
std::transform(assigned_key.begin(), assigned_key.end(), assigned_key.begin(),
72+
[](unsigned char c) { return std::tolower(c); });
73+
74+
const std::lock_guard<std::mutex> lock(cuda_factory_mutex);
75+
if (str_map.find(assigned_key) != str_map.end()) {
76+
ml_loge("cuda_context: cannot register factory with already taken key: %s",
77+
key.c_str());
78+
throw std::invalid_argument(key);
79+
}
80+
81+
if (int_key != -1 && int_map.find(int_key) != int_map.end()) {
82+
ml_loge(
83+
"cuda_context: cannot register factory with already taken int key: %d",
84+
int_key);
85+
throw std::invalid_argument(std::to_string(int_key));
86+
}
87+
88+
int assigned_int_key = int_key == -1 ? str_map.size() + 1 : int_key;
89+
90+
str_map[assigned_key] = factory;
91+
int_map[assigned_int_key] = assigned_key;
92+
93+
ml_logd("cuda_context: factory has registered with key: %s, int_key: %d",
94+
assigned_key.c_str(), assigned_int_key);
95+
96+
return assigned_int_key;
97+
}
98+
99+
bool CudaContext::cudaInit() {
100+
// if already initialized
101+
if (cuda_initialized)
102+
return true;
103+
104+
// Initialize CUDA context
105+
cudaError_t err = cudaSetDevice(0);
106+
if (err != cudaSuccess) {
107+
ml_loge("Failed to set CUDA device: %s", cudaGetErrorString(err));
108+
return false;
109+
}
110+
111+
// Create CUDA stream for asynchronous operations
112+
err = cudaStreamCreate(&stream_);
113+
if (err != cudaSuccess) {
114+
ml_loge("Failed to create CUDA stream: %s", cudaGetErrorString(err));
115+
return false;
116+
}
117+
118+
cuda_initialized = true;
119+
return cuda_initialized;
120+
}
121+
122+
/**
123+
* @copydoc const int CudaContext::registerFactory
124+
*/
125+
template const int CudaContext::registerFactory<nntrainer::Layer>(
126+
const FactoryType<nntrainer::Layer> factory, const std::string &key,
127+
const int int_key);
128+
129+
} // namespace nntrainer

nntrainer/cuda_context.h

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
/**
3+
* Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved.
4+
*
5+
* @file cuda_context.h
6+
* @date 13 Nov 2025
7+
* @see https://github.com/nnstreamer/nntrainer
8+
* @author Samsung Electronics Co., Ltd.
9+
* @bug No known bugs except for NYI items
10+
* @brief This file contains app context related functions and classes that
11+
* manages the global configuration of the current CUDA environment. It also
12+
* creates the CUDA stream and context.
13+
*/
14+
15+
#ifndef __CUDA_CONTEXT_H__
16+
#define __CUDA_CONTEXT_H__
17+
18+
#include <algorithm>
19+
#include <functional>
20+
#include <memory>
21+
#include <mutex>
22+
#include <stdexcept>
23+
#include <string>
24+
#include <type_traits>
25+
#include <unordered_map>
26+
#include <vector>
27+
28+
#include <cuda.h>
29+
#include <cuda_runtime.h>
30+
31+
#include <context.h>
32+
#include <layer.h>
33+
#include <layer_devel.h>
34+
#include <mem_allocator.h>
35+
36+
#include "singleton.h"
37+
38+
namespace nntrainer {
39+
40+
extern std::mutex cuda_factory_mutex;
41+
42+
/**
43+
* @class CudaContext contains user-dependent configuration for CUDA support
44+
* @brief CUDA support for app context
45+
*/
46+
class CudaContext : public Context, public Singleton<CudaContext> {
47+
public:
48+
/**
49+
* @brief Default constructor
50+
*/
51+
CudaContext() : Context(std::make_shared<ContextData>()) {}
52+
53+
/**
54+
* @brief destructor to release cuda context
55+
*/
56+
~CudaContext() override {
57+
if (cuda_initialized) {
58+
// Release CUDA resources
59+
if (stream_) {
60+
cudaStreamDestroy(stream_);
61+
}
62+
}
63+
};
64+
65+
/**
66+
* @brief Factory register function, use this function to register custom
67+
* object
68+
*
69+
* @tparam T object to create. Currently Layer is supported
70+
* @param factory factory function that creates std::unique_ptr<T>
71+
* @param key key to access the factory, if key is empty, try to find key by
72+
* calling factory({})->getType();
73+
* @param int_key key to access the factory by integer, if it is -1(default),
74+
* the function automatically unsigned the key and return
75+
* @return const int unique integer value to access the current factory
76+
* @throw invalid argument when key and/or int_key is already taken
77+
*/
78+
template <typename T>
79+
const int registerFactory(const PtrFactoryType<T> factory,
80+
const std::string &key = "",
81+
const int int_key = -1) {
82+
FactoryType<T> f = factory;
83+
return registerFactory(f, key, int_key);
84+
}
85+
86+
/**
87+
* @brief Factory register function, use this function to register custom
88+
* object
89+
*
90+
* @tparam T object to create. Currently Layer is supported
91+
* @param factory factory function that creates std::unique_ptr<T>
92+
* @param key key to access the factory, if key is empty, try to find key by
93+
* calling factory({})->getType();
94+
* @param int_key key to access the factory by integer, if it is -1(default),
95+
* the function automatically unsigned the key and return
96+
* @return const int unique integer value to access the current factory
97+
* @throw invalid argument when key and/or int_key is already taken
98+
*/
99+
template <typename T>
100+
const int registerFactory(const FactoryType<T> factory,
101+
const std::string &key = "",
102+
const int int_key = -1);
103+
104+
/**
105+
* @brief Create an Object from the integer key
106+
*
107+
* @tparam T Type of Object, currently, Only Layer is supported
108+
* @param int_key integer key
109+
* @param props property
110+
* @return PtrType<T> unique pointer to the object
111+
*/
112+
template <typename T>
113+
PtrType<T> createObject(const int int_key,
114+
const PropsType &props = {}) const {
115+
static_assert(isSupported<T>::value,
116+
"given type is not supported for current app context");
117+
auto &index = std::get<IndexType<T>>(factory_map);
118+
auto &int_map = std::get<IntIndexType>(index);
119+
120+
const auto &entry = int_map.find(int_key);
121+
122+
if (entry == int_map.end()) {
123+
ml_loge("Int Key is not found for the object. Key: %d", int_key);
124+
throw exception::not_supported(std::to_string(int_key));
125+
}
126+
127+
// entry is an object of int_map which is an unordered_map<int, std::string>
128+
return createObject<T>(entry->second, props);
129+
}
130+
131+
/**
132+
* @brief Create an Object from the string key
133+
*
134+
* @tparam T Type of object, currently, only Layer is supported
135+
* @param key integer key
136+
* @param props property
137+
* @return PtrType<T> unique pointer to the object
138+
*/
139+
template <typename T>
140+
PtrType<T> createObject(const std::string &key,
141+
const PropsType &props = {}) const {
142+
auto &index = std::get<IndexType<T>>(factory_map);
143+
auto &str_map = std::get<StrIndexType<T>>(index);
144+
145+
std::string lower_key;
146+
lower_key.resize(key.size());
147+
148+
std::transform(key.begin(), key.end(), lower_key.begin(),
149+
[](unsigned char c) { return std::tolower(c); });
150+
151+
const auto &entry = str_map.find(lower_key);
152+
153+
if (entry == str_map.end()) {
154+
ml_loge("Key is not found for the object. Key: %s", lower_key.c_str());
155+
throw exception::not_supported(lower_key);
156+
}
157+
158+
// entry -> object of str_map -> unordered_map<std::string, FactoryType<T>>
159+
return entry->second(props);
160+
}
161+
162+
/**
163+
* @brief Create a Layer object from the string key
164+
*
165+
* @param type string key
166+
* @param properties property
167+
* @return std::unique_ptr<nntrainer::Layer> unique pointer to the object
168+
*/
169+
std::unique_ptr<nntrainer::Layer>
170+
createLayerObject(const std::string &type,
171+
const std::vector<std::string> &properties = {}) override {
172+
return createObject<nntrainer::Layer>(type, properties);
173+
}
174+
175+
/**
176+
* @brief Create a Layer object from the integer key
177+
*
178+
* @param type integer key
179+
* @param properties property
180+
* @return std::unique_ptr<nntrainer::Layer> unique pointer to the object
181+
*/
182+
std::unique_ptr<nntrainer::Layer>
183+
createLayerObject(const int int_key,
184+
const std::vector<std::string> &properties = {}) override {
185+
return createObject<nntrainer::Layer>(int_key, properties);
186+
}
187+
188+
/**
189+
* @brief Get the name of the context
190+
*/
191+
std::string getName() override { return "cuda"; }
192+
193+
/**
194+
* @brief Set the Mem Allocator object
195+
*
196+
* @param mem Memory allocator object
197+
*/
198+
void setMemAllocator(std::shared_ptr<MemAllocator> mem) {
199+
getContextData()->setMemAllocator(mem);
200+
}
201+
202+
/**
203+
* @brief Get CUDA stream
204+
* @return cudaStream_t
205+
*/
206+
cudaStream_t getStream() const { return stream_; }
207+
208+
private:
209+
/**
210+
* @brief Overriden init function
211+
*/
212+
void initialize() noexcept override;
213+
214+
void add_default_object();
215+
216+
// flag to check cuda initialization
217+
bool cuda_initialized = false;
218+
219+
// CUDA stream for asynchronous operations
220+
cudaStream_t stream_ = nullptr;
221+
222+
FactoryMap<nntrainer::Layer> factory_map;
223+
224+
template <typename Args, typename T> struct isSupportedHelper;
225+
226+
/**
227+
* @brief supportHelper to check if given type is supported within cuda
228+
* context
229+
*/
230+
template <typename T, typename... Args>
231+
struct isSupportedHelper<T, CudaContext::FactoryMap<Args...>> {
232+
static constexpr bool value =
233+
(std::is_same_v<std::decay_t<T>, std::decay_t<Args>> || ...);
234+
};
235+
236+
/**
237+
* @brief supportHelper to check if given type is supported within cuda
238+
* context
239+
*/
240+
template <typename T>
241+
struct isSupported : isSupportedHelper<T, decltype(factory_map)> {};
242+
243+
/**
244+
* @brief Initialize cuda context and stream
245+
* @return true if CUDA context and stream creation is successful,
246+
* false otherwise
247+
*/
248+
bool cudaInit();
249+
};
250+
251+
/**
252+
* @copydoc const int CudaContext::registerFactory
253+
*/
254+
extern template const int CudaContext::registerFactory<nntrainer::Layer>(
255+
const FactoryType<nntrainer::Layer> factory, const std::string &key,
256+
const int int_key);
257+
258+
} // namespace nntrainer
259+
260+
#endif /* __CUDA_CONTEXT_H__ */

0 commit comments

Comments
 (0)