16
16
from __future__ import absolute_import
17
17
from __future__ import division
18
18
from __future__ import print_function
19
+
19
20
import warnings
20
21
22
+ import numpy as np
21
23
import tensorflow as tf
22
24
from tensorflow .keras import constraints
23
25
from tensorflow .keras import initializers
26
28
from tensorflow .keras .layers import Conv1D
27
29
from tensorflow .keras .layers import Conv2D
28
30
from tensorflow .keras .layers import Conv2DTranspose
29
- from tensorflow .keras .layers import SeparableConv1D
30
- from tensorflow .keras .layers import SeparableConv2D
31
31
from tensorflow .keras .layers import DepthwiseConv2D
32
32
from tensorflow .keras .layers import Dropout
33
33
from tensorflow .keras .layers import InputSpec
34
- from tensorflow .python . eager import context
35
- from tensorflow .python . ops import array_ops
36
- # from tensorflow.python.ops import array_ops
34
+ from tensorflow .keras . layers import SeparableConv1D
35
+ from tensorflow .keras . layers import SeparableConv2D
36
+
37
37
from .qlayers import get_auto_range_constraint_initializer
38
38
from .qlayers import QActivation
39
39
from .quantizers import get_quantized_initializer
40
40
from .quantizers import get_quantizer
41
-
41
+ from tensorflow .python .eager import context
42
+ from tensorflow .python .ops import array_ops
43
+ # from tensorflow.python.ops import array_ops
42
44
from tensorflow_model_optimization .python .core .sparsity .keras .prunable_layer import PrunableLayer
43
45
44
46
@@ -260,32 +262,36 @@ class QConv2D(Conv2D, PrunableLayer):
260
262
# can go over [-1,+1], these values are used to set the clipping
261
263
# value of kernels and biases, respectively, instead of using the
262
264
# constraints specified by the user.
265
+ # mask: Optional mask for kernel weights.
263
266
#
264
267
# we refer the reader to the documentation of Conv2D in Keras for the
265
268
# other parameters.
266
269
#
267
270
268
- def __init__ (self ,
269
- filters ,
270
- kernel_size ,
271
- strides = (1 , 1 ),
272
- padding = "valid" ,
273
- data_format = "channels_last" ,
274
- dilation_rate = (1 , 1 ),
275
- activation = None ,
276
- use_bias = True ,
277
- kernel_initializer = "he_normal" ,
278
- bias_initializer = "zeros" ,
279
- kernel_regularizer = None ,
280
- bias_regularizer = None ,
281
- activity_regularizer = None ,
282
- kernel_constraint = None ,
283
- bias_constraint = None ,
284
- kernel_range = None ,
285
- bias_range = None ,
286
- kernel_quantizer = None ,
287
- bias_quantizer = None ,
288
- ** kwargs ):
271
+ def __init__ (
272
+ self ,
273
+ filters ,
274
+ kernel_size ,
275
+ strides = (1 , 1 ),
276
+ padding = "valid" ,
277
+ data_format = "channels_last" ,
278
+ dilation_rate = (1 , 1 ),
279
+ activation = None ,
280
+ use_bias = True ,
281
+ kernel_initializer = "he_normal" ,
282
+ bias_initializer = "zeros" ,
283
+ kernel_regularizer = None ,
284
+ bias_regularizer = None ,
285
+ activity_regularizer = None ,
286
+ kernel_constraint = None ,
287
+ bias_constraint = None ,
288
+ kernel_range = None ,
289
+ bias_range = None ,
290
+ kernel_quantizer = None ,
291
+ bias_quantizer = None ,
292
+ mask = None ,
293
+ ** kwargs ,
294
+ ):
289
295
290
296
if kernel_range is not None :
291
297
warnings .warn ("kernel_range is deprecated in QConv2D layer." )
@@ -324,6 +330,20 @@ def __init__(self,
324
330
if activation is not None :
325
331
activation = get_quantizer (activation )
326
332
333
+ if mask is not None :
334
+ shape = mask .shape
335
+ if len (shape ) < 2 :
336
+ raise ValueError (
337
+ "Expected shape to have rank at least 2 but provided shape has"
338
+ f" rank { len (shape )} "
339
+ )
340
+ h , w = shape [0 ], shape [1 ]
341
+ self ._mask = np .reshape (
342
+ mask , (h , w , 1 , 1 )
343
+ ) # Extend the dimension to be 4D.
344
+ else :
345
+ self ._mask = None
346
+
327
347
super ().__init__ (
328
348
filters = filters ,
329
349
kernel_size = kernel_size ,
@@ -343,19 +363,44 @@ def __init__(self,
343
363
** kwargs
344
364
)
345
365
366
+ def convolution_op (self , inputs , kernel ):
367
+ return tf .keras .backend .conv2d (
368
+ inputs ,
369
+ kernel ,
370
+ strides = self .strides ,
371
+ padding = self .padding ,
372
+ data_format = self .data_format ,
373
+ dilation_rate = self .dilation_rate ,
374
+ )
375
+
376
+ @tf .function (jit_compile = True )
377
+ def _jit_compiled_convolution_op (self , inputs , kernel ):
378
+ return self .convolution_op (inputs , kernel )
379
+
346
380
def call (self , inputs ):
347
381
if self .kernel_quantizer :
348
382
quantized_kernel = self .kernel_quantizer_internal (self .kernel )
349
383
else :
350
384
quantized_kernel = self .kernel
351
385
352
- outputs = tf .keras .backend .conv2d (
353
- inputs ,
354
- quantized_kernel ,
355
- strides = self .strides ,
356
- padding = self .padding ,
357
- data_format = self .data_format ,
358
- dilation_rate = self .dilation_rate )
386
+ if self ._mask is not None :
387
+ # Apply mask to kernel weights if one is provided.
388
+ quantized_kernel = quantized_kernel * self ._mask
389
+
390
+ # Grouped convolutions are not fully supported on the CPU for compiled
391
+ # functions.
392
+ #
393
+ # This is a workaround taken from TF's core library. Remove when proper
394
+ # support is added.
395
+ # See definition of function "_jit_compiled_convolution_op" at
396
+ # cs/third_party/py/tf_keras/layers/convolutional/base_conv.py for more
397
+ # details.
398
+ if self .groups > 1 :
399
+ outputs = self ._jit_compiled_convolution_op (
400
+ inputs , tf .convert_to_tensor (quantized_kernel )
401
+ )
402
+ else :
403
+ outputs = self .convolution_op (inputs , quantized_kernel )
359
404
360
405
if self .use_bias :
361
406
if self .bias_quantizer :
@@ -364,7 +409,8 @@ def call(self, inputs):
364
409
quantized_bias = self .bias
365
410
366
411
outputs = tf .keras .backend .bias_add (
367
- outputs , quantized_bias , data_format = self .data_format )
412
+ outputs , quantized_bias , data_format = self .data_format
413
+ )
368
414
369
415
if self .activation is not None :
370
416
return self .activation (outputs )
@@ -380,10 +426,19 @@ def get_config(self):
380
426
),
381
427
"kernel_range" : self .kernel_range ,
382
428
"bias_range" : self .bias_range ,
429
+ "mask" : self ._mask .tolist () if self ._mask is not None else None ,
383
430
}
384
- base_config = super (QConv2D , self ).get_config ()
431
+ base_config = super ().get_config ()
385
432
return dict (list (base_config .items ()) + list (config .items ()))
386
433
434
+ @classmethod
435
+ def from_config (cls , config ):
436
+ mask = config .get ("mask" )
437
+ if mask is not None :
438
+ mask = np .array (mask )
439
+ config ["mask" ] = mask
440
+ return cls (** config )
441
+
387
442
def get_quantization_config (self ):
388
443
return {
389
444
"kernel_quantizer" :
0 commit comments