diff --git a/brian2cuda/cuda_generator.py b/brian2cuda/cuda_generator.py index 5145480a..bbb136fd 100644 --- a/brian2cuda/cuda_generator.py +++ b/brian2cuda/cuda_generator.py @@ -237,10 +237,14 @@ def __init__(self, *args, **kwds): # These are used in _add_user_function to format the function code if prefs.devices.cuda_standalone.default_functions_integral_convertion == np.float64: self.default_func_type = 'double' + self.default_func_suffix = '' self.other_func_type = 'float' + self.other_func_suffix = 'f' else: # np.float32 self.default_func_type = 'float' + self.default_func_suffix = 'f' self.other_func_type = 'double' + self.other_func_suffix = '' # set clip function to either use all float or all double arguments # see #51 for details if prefs['core.default_float_dtype'] == np.float64: @@ -634,7 +638,9 @@ def _add_user_function(self, varname, variable): # `DEFAULT_FUNCTIONS['cos'] would match intependent of the function name. if varname in functions_C99: funccode = funccode.format(default_type=self.default_func_type, - other_type=self.other_func_type) + default_f=self.default_func_suffix, + other_type=self.other_func_type, + other_f=self.other_func_suffix) elif varname in ['clip', 'exprel']: funccode = funccode.format(float_dtype=self.float_dtype) ### @@ -819,15 +825,18 @@ class CUDAAtomicsCodeGenerator(CUDACodeGenerator): {{default_type}} _brian_{func}(T value) {{{{ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ > 0)) - return {func}(({{default_type}})value); + // CUDA math functions are only overloaded for floating point types. Hence, + // here we cast integral types to floating point types. + return {func}{{default_f}}(({{default_type}})value); #else - return {func}(value); + // Host functions are already overloaded for integral types + return {func}{{default_f}}(value); #endif }}}} inline __host__ __device__ {{other_type}} _brian_{func}({{other_type}} value) {{{{ - return {func}(value); + return {func}{{other_f}}(value); }}}} '''.format(func=func_cuda) # {default_type} and {other_type} will be formatted in CUDACodeGenerator.determine_keywords()