@@ -237,10 +237,14 @@ def __init__(self, *args, **kwds):
237
237
# These are used in _add_user_function to format the function code
238
238
if prefs .devices .cuda_standalone .default_functions_integral_convertion == np .float64 :
239
239
self .default_func_type = 'double'
240
+ self .default_func_suffix = ''
240
241
self .other_func_type = 'float'
242
+ self .other_func_suffix = 'f'
241
243
else : # np.float32
242
244
self .default_func_type = 'float'
245
+ self .default_func_suffix = 'f'
243
246
self .other_func_type = 'double'
247
+ self .other_func_suffix = ''
244
248
# set clip function to either use all float or all double arguments
245
249
# see #51 for details
246
250
if prefs ['core.default_float_dtype' ] == np .float64 :
@@ -634,7 +638,9 @@ def _add_user_function(self, varname, variable):
634
638
# `DEFAULT_FUNCTIONS['cos'] would match intependent of the function name.
635
639
if varname in functions_C99 :
636
640
funccode = funccode .format (default_type = self .default_func_type ,
637
- other_type = self .other_func_type )
641
+ default_f = self .default_func_suffix ,
642
+ other_type = self .other_func_type ,
643
+ other_f = self .other_func_suffix )
638
644
elif varname in ['clip' , 'exprel' ]:
639
645
funccode = funccode .format (float_dtype = self .float_dtype )
640
646
###
@@ -819,15 +825,18 @@ class CUDAAtomicsCodeGenerator(CUDACodeGenerator):
819
825
{{default_type}} _brian_{func}(T value)
820
826
{{{{
821
827
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ > 0))
822
- return {func}(({{default_type}})value);
828
+ // CUDA math functions are only overloaded for floating point types. Hence,
829
+ // here we cast integral types to floating point types.
830
+ return {func}{{default_f}}(({{default_type}})value);
823
831
#else
824
- return {func}(value);
832
+ // Host functions are already overloaded for integral types
833
+ return {func}{{default_f}}(value);
825
834
#endif
826
835
}}}}
827
836
inline __host__ __device__
828
837
{{other_type}} _brian_{func}({{other_type}} value)
829
838
{{{{
830
- return {func}(value);
839
+ return {func}{{other_f}} (value);
831
840
}}}}
832
841
''' .format (func = func_cuda )
833
842
# {default_type} and {other_type} will be formatted in CUDACodeGenerator.determine_keywords()
0 commit comments