@@ -39,6 +39,10 @@ def dtype(self):
3939 def compiler (self ):
4040 return self ._settings ['compiler' ]
4141
42+ def single_prec (self , expr = None ):
43+ dtype = sympy_dtype (expr ) if expr is not None else self .dtype
44+ return dtype in [np .float32 , np .float16 ]
45+
4246 def parenthesize (self , item , level , strict = False ):
4347 if isinstance (item , BooleanFunction ):
4448 return "(%s)" % self ._print (item )
@@ -104,9 +108,8 @@ def _print_math_func(self, expr, nest=False, known=None):
104108 except KeyError :
105109 return super ()._print_math_func (expr , nest = nest , known = known )
106110
107- dtype = sympy_dtype (expr )
108- if dtype is np .float32 :
109- cname += 'f'
111+ if self .single_prec (expr ):
112+ cname = '%sf' % cname
110113
111114 args = ', ' .join ((self ._print (arg ) for arg in expr .args ))
112115
@@ -116,7 +119,7 @@ def _print_Pow(self, expr):
116119 # Need to override because of issue #1627
117120 # E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x
118121 try :
119- if expr .exp == - 1 and self .dtype == np . float32 :
122+ if expr .exp == - 1 and self .single_prec () :
120123 PREC = precedence (expr )
121124 return '1.0F/%s' % self .parenthesize (expr .base , PREC )
122125 except AttributeError :
@@ -196,8 +199,8 @@ def _print_Float(self, expr):
196199 elif rv .startswith ('.0' ):
197200 rv = '0.' + rv [2 :]
198201
199- if self .dtype == np . float32 :
200- rv = rv + 'F'
202+ if self .single_prec () :
203+ rv = '%sF' % rv
201204
202205 return rv
203206
@@ -252,8 +255,8 @@ def _print_ComponentAccess(self, expr):
252255
253256 def _print_TrigonometricFunction (self , expr ):
254257 func_name = str (expr .func )
255- if self .dtype == np . float32 :
256- func_name + = 'f'
258+ if self .single_prec () :
259+ func_name = '%sf' % func_name
257260 return '%s(%s)' % (func_name , self ._print (* expr .args ))
258261
259262 def _print_DefFunction (self , expr ):
0 commit comments