99
99
]
100
100
101
101
102
- def find_bn_fusing_layer_pair (model ):
102
+ def find_bn_fusing_layer_pair (model , custom_objects = {} ):
103
103
"""Finds layers that can be fused with the following batchnorm layers.
104
104
105
105
Args:
106
106
model: input model
107
+ custom_objects: Dict of model specific objects needed for cloning.
107
108
108
109
Returns:
109
110
Dict that marks all the layer pairs that need to be fused.
110
111
111
112
Note: supports sequential and non-sequential model
112
113
"""
113
114
114
- fold_model = clone_model (model )
115
+ fold_model = clone_model (model , custom_objects )
115
116
(graph , _ ) = qgraph .GenerateGraphFromModel (
116
117
fold_model , "quantized_bits(8, 0, 1)" , "quantized_bits(8, 0, 1)" )
117
118
@@ -219,7 +220,7 @@ def apply_quantizer(quantizer, input_weight):
219
220
220
221
221
222
# Model utilities: before saving the weights, we want to apply the quantizers
222
- def model_save_quantized_weights (model , filename = None ):
223
+ def model_save_quantized_weights (model , filename = None , custom_objects = {} ):
223
224
"""Quantizes model for inference and save it.
224
225
225
226
Takes a model with weights, apply quantization function to weights and
@@ -241,17 +242,19 @@ def model_save_quantized_weights(model, filename=None):
241
242
model: model with weights to be quantized.
242
243
filename: if specified, we will save the hdf5 containing the quantized
243
244
weights so that we can use them for inference later on.
245
+ custom_objects: Dict of model specific objects needed to load/store.
244
246
245
247
Returns:
246
248
dictionary containing layer name and quantized weights that can be used
247
249
by a hardware generator.
248
-
249
250
"""
250
251
251
252
saved_weights = {}
252
253
253
254
# Find the conv/dense layers followed by Batchnorm layers
254
- (fusing_layer_pair_dict , bn_layers_to_skip ) = find_bn_fusing_layer_pair (model )
255
+ (fusing_layer_pair_dict , bn_layers_to_skip ) = find_bn_fusing_layer_pair (
256
+ model , custom_objects
257
+ )
255
258
256
259
print ("... quantizing model" )
257
260
for layer in model .layers :
0 commit comments