Skip to content

Commit 378143c

Browse files
lishanokcopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 770347940 Change-Id: Id5cbbb006acf99e030989cf803ef74a8a9b5f7c7
1 parent 3b7d08c commit 378143c

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

qkeras/quantizers.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2323,7 +2323,8 @@ def __init__(self,
23232323
qnoise_factor=1.0,
23242324
var_name=None,
23252325
use_ste=True,
2326-
use_variables=False):
2326+
use_variables=False,
2327+
enable_fast_inference=False):
23272328
super().__init__()
23282329
self.bits = bits
23292330
self.integer = integer
@@ -2339,6 +2340,7 @@ def __init__(self,
23392340
assert np.mod(np.log2(negative_slope), 1) == 0
23402341
self.var_name = var_name
23412342
self.use_variables = use_variables
2343+
self.enable_fast_inference = enable_fast_inference
23422344

23432345
def __str__(self):
23442346
# Converts Tensors to printable strings by converting to a numpy array and
@@ -2357,7 +2359,18 @@ def __str__(self):
23572359
flags.append(str(int(self.use_stochastic_rounding)))
23582360
return "quantized_relu(" + ",".join(flags) + ")"
23592361

2362+
@tf.function(jit_compile=True)
2363+
def fast_quantize(p, m_i, factor):
2364+
return m_i * tf.clip_by_value(tf.round(p) * factor, 0.0, 1.0 - factor)
2365+
23602366
def __call__(self, x):
2367+
if self.enable_fast_inference:
2368+
# This is the fast inference version of the quantizer.
2369+
m_i = 1 << self.integer
2370+
p = x * (2 ** (self.bits - self.integer))
2371+
factor = 2 ** -self.bits
2372+
return self.fast_quantize(p, m_i, factor)
2373+
23612374
if not self.built:
23622375
self.build(var_name=self.var_name, use_variables=self.use_variables)
23632376

0 commit comments

Comments
 (0)