Skip to content

Commit 9477f2e

Browse files
Akshaya Purohitcopybara-github
Akshaya Purohit
authored andcommitted
Add support for masking and groups in QConv2D.
PiperOrigin-RevId: 726258359 Change-Id: I565b1fec70fa3007246111725aeff936f54a1c0d
1 parent a1fec24 commit 9477f2e

File tree

2 files changed

+185
-36
lines changed

2 files changed

+185
-36
lines changed

qkeras/qconvolutional.py

Lines changed: 91 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
from __future__ import absolute_import
1717
from __future__ import division
1818
from __future__ import print_function
19+
1920
import warnings
2021

22+
import numpy as np
2123
import tensorflow as tf
2224
from tensorflow.keras import constraints
2325
from tensorflow.keras import initializers
@@ -26,19 +28,19 @@
2628
from tensorflow.keras.layers import Conv1D
2729
from tensorflow.keras.layers import Conv2D
2830
from tensorflow.keras.layers import Conv2DTranspose
29-
from tensorflow.keras.layers import SeparableConv1D
30-
from tensorflow.keras.layers import SeparableConv2D
3131
from tensorflow.keras.layers import DepthwiseConv2D
3232
from tensorflow.keras.layers import Dropout
3333
from tensorflow.keras.layers import InputSpec
34-
from tensorflow.python.eager import context
35-
from tensorflow.python.ops import array_ops
36-
# from tensorflow.python.ops import array_ops
34+
from tensorflow.keras.layers import SeparableConv1D
35+
from tensorflow.keras.layers import SeparableConv2D
36+
3737
from .qlayers import get_auto_range_constraint_initializer
3838
from .qlayers import QActivation
3939
from .quantizers import get_quantized_initializer
4040
from .quantizers import get_quantizer
41-
41+
from tensorflow.python.eager import context
42+
from tensorflow.python.ops import array_ops
43+
# from tensorflow.python.ops import array_ops
4244
from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer
4345

4446

@@ -260,32 +262,36 @@ class QConv2D(Conv2D, PrunableLayer):
260262
# can go over [-1,+1], these values are used to set the clipping
261263
# value of kernels and biases, respectively, instead of using the
262264
# constraints specified by the user.
265+
# mask: Optional mask for kernel weights.
263266
#
264267
# we refer the reader to the documentation of Conv2D in Keras for the
265268
# other parameters.
266269
#
267270

268-
def __init__(self,
269-
filters,
270-
kernel_size,
271-
strides=(1, 1),
272-
padding="valid",
273-
data_format="channels_last",
274-
dilation_rate=(1, 1),
275-
activation=None,
276-
use_bias=True,
277-
kernel_initializer="he_normal",
278-
bias_initializer="zeros",
279-
kernel_regularizer=None,
280-
bias_regularizer=None,
281-
activity_regularizer=None,
282-
kernel_constraint=None,
283-
bias_constraint=None,
284-
kernel_range=None,
285-
bias_range=None,
286-
kernel_quantizer=None,
287-
bias_quantizer=None,
288-
**kwargs):
271+
def __init__(
272+
self,
273+
filters,
274+
kernel_size,
275+
strides=(1, 1),
276+
padding="valid",
277+
data_format="channels_last",
278+
dilation_rate=(1, 1),
279+
activation=None,
280+
use_bias=True,
281+
kernel_initializer="he_normal",
282+
bias_initializer="zeros",
283+
kernel_regularizer=None,
284+
bias_regularizer=None,
285+
activity_regularizer=None,
286+
kernel_constraint=None,
287+
bias_constraint=None,
288+
kernel_range=None,
289+
bias_range=None,
290+
kernel_quantizer=None,
291+
bias_quantizer=None,
292+
mask=None,
293+
**kwargs,
294+
):
289295

290296
if kernel_range is not None:
291297
warnings.warn("kernel_range is deprecated in QConv2D layer.")
@@ -324,6 +330,20 @@ def __init__(self,
324330
if activation is not None:
325331
activation = get_quantizer(activation)
326332

333+
if mask is not None:
334+
shape = mask.shape
335+
if len(shape) < 2:
336+
raise ValueError(
337+
"Expected shape to have rank at least 2 but provided shape has"
338+
f" rank {len(shape)}"
339+
)
340+
h, w = shape[0], shape[1]
341+
self._mask = np.reshape(
342+
mask, (h, w, 1, 1)
343+
) # Extend the dimension to be 4D.
344+
else:
345+
self._mask = None
346+
327347
super().__init__(
328348
filters=filters,
329349
kernel_size=kernel_size,
@@ -343,19 +363,44 @@ def __init__(self,
343363
**kwargs
344364
)
345365

366+
def convolution_op(self, inputs, kernel):
367+
return tf.keras.backend.conv2d(
368+
inputs,
369+
kernel,
370+
strides=self.strides,
371+
padding=self.padding,
372+
data_format=self.data_format,
373+
dilation_rate=self.dilation_rate,
374+
)
375+
376+
@tf.function(jit_compile=True)
377+
def _jit_compiled_convolution_op(self, inputs, kernel):
378+
return self.convolution_op(inputs, kernel)
379+
346380
def call(self, inputs):
347381
if self.kernel_quantizer:
348382
quantized_kernel = self.kernel_quantizer_internal(self.kernel)
349383
else:
350384
quantized_kernel = self.kernel
351385

352-
outputs = tf.keras.backend.conv2d(
353-
inputs,
354-
quantized_kernel,
355-
strides=self.strides,
356-
padding=self.padding,
357-
data_format=self.data_format,
358-
dilation_rate=self.dilation_rate)
386+
if self._mask is not None:
387+
# Apply mask to kernel weights if one is provided.
388+
quantized_kernel = quantized_kernel * self._mask
389+
390+
# Grouped convolutions are not fully supported on the CPU for compiled
391+
# functions.
392+
#
393+
# This is a workaround taken from TF's core library. Remove when proper
394+
# support is added.
395+
# See definition of function "_jit_compiled_convolution_op" at
396+
# cs/third_party/py/tf_keras/layers/convolutional/base_conv.py for more
397+
# details.
398+
if self.groups > 1:
399+
outputs = self._jit_compiled_convolution_op(
400+
inputs, tf.convert_to_tensor(quantized_kernel)
401+
)
402+
else:
403+
outputs = self.convolution_op(inputs, quantized_kernel)
359404

360405
if self.use_bias:
361406
if self.bias_quantizer:
@@ -364,7 +409,8 @@ def call(self, inputs):
364409
quantized_bias = self.bias
365410

366411
outputs = tf.keras.backend.bias_add(
367-
outputs, quantized_bias, data_format=self.data_format)
412+
outputs, quantized_bias, data_format=self.data_format
413+
)
368414

369415
if self.activation is not None:
370416
return self.activation(outputs)
@@ -380,10 +426,19 @@ def get_config(self):
380426
),
381427
"kernel_range": self.kernel_range,
382428
"bias_range": self.bias_range,
429+
"mask": self._mask.tolist() if self._mask is not None else None,
383430
}
384-
base_config = super(QConv2D, self).get_config()
431+
base_config = super().get_config()
385432
return dict(list(base_config.items()) + list(config.items()))
386433

434+
@classmethod
435+
def from_config(cls, config):
436+
mask = config.get("mask")
437+
if mask is not None:
438+
mask = np.array(mask)
439+
config["mask"] = mask
440+
return cls(**config)
441+
387442
def get_quantization_config(self):
388443
return {
389444
"kernel_quantizer":

tests/qconvolutional_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,5 +308,99 @@ def test_qconv2dtranspose():
308308
[2., 3., 4., 4., 3., 2.] ]).reshape((1,6,6,1)).astype(np.float16)
309309
assert_allclose(actual_output, expected_output, rtol=1e-4)
310310

311+
312+
def test_masked_qconv2d_creates_correct_parameters():
313+
mask = mask = np.ones((5, 5), dtype=np.float32)
314+
model = tf.keras.Sequential()
315+
model.add(tf.keras.layers.Input(shape=(10, 10, 1)))
316+
model.add(QConv2D(mask=mask, filters=1, kernel_size=(5, 5), use_bias=False))
317+
318+
# There should be no non-trainable params.
319+
np.testing.assert_equal(len(model.non_trainable_weights), 0)
320+
321+
# Validate number of trainable params. This should be equal to one (5,5)
322+
# kernel.
323+
np.testing.assert_equal(len(model.trainable_weights), 1)
324+
num_trainable_params = np.prod(model.trainable_weights[0].shape)
325+
np.testing.assert_equal(num_trainable_params, 25)
326+
327+
328+
def test_qconv2d_masks_weights():
329+
# Create an arbitrary mask.
330+
mask = np.array(
331+
[
332+
[1.0, 0.0, 1.0, 0.0, 1.0],
333+
[0.0, 0.0, 1.0, 0.0, 0.0],
334+
[1.0, 0.0, 1.0, 0.0, 1.0],
335+
[0.0, 0.0, 1.0, 0.0, 0.0],
336+
[1.0, 0.0, 1.0, 0.0, 1.0],
337+
],
338+
dtype=np.float32,
339+
)
340+
model = tf.keras.Sequential()
341+
model.add(tf.keras.layers.Input(shape=(5, 5, 1)))
342+
model.add(QConv2D(mask=mask, filters=1, kernel_size=(5, 5), use_bias=False))
343+
344+
# Set the weights to be all ones.
345+
model.layers[0].set_weights([np.ones((5, 5, 1, 1), dtype=np.float32)])
346+
347+
# Run inference on a all ones input.
348+
output = model.predict(np.ones((1, 5, 5, 1), dtype=np.float32))
349+
# Output should just be summation of number of ones in the mask.
350+
np.testing.assert_array_equal(
351+
output, np.array([[[[11.0]]]], dtype=np.float32)
352+
)
353+
354+
355+
def test_masked_qconv2d_load_restore_works():
356+
model = tf.keras.Sequential()
357+
model.add(tf.keras.layers.Input(shape=(10, 10, 1)))
358+
model.add(
359+
QConv2D(
360+
mask=np.ones((5, 5), dtype=np.float32),
361+
filters=1,
362+
kernel_size=(5, 5),
363+
use_bias=False,
364+
)
365+
)
366+
367+
with tempfile.TemporaryDirectory() as temp_dir:
368+
model_path = os.path.join(temp_dir, 'model.keras')
369+
# Can save the model.
370+
model.save(model_path)
371+
372+
# Can load the model.
373+
custom_objects = {
374+
'QConv2D': QConv2D,
375+
}
376+
loaded_model = tf.keras.models.load_model(
377+
model_path, custom_objects=custom_objects
378+
)
379+
380+
np.testing.assert_array_equal(
381+
model.layers[0].weights[0], loaded_model.layers[0].weights[0]
382+
)
383+
384+
385+
def test_qconv2d_groups_works():
386+
model = tf.keras.Sequential()
387+
model.add(tf.keras.layers.Input(shape=(10, 10, 10)))
388+
model.add(
389+
QConv2D(
390+
filters=6,
391+
kernel_size=(1, 1),
392+
use_bias=True,
393+
groups=2,
394+
)
395+
)
396+
# Validate number of trainable params.
397+
np.testing.assert_equal(len(model.trainable_weights), 2)
398+
num_trainable_params = np.prod(model.trainable_weights[0].shape) + np.prod(
399+
model.trainable_weights[1].shape
400+
)
401+
expected_trainable_params = 36 # (5*3)*2 + 6
402+
np.testing.assert_equal(num_trainable_params, expected_trainable_params)
403+
404+
311405
if __name__ == '__main__':
312406
pytest.main([__file__])

0 commit comments

Comments
 (0)