Skip to content

Commit 4eedf1e

Browse files
Akshaya Purohitcopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 711792163 Change-Id: I5118482d2fbd0c4a722ab65ca01a2016aa4ce44b
1 parent 8e7a1a4 commit 4eedf1e

File tree

5 files changed

+290
-240
lines changed

5 files changed

+290
-240
lines changed

qkeras/base_quantizer.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
import tensorflow.compat.v2 as tf
17+
import tensorflow.keras.backend as K
18+
19+
20+
def _create_variable_name(attr_name, var_name=None):
21+
"""Creates variable name.
22+
23+
Arguments:
24+
attr_name: string. attribute name
25+
var_name: string. variable name
26+
27+
Returns:
28+
string. variable name
29+
"""
30+
31+
if var_name:
32+
return var_name + "/" + attr_name
33+
34+
# This naming scheme is to solve a problem of a layer having more than
35+
# one quantizer can have multiple qnoise_factor variables with the same
36+
# name of "qnoise_factor".
37+
return attr_name + "_" + str(K.get_uid(attr_name))
38+
39+
40+
class BaseQuantizer(tf.Module):
41+
"""Base quantizer.
42+
43+
Defines behavior all quantizers should follow.
44+
"""
45+
46+
def __init__(self):
47+
self.built = False
48+
49+
def build(self, var_name=None, use_variables=False):
50+
if use_variables:
51+
if hasattr(self, "qnoise_factor"):
52+
self.qnoise_factor = tf.Variable(
53+
lambda: tf.constant(self.qnoise_factor, dtype=tf.float32),
54+
name=_create_variable_name("qnoise_factor", var_name=var_name),
55+
dtype=tf.float32,
56+
trainable=False,
57+
)
58+
self.built = True
59+
60+
def _set_trainable_parameter(self):
61+
pass
62+
63+
def update_qnoise_factor(self, qnoise_factor):
64+
"""Update qnoise_factor."""
65+
if isinstance(self.qnoise_factor, tf.Variable):
66+
# self.qnoise_factor is a tf.Variable.
67+
# This is to update self.qnoise_factor during training.
68+
self.qnoise_factor.assign(qnoise_factor)
69+
else:
70+
if isinstance(qnoise_factor, tf.Variable):
71+
# self.qnoise_factor is a numpy variable, and qnoise_factor is a
72+
# tf.Variable.
73+
self.qnoise_factor = qnoise_factor.eval()
74+
else:
75+
# self.qnoise_factor and qnoise_factor are numpy variables.
76+
# This is to set self.qnoise_factor before building
77+
# (creating tf.Variable) it.
78+
self.qnoise_factor = qnoise_factor
79+
80+
# Override not to expose the quantizer variables.
81+
@property
82+
def variables(self):
83+
return ()
84+
85+
# Override not to expose the quantizer variables.
86+
@property
87+
def trainable_variables(self):
88+
return ()
89+
90+
# Override not to expose the quantizer variables.
91+
@property
92+
def non_trainable_variables(self):
93+
return ()

qkeras/qtools/DnC/divide_and_conquer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,16 @@
2727

2828
import enum
2929
import logging
30-
from typing import List, Any, Union
30+
from typing import Any, List, Union
3131

3232
import numpy as np
3333
import tensorflow as tf
3434

35+
from qkeras import base_quantizer
3536
from qkeras import quantizers
37+
from qkeras.qtools import generate_layer_data_type_map
3638
from qkeras.qtools import qgraph
3739
from qkeras.qtools import qtools_util
38-
from qkeras.qtools import generate_layer_data_type_map
3940
from qkeras.qtools.DnC import dnc_layer_cost_ace
4041

4142

@@ -49,8 +50,11 @@ class CostMode(enum.Enum):
4950
class DivideConquerGraph:
5051
"""This class creates model graph structure and methods to access layers."""
5152

52-
def __init__(self, model: tf.keras.Model,
53-
source_quantizers: quantizers.BaseQuantizer = None):
53+
def __init__(
54+
self,
55+
model: tf.keras.Model,
56+
source_quantizers: base_quantizer.BaseQuantizer = None,
57+
):
5458
self._model = model
5559
self._source_quantizer_list = source_quantizers or [
5660
quantizers.quantized_bits(8, 0, 1)]

qkeras/qtools/quantized_operators/fused_bn_factory.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
import math
2424

2525
import numpy as np
26-
import copy
26+
27+
from qkeras import base_quantizer
28+
from qkeras.qtools import qtools_util
2729
from qkeras.qtools.quantized_operators import adder_factory
2830
from qkeras.qtools.quantized_operators import divider_factory
2931
from qkeras.qtools.quantized_operators import multiplier_factory
3032
from qkeras.qtools.quantized_operators import quantizer_impl
31-
from qkeras.qtools import qtools_util
32-
from qkeras import quantizers
3333

3434
class FusedBNFactory:
3535
"""determine which quantizer implementation to use.
@@ -48,14 +48,15 @@ class FusedBNFactory:
4848
"""
4949

5050
def make_quantizer(
51-
self, prev_output_quantizer: quantizer_impl.IQuantizer,
51+
self,
52+
prev_output_quantizer: quantizer_impl.IQuantizer,
5253
beta_quantizer: quantizer_impl.IQuantizer,
5354
mean_quantizer: quantizer_impl.IQuantizer,
5455
inverse_quantizer: quantizer_impl.IQuantizer,
5556
prev_bias_quantizer: quantizer_impl.IQuantizer,
5657
use_beta: bool,
5758
use_bias: bool,
58-
qkeras_inverse_quantizer:quantizers.BaseQuantizer
59+
qkeras_inverse_quantizer: base_quantizer.BaseQuantizer,
5960
):
6061
"""Makes a fused_bn quantizer.
6162

qkeras/quantizer_imports.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""Imports for QKeras quantizers."""
17+
18+
from .quantizers import bernoulli
19+
from .quantizers import binary
20+
from .quantizers import quantized_bits
21+
from .quantizers import quantized_hswish
22+
from .quantizers import quantized_linear
23+
from .quantizers import quantized_po2
24+
from .quantizers import quantized_relu
25+
from .quantizers import quantized_relu_po2
26+
from .quantizers import quantized_sigmoid
27+
from .quantizers import quantized_tanh
28+
from .quantizers import quantized_ulaw
29+
from .quantizers import stochastic_binary
30+
from .quantizers import stochastic_ternary
31+
from .quantizers import ternary

0 commit comments

Comments
 (0)