@@ -2323,7 +2323,8 @@ def __init__(self,
2323
2323
qnoise_factor = 1.0 ,
2324
2324
var_name = None ,
2325
2325
use_ste = True ,
2326
- use_variables = False ):
2326
+ use_variables = False ,
2327
+ enable_fast_inference = False ):
2327
2328
super ().__init__ ()
2328
2329
self .bits = bits
2329
2330
self .integer = integer
@@ -2339,6 +2340,7 @@ def __init__(self,
2339
2340
assert np .mod (np .log2 (negative_slope ), 1 ) == 0
2340
2341
self .var_name = var_name
2341
2342
self .use_variables = use_variables
2343
+ self .enable_fast_inference = enable_fast_inference
2342
2344
2343
2345
def __str__ (self ):
2344
2346
# Converts Tensors to printable strings by converting to a numpy array and
@@ -2357,7 +2359,18 @@ def __str__(self):
2357
2359
flags .append (str (int (self .use_stochastic_rounding )))
2358
2360
return "quantized_relu(" + "," .join (flags ) + ")"
2359
2361
2362
+ @tf .function (jit_compile = True )
2363
+ def fast_quantize (p , m_i , factor ):
2364
+ return m_i * tf .clip_by_value (tf .round (p ) * factor , 0.0 , 1.0 - factor )
2365
+
2360
2366
def __call__ (self , x ):
2367
+ if self .enable_fast_inference :
2368
+ # This is the fast inference version of the quantizer.
2369
+ m_i = 1 << self .integer
2370
+ p = x * (2 ** (self .bits - self .integer ))
2371
+ factor = 2 ** - self .bits
2372
+ return self .fast_quantize (p , m_i , factor )
2373
+
2361
2374
if not self .built :
2362
2375
self .build (var_name = self .var_name , use_variables = self .use_variables )
2363
2376
0 commit comments