Skip to content

Commit 57b1179

Browse files
doncarlos999shashaka
authored andcommitted
Porting TF fake_quant_with_min_max functions (keras-team#20641)
* QAT (squashed this time) (keras-team#1) * adds fake_quant_with_min_max functions from TF to keras3 * Addresses PR review comments * drops another type hint * swaps out if statements, change float() to ops.cast and adds fake_quant_with_min_max_vars function * fix missed if statement, adds gradient tests via main function for tf and torch * fix unbound variable error when not using torch or tf backend (keras-team#2) Refactor to use backend specific gradient functions in tests and merges logic into single function * More QAT function revisions (keras-team#3) This PR addresses review feedback to fix implementation and to move tests to using named_parameters rather than individual functions. * Qat revisions (keras-team#4) Adds axis param and fixes logic for per channel function * updated implementation * removed redundant functions
1 parent 7b53bef commit 57b1179

File tree

5 files changed

+525
-0
lines changed

5 files changed

+525
-0
lines changed

keras/api/_tf_keras/keras/quantizers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,7 @@
1212
from keras.src.quantizers.quantizers import abs_max_quantize
1313
from keras.src.quantizers.quantizers import compute_float8_amax_history
1414
from keras.src.quantizers.quantizers import compute_float8_scale
15+
from keras.src.quantizers.quantizers import (
16+
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars_per_channel,
17+
)
1518
from keras.src.quantizers.quantizers import quantize_and_dequantize

keras/api/quantizers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,7 @@
1212
from keras.src.quantizers.quantizers import abs_max_quantize
1313
from keras.src.quantizers.quantizers import compute_float8_amax_history
1414
from keras.src.quantizers.quantizers import compute_float8_scale
15+
from keras.src.quantizers.quantizers import (
16+
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars_per_channel,
17+
)
1518
from keras.src.quantizers.quantizers import quantize_and_dequantize

keras/src/quantizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from keras.src.quantizers.quantizers import abs_max_quantize
77
from keras.src.quantizers.quantizers import compute_float8_amax_history
88
from keras.src.quantizers.quantizers import compute_float8_scale
9+
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
910
from keras.src.quantizers.quantizers import quantize_and_dequantize
1011
from keras.src.saving import serialization_lib
1112
from keras.src.utils.naming import to_snake_case

keras/src/quantizers/quantizers.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from keras.src import backend
55
from keras.src import ops
66
from keras.src.api_export import keras_export
7+
from keras.src.backend.common.backend_utils import canonicalize_axis
78
from keras.src.backend.common.backend_utils import standardize_axis_for_numpy
89

910
"""Int8-related classes and methods"""
@@ -127,6 +128,142 @@ def get_config(self):
127128
}
128129

129130

131+
def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
132+
"""Adjusts and nudges the quantization range for better accuracy."""
133+
134+
quant_max = ops.cast(ops.subtract(ops.power(2, num_bits), 1.0), "float32")
135+
136+
quant_min = ops.cast(0.0 if not narrow_range else 1.0, "float32")
137+
138+
# Calculate the scale and ensure it's positive
139+
scale = ops.divide(
140+
ops.subtract(max_range, min_range), ops.subtract(quant_max, quant_min)
141+
)
142+
143+
inv_scale = ops.reciprocal(scale)
144+
145+
# Calculate the zero point from the min range
146+
zero_point_from_min = quant_min - ops.divide(min_range, scale)
147+
148+
# Ensure zero point is within valid range [0, quant_max]
149+
zero_point = ops.clip(zero_point_from_min, quant_min, quant_max)
150+
151+
# Nudge zero point if it's very close to an integer
152+
nudged_zero_point = ops.round(zero_point)
153+
154+
# Calculate nudged limits
155+
nudged_min = ops.multiply(ops.subtract(quant_min, nudged_zero_point), scale)
156+
nudged_max = ops.multiply(ops.subtract(quant_max, nudged_zero_point), scale)
157+
158+
return nudged_min, nudged_max, scale, inv_scale
159+
160+
161+
@keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel")
162+
def fake_quant_with_min_max_vars(
163+
inputs,
164+
min_vals,
165+
max_vals,
166+
num_bits,
167+
narrow_range=False,
168+
axis=None,
169+
):
170+
"""
171+
Perform per-tensor or per-channel fake quantization.
172+
173+
`[min_vals, max_vals]` define the clamping range for the `inputs`.
174+
175+
The `inputs` are quantized into the quantization range:
176+
- `[0, 2^num_bits - 1]` when `narrow_range=False`
177+
- `[1, 2^num_bits - 1]` when `narrow_range=True`
178+
179+
After quantization, the values are dequantized and output as floats within
180+
the `[min_vals, max_vals]` interval.
181+
182+
This operation supports gradient computation, allowing `min_vals` and
183+
`max_vals` to be trained.
184+
185+
Args:
186+
inputs: Input tensor of float dtype.
187+
min_vals: A global minimum scalar or a per-channel minimum tensor.
188+
max_vals: A global maximum scalar or a per-channel maximum tensor.
189+
num_bits: Quantization bit width (e.g., `8` for int8).
190+
narrow_range: Whether to use narrow quantization range.
191+
axis: Axis along which to perform per-channel quantization. If `None`,
192+
per-tensor quantization is performed. Defaults to `None`.
193+
194+
195+
Returns:
196+
Fake-quantized tensor
197+
"""
198+
inputs = ops.convert_to_tensor(inputs)
199+
min_vals = ops.convert_to_tensor(min_vals)
200+
max_vals = ops.convert_to_tensor(max_vals)
201+
202+
if axis is not None:
203+
axis = canonicalize_axis(axis, inputs.ndim)
204+
205+
@ops.custom_gradient
206+
def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):
207+
# Calculate quantization parameters for all channels at once
208+
nudged_min, nudged_max, scale, inv_scale = adjust_and_nudge(
209+
min_val, max_val, num_bits, narrow_range
210+
)
211+
212+
quant_zero = ops.floor(
213+
ops.add(ops.multiply(-nudged_min, inv_scale), 0.5)
214+
)
215+
x_clamped = ops.clip(x, nudged_min, nudged_max)
216+
x_clamped_shifted = ops.subtract(x_clamped, nudged_min)
217+
result = ops.multiply(
218+
ops.floor(
219+
ops.add(
220+
ops.subtract(
221+
ops.multiply(x_clamped_shifted, inv_scale), quant_zero
222+
),
223+
0.5,
224+
)
225+
),
226+
scale,
227+
)
228+
229+
# Create gradient mask for all channels
230+
masks = ops.cast(
231+
(x >= nudged_min) & (x <= nudged_max),
232+
dtype="float32",
233+
)
234+
235+
def grad(*args, upstream=None):
236+
if upstream is None:
237+
(upstream,) = args
238+
239+
# Gradient for x
240+
dx = ops.multiply(upstream, masks)
241+
axes = [i for i in range(len(dx.shape)) if i != axis]
242+
# Gradient for min_val
243+
# When x is clipped to min, the gradient flows to min_val
244+
min_mask = ops.cast(x <= nudged_min, dtype="float32")
245+
grad_min = ops.multiply(upstream, min_mask)
246+
if axis is not None:
247+
grad_min = ops.sum(grad_min, axis=axes)
248+
else:
249+
grad_min = ops.sum(grad_min)
250+
251+
# Gradient for max_val
252+
# When x is clipped to max, the gradient flows to max_val
253+
max_mask = ops.cast(x >= nudged_max, dtype="float32")
254+
grad_max = ops.multiply(upstream, max_mask)
255+
if axis is not None:
256+
grad_max = ops.sum(grad_max, axis=axes)
257+
else:
258+
grad_max = ops.sum(grad_max)
259+
260+
return dx, grad_min, grad_max
261+
262+
return result, grad
263+
264+
return _fake_quant_with_min_max_vars_per_channel(inputs, min_vals, max_vals)
265+
266+
130267
"""Float8-related methods"""
131268

132269

0 commit comments

Comments
 (0)