Skip to content

Commit 96e6f39

Browse files
Akshaya Purohitcopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 713012271 Change-Id: Ib9b64b6ddde9ad843fc7612812324b36b7b24fff
1 parent 84f3adf commit 96e6f39

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

qkeras/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,19 +99,20 @@
9999
]
100100

101101

102-
def find_bn_fusing_layer_pair(model):
102+
def find_bn_fusing_layer_pair(model, custom_objects={}):
103103
"""Finds layers that can be fused with the following batchnorm layers.
104104
105105
Args:
106106
model: input model
107+
custom_objects: Dict of model specific objects needed for cloning.
107108
108109
Returns:
109110
Dict that marks all the layer pairs that need to be fused.
110111
111112
Note: supports sequential and non-sequential model
112113
"""
113114

114-
fold_model = clone_model(model)
115+
fold_model = clone_model(model, custom_objects)
115116
(graph, _) = qgraph.GenerateGraphFromModel(
116117
fold_model, "quantized_bits(8, 0, 1)", "quantized_bits(8, 0, 1)")
117118

@@ -219,7 +220,7 @@ def apply_quantizer(quantizer, input_weight):
219220

220221

221222
# Model utilities: before saving the weights, we want to apply the quantizers
222-
def model_save_quantized_weights(model, filename=None):
223+
def model_save_quantized_weights(model, filename=None, custom_objects={}):
223224
"""Quantizes model for inference and save it.
224225
225226
Takes a model with weights, apply quantization function to weights and
@@ -241,17 +242,19 @@ def model_save_quantized_weights(model, filename=None):
241242
model: model with weights to be quantized.
242243
filename: if specified, we will save the hdf5 containing the quantized
243244
weights so that we can use them for inference later on.
245+
custom_objects: Dict of model specific objects needed to load/store.
244246
245247
Returns:
246248
dictionary containing layer name and quantized weights that can be used
247249
by a hardware generator.
248-
249250
"""
250251

251252
saved_weights = {}
252253

253254
# Find the conv/dense layers followed by Batchnorm layers
254-
(fusing_layer_pair_dict, bn_layers_to_skip) = find_bn_fusing_layer_pair(model)
255+
(fusing_layer_pair_dict, bn_layers_to_skip) = find_bn_fusing_layer_pair(
256+
model, custom_objects
257+
)
255258

256259
print("... quantizing model")
257260
for layer in model.layers:

0 commit comments

Comments
 (0)