Skip to content

Commit 42922de

Browse files
committed
Add CUDA context support and build configuration
This commit adds CUDA context management files (cuda_context.h and cuda_context.cpp) that provide similar functionality to the existing OpenCL context. The changes include: - Implementation of CudaContext class inheriting from Context and Singleton - CUDA kernel management and execution interface - Build system updates to support CUDA with enable-cuda meson_options - Conditional linking of CUDA runtime library for both Windows and Linux - Addition of enable-cuda option in meson_options.txt - Implementation of RMSNorm CUDA kernel and build configuration Signed-off-by: Daekyoung Jung <[email protected]>
1 parent fd48dca commit 42922de

File tree

5 files changed

+180
-0
lines changed

5 files changed

+180
-0
lines changed

nntrainer/engine.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ void Engine::add_default_object() {
5050

5151
registerContext("gpu", &cl_context);
5252
#endif
53+
54+
#ifdef ENABLE_CUDA
55+
auto &cuda_context = nntrainer::CudaContext::Global();
56+
57+
registerContext("cuda", &cuda_context);
58+
#endif
5359
}
5460

5561
void Engine::initialize() noexcept {

nntrainer/meson.build

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ foreach elem : nntrainer_elements
9595
nntrainer_inc_abs += meson.current_source_dir() / elem
9696
endforeach
9797

98+
# Add CUDA operations subdir if CUDA is enabled
99+
if get_option('enable-cuda')
100+
subdir('tensor/cuda_operations')
101+
endif
102+
98103
nntrainer_common_sources = [
99104
'nntrainer_logger.cpp',
100105
'app_context.cpp',
@@ -114,6 +119,7 @@ endif
114119
if get_option('enable-cuda')
115120
nntrainer_headers += meson.current_source_dir() / 'cuda_context.h'
116121
nntrainer_common_sources += 'cuda_context.cpp'
122+
extra_defines += '-DENABLE_CUDA=1'
117123
endif
118124

119125
foreach s : nntrainer_common_sources
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Find CUDA compiler
2+
dep = dependency('cuda', version : '>=13', modules : ['cublas'])
3+
4+
nvcc = find_program('nvcc', required: true)
5+
6+
if nvcc.found()
7+
cuda_sources = [
8+
'rmsnorm_cuda.cu'
9+
]
10+
11+
cuda_headers = [
12+
'rmsnorm_cuda.h'
13+
]
14+
15+
kernel_objects = []
16+
foreach kernel : cuda_sources
17+
obj_name = kernel.replace('.cu', '.o')
18+
obj = custom_target(obj_name,
19+
command: [nvcc, '-c', '-Xcompiler', '/MD', '@INPUT@', '-o', '@OUTPUT@'],
20+
input: kernel,
21+
output: obj_name
22+
)
23+
kernel_objects += obj
24+
endforeach
25+
26+
nntrainer_sources += kernel_objects
27+
28+
foreach h : cuda_headers
29+
nntrainer_headers += meson.current_source_dir() / h
30+
endforeach
31+
32+
else
33+
message('CUDA compiler (nvcc) not found. CUDA kernels will not be compiled.')
34+
endif
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
/**
3+
* Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved.
4+
*
5+
* @file rmsnorm_cuda.cpp
6+
* @date 14 Nov 2025
7+
* @brief Common blas CUDA kernels
8+
* @see https://github.com/nnstreamer/nntrainer
9+
* @author Samsung Electronics Co., Ltd.
10+
* @bug No known bugs except for NYI items
11+
*
12+
*/
13+
14+
#include "rmsnorm_cuda.h"
15+
#include <cuda_runtime.h>
16+
17+
__global__ void rmsnorm_cuda_kernel(const float *input, float *output,
18+
const float *alpha, float epsilon,
19+
int H, int W) {
20+
// Each block processes one row (height index)
21+
int h = blockIdx.x;
22+
int index = h * W;
23+
24+
// Shared memory for reduction
25+
extern __shared__ float sdata[];
26+
27+
// Thread index within block
28+
int tid = threadIdx.x;
29+
const int blockSize = blockDim.x;
30+
31+
// Load input data and compute sum of squares
32+
const float *in = input + index;
33+
float sum_squares = 0.0f;
34+
35+
// Each thread processes multiple elements if W > blockSize
36+
for (int i = tid; i < W; i += blockSize) {
37+
float val = in[i];
38+
sum_squares += val * val;
39+
}
40+
41+
// Store partial sum in shared memory
42+
sdata[tid] = sum_squares;
43+
__syncthreads();
44+
45+
// Reduction in shared memory
46+
for (int s = blockSize / 2; s > 0; s >>= 1) {
47+
if (tid < s) {
48+
sdata[tid] += sdata[tid + s];
49+
}
50+
__syncthreads();
51+
}
52+
53+
// First thread in block computes the final result
54+
if (tid == 0) {
55+
float mean = sdata[0] / W;
56+
float scale = 1.0f / sqrtf(mean + epsilon);
57+
58+
// Store the scale value in shared memory for reuse
59+
sdata[0] = scale;
60+
}
61+
__syncthreads();
62+
63+
// Load the computed scale
64+
float scale = sdata[0];
65+
66+
// Compute output values
67+
float *out = output + index;
68+
for (int i = tid; i < W; i += blockSize) {
69+
out[i] = in[i] * scale * alpha[i];
70+
}
71+
}
72+
73+
namespace nntrainer {
74+
75+
void rmsnorm_cuda(const float *input, const float *gamma, float *result,
76+
const float epsilon, unsigned int height, unsigned int width) {
77+
// Define block size
78+
const int blockSize = 256;
79+
80+
// Calculate grid size (one block per row)
81+
const int gridSize = height;
82+
83+
// Shared memory size for reduction
84+
const int sharedMemSize = blockSize * sizeof(float);
85+
86+
// Launch the CUDA kernel
87+
rmsnorm_cuda_kernel<<<gridSize, blockSize, sharedMemSize>>>(
88+
input, result, gamma, epsilon, height, width);
89+
}
90+
91+
void sscal_cuda(float *X, const unsigned int N, const float alpha) {
92+
// TODO: Implement CUDA kernel for sscal
93+
}
94+
95+
} // namespace nntrainer
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
/**
3+
* Copyright (C) 2025 Samsung Electronics Co., Ltd. All Rights Reserved.
4+
*
5+
* @file rmsnorm_cuda.h
6+
* @date 14 Nov 2025
7+
* @brief Common blas CUDA kernels
8+
* @see https://github.com/nnstreamer/nntrainer
9+
* @author Samsung Electronics Co., Ltd.
10+
* @bug No known bugs except for NYI items
11+
*
12+
*/
13+
14+
#pragma once
15+
16+
namespace nntrainer {
17+
18+
/**
19+
* @brief rmsnorm each row of the tensor
20+
* @param[in] input float * for input
21+
* @param[in] gamma float * for gamma multiplier for each row
22+
* @param[in] result float * for result
23+
* @param[in] epsilon epsilon to add to each row sum to prevent division by zero
24+
* @param[in] height height of the tensor
25+
* @param[in] width width of the tensor
26+
*/
27+
void rmsnorm_cuda(const float *input, const float *gamma, float *result,
28+
const float epsilon, unsigned int height, unsigned int width);
29+
30+
/**
31+
* @brief sscal value element by element immediately
32+
* @param[in] X float * input
33+
* @param[in] N unsigned int number of elements
34+
* @param[in] alpha float multiplier
35+
* @param[in] context RunLayerContext reference
36+
*/
37+
void sscal_cuda(float *X, const unsigned int N, const float alpha);
38+
39+
} // namespace nntrainer

0 commit comments

Comments
 (0)