Skip to content

Commit a1fec24

Browse files
lishanokcopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 725819350 Change-Id: Ie196effeb9912c7cdd38e53a2b69a1716147ba4c
1 parent abc660f commit a1fec24

File tree

2 files changed

+296
-4
lines changed

2 files changed

+296
-4
lines changed

qkeras/utils.py

Lines changed: 182 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def model_save_quantized_weights(model, filename=None, custom_objects={}):
309309
# Weights store the weight in the format that software inference uses.
310310
weights.append(weight)
311311

312+
q_name = ""
312313
if quantizer:
313314
if isinstance(quantizer, six.string_types):
314315
q_name = quantizer
@@ -318,11 +319,10 @@ def model_save_quantized_weights(model, filename=None, custom_objects={}):
318319
q_name = quantizer.name
319320
elif hasattr(quantizer, "__class__"):
320321
q_name = quantizer.__class__.__name__
321-
else:
322-
q_name = ""
322+
323323
if quantizer and ("_po2" in q_name):
324324
# Quantized_relu_po2 does not have a sign.
325-
if isinstance(quantizer, quantized_po2):
325+
if q_name == "quantized_po2":
326326
has_sign = True
327327
sign = np.sign(weight)
328328
# Makes sure values are -1 or +1 only
@@ -332,7 +332,7 @@ def model_save_quantized_weights(model, filename=None, custom_objects={}):
332332
hw_weight = np.round(np.log2(np.abs(weight)))
333333
signs.append(sign)
334334
scales.append([])
335-
elif (isinstance(quantizer, quantized_bits) and
335+
elif (q_name == "quantized_bits" and
336336
quantizer.alpha == "auto_po2"):
337337
unsigned_bits = quantizer.bits - quantizer.keep_negative
338338
m = K.cast_to_floatx(pow(2, unsigned_bits))
@@ -1352,3 +1352,181 @@ def quantized_model_dump(model,
13521352
print("writing the layer output tensor to ", filename)
13531353
with open(filename, "w") as fid:
13541354
tensor_data.astype(np.float32).tofile(fid)
1355+
1356+
1357+
def clone_model_and_freeze_auto_po2_scale(
1358+
orig_model, orig_model_path=None, quantize_model_weights=False):
1359+
"""Clone model and freeze the scale value of auto_po2 type quantizers.
1360+
1361+
Args:
1362+
orig_model: original model which will be used to clone the new model.
1363+
If set to None, the function will load the original model
1364+
from orig_model_path argument.
1365+
orig_model_path: The path to the original model file.
1366+
If set to None, the function will load the original model from the
1367+
orig_model argument.
1368+
quantize_model_weights: Bool to quantize weights to HW format.
1369+
If set to False, the model weights will be in float format.
1370+
If set to True, the model weights will be in HW format and the function
1371+
will also check if the hw weights extracted from the new model matches
1372+
the original model.
1373+
1374+
Returns:
1375+
A tuple of the new model and the new model's hw weights.
1376+
1377+
Note:
1378+
+ When using this function to retrain model with fixed scale value.
1379+
Set quantize_model_weights to False in this case.
1380+
+ This function only supports a collection of common layers that will use
1381+
auto_po2 quantizers. For less common layers, it will raise errors and we
1382+
will add more support case by case.
1383+
1384+
Example usage:
1385+
model, _ = clone_model_and_freeze_auto_po2_scale(
1386+
orig_model_path="path/to/model",
1387+
quantize_model_weights=False)
1388+
"""
1389+
1390+
def _create_bn_layer(layer_cfg, bn_inv_quantizer):
1391+
# Clone batch normalization layer with the new inverse quantizer.
1392+
if bn_inv_quantizer is not None:
1393+
layer_cfg["inverse_quantizer"]["config"] = bn_inv_quantizer.get_config()
1394+
return QBatchNormalization(**layer_cfg)
1395+
1396+
def _create_qconv2d_layer(layer_cfg, kernel_quantizer):
1397+
# Clone QConv2D layer wiht the new kernel quantizers.
1398+
if kernel_quantizer is not None:
1399+
layer_cfg["kernel_quantizer"]["config"] = kernel_quantizer.get_config()
1400+
return QConv2D(**layer_cfg)
1401+
1402+
def _create_qdepthwise_conv2d_layer(layer_cfg, depthwise_quantizer):
1403+
# Clone QDepthwiseConv2D layer with the new depthwise_quantizer quantizer.
1404+
if depthwise_quantizer is not None:
1405+
layer_cfg["depthwise_quantizer"][
1406+
"config"] = depthwise_quantizer.get_config()
1407+
return QDepthwiseConv2D(**layer_cfg)
1408+
1409+
def _create_qdense_layer(layer_cfg, kernel_quantizer):
1410+
# Clone QDense layer with the new kernel quantizer.
1411+
if kernel_quantizer is not None:
1412+
layer_cfg["kernel_quantizer"]["config"] = kernel_quantizer.get_config()
1413+
return QDense(**layer_cfg)
1414+
1415+
def _create_other_layer(orig_layer):
1416+
# Clone other layers.
1417+
config = orig_layer.get_config()
1418+
return orig_layer.__class__.from_config(config)
1419+
1420+
def _create_quantized_bits_with_post_training_scale(q):
1421+
# Create a new quantized_bits instance with the fixed scale value.
1422+
if q is not None:
1423+
q_cfg = q.get_config()
1424+
q_cfg["post_training_scale"] = q.scale.numpy()
1425+
q = quantized_bits(**q_cfg)
1426+
return q
1427+
1428+
def _find_auto_po2_quantizer(layer):
1429+
# Find the auto_po2 quantizer in the layer. Note that we allow at
1430+
# most one auto_po2 quantizer in each layer due to the limitation of
1431+
# the current HW implementation.
1432+
num_auto_po2_quantizers = 0
1433+
auto_po2_quantizer = None
1434+
if hasattr(layer, "quantizers"):
1435+
for q in layer.quantizers:
1436+
if hasattr(q, "alpha") and q.alpha == "auto_po2":
1437+
num_auto_po2_quantizers += 1
1438+
auto_po2_quantizer = q
1439+
if num_auto_po2_quantizers > 1:
1440+
raise ValueError(
1441+
f"{layer.name} has more than one auto_po2 quantizer. "
1442+
"Please check if this is expected.")
1443+
else:
1444+
return auto_po2_quantizer
1445+
1446+
def _check_hw_weights_equal(hw_weights_1, hw_weights_2):
1447+
# Check if the hw weights extracted from the new model matches the
1448+
# original model.
1449+
for layer_name in hw_weights_2.keys():
1450+
for key in hw_weights_2[layer_name].keys():
1451+
1452+
val1 = hw_weights_2[layer_name][key]
1453+
val2 = hw_weights_1[layer_name][key]
1454+
if isinstance(val1, list):
1455+
for (v1, v2) in zip(val1, val2):
1456+
if not np.all(v1 == v2):
1457+
raise ValueError(
1458+
f"{layer_name}/{key}: No Match! v1={v1}, v2={v2}")
1459+
else:
1460+
if not np.all(val1 == val2):
1461+
raise ValueError(
1462+
f"{layer_name}/{key}: No Match! val1={val1}, val2={val2}")
1463+
1464+
# Load the original model with float weights.
1465+
# Note: weights will be quantized later in silicon flow by calling
1466+
# model_save_quantized_weights.
1467+
if orig_model is not None and orig_model_path is not None:
1468+
raise ValueError(
1469+
"Only one of orig_model and orig_model_path can be set.")
1470+
elif orig_model is None and orig_model_path is None:
1471+
raise ValueError(
1472+
"One of orig_model and orig_model_path must be set.")
1473+
elif orig_model_path is not None:
1474+
orig_model = load_qmodel(orig_model_path, compile=False)
1475+
1476+
# Quantize model weights and compute quantizer scale values.
1477+
quantized_model = tf.keras.models.clone_model(orig_model)
1478+
quantized_model.set_weights(orig_model.get_weights())
1479+
# In silicon flow, weight binary files are generated from hw weights.
1480+
orig_hw_weights = model_save_quantized_weights(
1481+
quantized_model)
1482+
1483+
# Create a new model with fixed scale quantizers.
1484+
x = inputs = tf.keras.Input(
1485+
shape=orig_model.input_shape[1:], name=orig_model.layers[0].name)
1486+
for layer in quantized_model.layers[1:]:
1487+
layer_class = layer.__class__.__name__
1488+
auto_po2_quantizer = _find_auto_po2_quantizer(layer)
1489+
auto_po2_quantizer_with_frozen_scale = (
1490+
_create_quantized_bits_with_post_training_scale(auto_po2_quantizer))
1491+
layer_cfg = layer.get_config()
1492+
1493+
# To be compatible with different python versions, we do not use
1494+
# match-case style here.
1495+
if layer_class == "QConv2D":
1496+
x = _create_qconv2d_layer(layer_cfg,
1497+
auto_po2_quantizer_with_frozen_scale)(x)
1498+
elif layer_class == "QDepthwiseConv2D":
1499+
x = _create_qdepthwise_conv2d_layer(
1500+
layer_cfg, auto_po2_quantizer_with_frozen_scale)(x)
1501+
elif layer_class == "QBatchNormalization":
1502+
x = _create_bn_layer(layer_cfg,
1503+
auto_po2_quantizer_with_frozen_scale)(x)
1504+
elif layer_class == "QDense":
1505+
x = _create_qdense_layer(layer_cfg,
1506+
auto_po2_quantizer_with_frozen_scale)(x)
1507+
else:
1508+
x = _create_other_layer(layer)(x)
1509+
1510+
new_model = tf.keras.Model(inputs, x)
1511+
# Set the weights of the new model to the original model (float weights).
1512+
new_model.set_weights(orig_model.get_weights())
1513+
1514+
# Check if the new model still has auto_po2 quantizer.
1515+
# This function only supports a colleciton of common layers that will use
1516+
# auto_po2 quantizers. For less common layers, we need to add extra support
1517+
# in the future.
1518+
for layer in new_model.layers:
1519+
q = _find_auto_po2_quantizer(layer)
1520+
if q is not None and q.post_training_scale is None:
1521+
raise ValueError(
1522+
f"{layer.name} in the new model still has auto_po2 quantizer with "
1523+
"adaptive scales. Please check if this is expected!")
1524+
1525+
new_hw_weights = None
1526+
if quantize_model_weights:
1527+
new_hw_weights = model_save_quantized_weights(new_model)
1528+
# Check if the hw weights extracted from the new model matches the original
1529+
# nima model.
1530+
_check_hw_weights_equal(orig_hw_weights, new_hw_weights)
1531+
1532+
return new_model, new_hw_weights

tests/utils_test.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import numpy as np
2222
import pytest
23+
import os
24+
import tempfile
2325
from tensorflow.keras.layers import *
2426
from tensorflow.keras.models import *
2527

@@ -30,6 +32,8 @@
3032
from qkeras.utils import is_TFOpLambda_layer
3133
from qkeras.utils import find_bn_fusing_layer_pair
3234
from qkeras.utils import add_bn_fusing_weights
35+
from qkeras.utils import clone_model_and_freeze_auto_po2_scale
36+
from qkeras.utils import load_qmodel
3337

3438

3539
def create_quantized_network():
@@ -223,5 +227,115 @@ def test_find_bn_fusing_layer_pair():
223227
assert np.all(d["fused_bias"] == np.array([0.09375, 0.65625]))
224228

225229

230+
def create_test_model_for_scale_freezing(bias_quantizer):
231+
def _create_simple_model(bias_quantizer):
232+
x = x_in = tf.keras.Input((4, 4, 1), name="input")
233+
x = QConv2D(
234+
filters=4, kernel_size=2, strides=2,
235+
kernel_quantizer=quantized_bits(4, 2, 1, alpha="auto_po2"),
236+
bias_quantizer=quantized_bits(4, 2, 1),
237+
use_bias=False,
238+
name="conv")(x)
239+
x = QDepthwiseConv2D(
240+
kernel_size=2, strides=1,
241+
depthwise_quantizer=quantized_bits(6, 3, 1, alpha="auto_po2"),
242+
use_bias=False,
243+
bias_quantizer=quantized_bits(4, 2, 1),
244+
name="dw_conv")(x)
245+
x = QBatchNormalization(
246+
mean_quantizer=quantized_bits(4, 2, 1),
247+
gamma_quantizer=None,
248+
variance_quantizer=None,
249+
beta_quantizer=quantized_bits(4, 0, 1),
250+
inverse_quantizer=quantized_bits(8, 0, 1, alpha="auto_po2"),
251+
name="bn")(x)
252+
253+
x = QActivation(activation=quantized_bits(4, 0), name="relu")(x)
254+
x = tf.keras.layers.Flatten(name="flatten")(x)
255+
x = QDense(units=2,
256+
kernel_quantizer=quantized_bits(4, 2, 1, alpha="auto_po2"),
257+
bias_quantizer=bias_quantizer, name="dense")(x)
258+
model = tf.keras.Model(inputs=x_in, outputs=x)
259+
260+
return model
261+
262+
def _set_weights(model):
263+
conv_w = [np.array(
264+
[0.23, 2.76, 0.1, 0.33, 0.53, 0.16, 0.3, 1.7, -0.9,
265+
1.43, 2.31, -0.2, -1.7, 0.39, -2.03, 1.79]).reshape(2, 2, 1, 4)]
266+
267+
dw_conv_w = [np.array([
268+
0.03, 3.6, 2.1, 1.2, 0.13, 1.3, -0.3, 1.2, -0.7,
269+
-10.3, 11.7, -0.92, -10.7, 0.59, -1.93, 2.8]).reshape((2, 2, 4, 1))]
270+
271+
bn_w = [np.array([0.28, 1.33, 2.27, 3.36]),
272+
np.array([0.31, 0.1, 0.03, 4.26]),
273+
np.array([0.89, -0.21, 1.97, 2.06]),
274+
np.array([1.2, 0.9, 13.2, 10.9])]
275+
276+
dense_w = np.array(
277+
[0.13, 0.66, 0.21, 0.23, 1.07, -0.79, 1.83, 1.81])
278+
dense_w = [dense_w.reshape((4, 2)), np.array([-1.3, 0.7])]
279+
280+
model.get_layer("conv").set_weights(conv_w)
281+
model.get_layer("dw_conv").set_weights(dw_conv_w)
282+
model.get_layer("bn").set_weights(bn_w)
283+
model.get_layer("dense").set_weights(dense_w)
284+
285+
orig_model = _create_simple_model(bias_quantizer)
286+
_set_weights(orig_model)
287+
288+
return orig_model
289+
290+
291+
def test_clone_model_and_freeze_auto_po2_scale():
292+
"""Test clone_model_and_freeze_auto_po2_scale to work properly."""
293+
294+
orig_model = create_test_model_for_scale_freezing(quantized_bits(4, 2, 1))
295+
_, new_hw = clone_model_and_freeze_auto_po2_scale(
296+
orig_model, quantize_model_weights=True)
297+
298+
# Check if the new model's weights and scales are derived properly.
299+
np.testing.assert_array_equal(
300+
new_hw["conv"]["weights"][0],
301+
np.array(
302+
[[[[0.5, 6, 0, 0.5]], [[1, 0, 0.5, 3.5]]],
303+
[[[-2., 3., 3.5, -0.5]], [[-3.5, 1., -3.5, 3.5]]]]))
304+
305+
np.testing.assert_array_equal(
306+
new_hw["conv"]["scales"][0], np.array([[[[0.25, 0.5, 0.25, 0.25]]]]))
307+
308+
np.testing.assert_array_equal(
309+
new_hw["dw_conv"]["weights"][0].numpy().flatten(),
310+
np.array([
311+
0., 14, 8, 4, 0, 6, -2, 4, -2, -42, 46, -4, -42, 2, -8, 12]))
312+
313+
np.testing.assert_array_equal(
314+
new_hw["dense"]["scales"][0], np.array([[0.25, 0.25]]))
315+
316+
317+
def test_clone_model_and_freeze_auto_po2_scale_serialization():
318+
# Test if the cloned model can be saved and loaded properly.
319+
orig_model = create_test_model_for_scale_freezing(quantized_bits(4, 2, 1))
320+
new_model, _ = clone_model_and_freeze_auto_po2_scale(
321+
orig_model, quantize_model_weights=True)
322+
323+
fd, fname = tempfile.mkstemp(".hdf5")
324+
new_model.save(fname)
325+
_ = load_qmodel(fname)
326+
os.close(fd)
327+
os.remove(fname)
328+
329+
330+
def test_clone_model_and_freeze_auto_po2_scale_error():
331+
orig_model = create_test_model_for_scale_freezing(
332+
quantized_bits(4, 2, 1, alpha="auto_po2"))
333+
# Test if the function raises an error when there are more than one
334+
# auto_po2 quantizers in a layer.
335+
with pytest.raises(ValueError):
336+
clone_model_and_freeze_auto_po2_scale(
337+
orig_model, quantize_model_weights=False)
338+
339+
226340
if __name__ == "__main__":
227341
pytest.main([__file__])

0 commit comments

Comments
 (0)