We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The example gelu CUDA does not work on my machine. The reason is the generated CUDA code
l_0_3 = float_4{ erf(l_0_3.i_0),erf(l_0_3.i_1),erf(l_0_3.i_2),erf(l_0_3.i_3),};
can only be compiled with C++11 and later.
System Info: Ubuntu 20.04 nvcc Build cuda_11.8.r11.8/compiler.31833905_0 gcc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
cargo run --example gelu --features cuda
generates the following error message:
thread 'main' panicked at crates/cubecl-cuda/src/compute/server.rs:327:17: [Compilation Error] default_program(36): error: type name is not allowed default_program(36): error: expected a ";" default_program(38): error: type name is not allowed default_program(38): error: expected a ";" default_program(40): error: type name is not allowed default_program(40): error: expected a ";" default_program(42): error: type name is not allowed default_program(42): error: expected a ";" default_program(44): error: type name is not allowed default_program(44): error: expected a ";" default_program(27): warning #177-D: variable "rank_2" was declared but never referenced default_program(30): warning #550-D: variable "l_0_2" was set but never used 10 errors detected in the compilation of "default_program". [Source] typedef unsigned char uint8; typedef unsigned short uint16; typedef unsigned int uint; typedef unsigned long long int uint64; typedef long long int int64; struct __align__(16) float_4 { float i_0; float i_1; float i_2; float i_3; }; extern "C" __global__ void kernel( float_4 input_0[],float_4 output_0[],uint info[] ) { int3 absoluteIdx = make_int3( blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, blockIdx.z * blockDim.z + threadIdx.z ); uint idxGlobal = (absoluteIdx.z * gridDim.x * blockDim.x * gridDim.y * blockDim.y) + (absoluteIdx.y * gridDim.x * blockDim.x) + absoluteIdx.x; uint rank = info[0]; uint rank_2 = rank * 2; uint l_0_0; bool l_0_1; float_4 l_0_2; float_4 l_0_3; l_0_0 = info[(2 * 2 * info[0]) + 1] / 4; l_0_1 = idxGlobal < l_0_0; if (l_0_1) { l_0_2 = input_0[idxGlobal]; l_0_3 = float_4{ l_0_2.i_0 / float(1.4142135), l_0_2.i_1 / float(1.4142135), l_0_2.i_2 / float(1.4142135), l_0_2.i_3 / float(1.4142135), }; l_0_3 = float_4{ erf(l_0_3.i_0),erf(l_0_3.i_1),erf(l_0_3.i_2),erf(l_0_3.i_3),}; l_0_3 = float_4{ l_0_3.i_0 + float(1.0), l_0_3.i_1 + float(1.0), l_0_3.i_2 + float(1.0), l_0_3.i_3 + float(1.0), }; l_0_3 = float_4{ l_0_2.i_0 * l_0_3.i_0, l_0_2.i_1 * l_0_3.i_1, l_0_2.i_2 * l_0_3.i_2, l_0_2.i_3 * l_0_3.i_3, }; l_0_3 = float_4{ l_0_3.i_0 / float(2.0), l_0_3.i_1 / float(2.0), l_0_3.i_2 / float(2.0), l_0_3.i_3 / float(2.0), }; output_0[idxGlobal] = l_0_3; } } note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
A possible fix is to append "-std=c++11" or "-std=c++14" to options at
options
cubecl/crates/cubecl-cuda/src/compute/server.rs
Line 312 in 370e5cf
let options = &[arch.as_str(), include_option.as_str(), "-std=c++14"];
The text was updated successfully, but these errors were encountered:
No branches or pull requests
The example gelu CUDA does not work on my machine.
The reason is the generated CUDA code
can only be compiled with C++11 and later.
System Info:
Ubuntu 20.04
nvcc Build cuda_11.8.r11.8/compiler.31833905_0
gcc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
generates the following error message:
A possible fix is to append "-std=c++11" or "-std=c++14" to
options
atcubecl/crates/cubecl-cuda/src/compute/server.rs
Line 312 in 370e5cf
let options = &[arch.as_str(), include_option.as_str(), "-std=c++14"];
The text was updated successfully, but these errors were encountered: