Skip to content

Commit 2198dc3

Browse files
lishanokcopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 726648660 Change-Id: I35836e5b0083e65758f070d9d2886cbfab07f39f
1 parent 9477f2e commit 2198dc3

File tree

2 files changed

+110
-5
lines changed

2 files changed

+110
-5
lines changed

qkeras/qtools/qtools_util.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,16 @@ def get_weights(layer, model_weights_already_quantized=True):
248248
return out
249249

250250

251+
def get_scale_from_quantized_bits_with_auto_po2(quantizer):
252+
"""Get scale from quantized_bits with alpha=auto_po2."""
253+
if hasattr(quantizer.scale, "numpy"):
254+
return quantizer.scale.numpy()
255+
elif isinstance(quantizer.scale, np.ndarray):
256+
return quantizer.scale
257+
else:
258+
return None
259+
260+
251261
def adjust_multiplier_for_auto_po2(multiplier, qkeras_weight_quantizer):
252262
"""Adjust multiplier when weight quantizer is auto_po2 type.
253263
@@ -267,11 +277,9 @@ def adjust_multiplier_for_auto_po2(multiplier, qkeras_weight_quantizer):
267277
qkeras_weight_quantizer.alpha == "auto_po2"):
268278
bits = output_quantizer.bits
269279
int_bits = output_quantizer.int_bits
270-
scale = qkeras_weight_quantizer.scale
271-
if hasattr(scale, "numpy"):
272-
# if scale doesn't have numpy() function, it means the quantizer has
273-
# never being called. Therfore we skip the following steps
274-
scale = scale.numpy()
280+
scale = get_scale_from_quantized_bits_with_auto_po2(
281+
qkeras_weight_quantizer)
282+
if scale is not None:
275283
if isinstance(scale, np.ndarray):
276284
scale = np.squeeze(scale)
277285
max_shift = int(np.log2(np.max(scale)))
@@ -293,6 +301,8 @@ def adjust_multiplier_for_auto_po2(multiplier, qkeras_weight_quantizer):
293301
output_quantizer.bits = total_bits
294302
output_quantizer.int_bits = max_int_bits
295303
else:
304+
# If scale is None, it means the quantizer has
305+
# never being called. Therfore we skip the bitwidth adjustment steps
296306
print("[WARNING] The weight quantizer is never called even though it has "
297307
"alpha=auto_po2. In this case we do not adjust the multiplier and "
298308
"accumulator bit width since we don't know the exact values of "

tests/qtools_util_test.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""Tests for qtools_util module."""
17+
18+
import json
19+
20+
import numpy as np
21+
import pytest
22+
import tensorflow.keras as keras
23+
import tensorflow as tf
24+
25+
from qkeras import quantizers
26+
from qkeras.qtools import qtools_util
27+
28+
from qkeras.qtools import quantized_operators
29+
from qkeras.qtools.quantized_operators import quantizer_factory as quantizer_factory_module
30+
31+
32+
@pytest.mark.parametrize(
33+
"w_bits, w_int_bits, weight_quantizer_scale_type, "
34+
"expected_bits_before_adjustment, expected_int_bits_before_adjustment, "
35+
"expected_bits_after_adjustment, expected_int_bits_after_adjustment",
36+
[
37+
(8, 0, "1.0", 11, 2, 11, 2),
38+
(4, 2, "auto_po2", 7, 4, 10, 5),
39+
(4, 0, "post_training_scale", 7, 2, 10, 5),
40+
],
41+
)
42+
def test_adjust_multiplier_for_auto_po2(
43+
w_bits, w_int_bits, weight_quantizer_scale_type,
44+
expected_bits_before_adjustment, expected_int_bits_before_adjustment,
45+
expected_bits_after_adjustment, expected_int_bits_after_adjustment):
46+
"""Test adjust_multiplier_for_auto_po2 with auto_po2 weight quantizer."""
47+
48+
multiplier_factory = quantized_operators.MultiplierFactory()
49+
quantizer_factory = quantizer_factory_module.QuantizerFactory()
50+
51+
qkeras_input_quantizer = quantizers.quantized_bits(4, 2, 1)
52+
53+
# Generate the weight quantizer.
54+
if weight_quantizer_scale_type in ["auto_po2", "post_training_scale"]:
55+
# Compute the scale for auto_po2 quantizer.
56+
qkeras_weight_quantizer = quantizers.quantized_bits(
57+
bits=w_bits, integer=w_int_bits, keep_negative=True,
58+
symmetric=True, alpha="auto_po2")
59+
weight_arr = np.array([1.07, -1.7, 3.06, 1.93, 0.37, -2.43, 6.3, -2.9]
60+
).reshape((2, 4))
61+
qkeras_weight_quantizer(weight_arr)
62+
63+
if weight_quantizer_scale_type == "post_training_scale":
64+
# Set the post_training_scale as fixed scale.
65+
auto_po2_scale = qkeras_weight_quantizer.scale.numpy()
66+
qkeras_weight_quantizer = quantizers.quantized_bits(
67+
bits=w_bits, integer=w_int_bits, alpha="auto_po2",
68+
post_training_scale=auto_po2_scale)
69+
else:
70+
qkeras_weight_quantizer = quantizers.quantized_bits(w_bits, w_int_bits)
71+
72+
input_quantizer = quantizer_factory.make_quantizer(
73+
qkeras_input_quantizer)
74+
weight_quantizer = quantizer_factory.make_quantizer(
75+
qkeras_weight_quantizer)
76+
77+
multiplier = multiplier_factory.make_multiplier(
78+
weight_quantizer, input_quantizer)
79+
80+
np.testing.assert_equal(multiplier.output.bits,
81+
expected_bits_before_adjustment)
82+
np.testing.assert_equal(multiplier.output.int_bits,
83+
expected_int_bits_before_adjustment)
84+
85+
qtools_util.adjust_multiplier_for_auto_po2(
86+
multiplier, qkeras_weight_quantizer)
87+
print(f"after adjustment: {multiplier.output.bits}, {multiplier.output.int_bits}")
88+
np.testing.assert_equal(multiplier.output.bits,
89+
expected_bits_after_adjustment)
90+
np.testing.assert_equal(multiplier.output.int_bits,
91+
expected_int_bits_after_adjustment)
92+
93+
94+
if __name__ == "__main__":
95+
pytest.main([__file__])

0 commit comments

Comments
 (0)