Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting TF fake_quant_with_min_max functions #20641

Merged
merged 16 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions keras/api/_tf_keras/keras/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,14 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_args
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_args_gradient,
)
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel_gradient,
)
from keras.src.quantizers.quantizers import quantize_and_dequantize
10 changes: 10 additions & 0 deletions keras/api/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,14 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_args
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_args_gradient,
)
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel_gradient,
)
from keras.src.quantizers.quantizers import quantize_and_dequantize
10 changes: 10 additions & 0 deletions keras/src/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_args
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_args_gradient,
)
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel_gradient,
)
from keras.src.quantizers.quantizers import quantize_and_dequantize
from keras.src.saving import serialization_lib
from keras.src.utils.naming import to_snake_case
Expand Down
335 changes: 335 additions & 0 deletions keras/src/quantizers/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,341 @@ def get_config(self):
}


def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
"""Adjusts and nudges the quantization range for better accuracy."""
if num_bits < 2:
raise ValueError("num_bits must be >= 2")

n_steps = float(2**num_bits - 1)
if narrow_range:
n_steps -= 1.0

# Handle the case where min and max are too close
if abs(max_range - min_range) < 1e-10:
return min_range, max_range, 1.0

# Calculate the step size
step_size = (max_range - min_range) / n_steps

# Calculate the reciprocal of the step size
inv_step_size = 1.0 / step_size

# Round the reciprocal to get an integer
rounded_inv_step_size = ops.round(inv_step_size)

# Calculate the final step size
final_step_size = 1.0 / rounded_inv_step_size

# Calculate the quantized min/max values, ensuring accurate rounding
quantized_min = (
ops.round(min_range * rounded_inv_step_size) / rounded_inv_step_size
)
quantized_max = (
ops.round(max_range * rounded_inv_step_size) / rounded_inv_step_size
)

# Convert quantization limits to float
quant_min_float = float(quantized_min)
quant_max_float = float(quantized_max)

# Calculate the scale
nudged_scale = (max_range - min_range) / (quant_max_float - quant_min_float)

# Calculate zero point from min
zero_point_from_min = quant_min_float - min_range / nudged_scale

# Determine nudged zero point
if zero_point_from_min < quant_min_float:
nudged_zero_point = quantized_min
elif zero_point_from_min > quant_max_float:
nudged_zero_point = quantized_max
else:
nudged_zero_point = ops.round(zero_point_from_min)

# Calculate nudged min and max
nudged_min = (quant_min_float - nudged_zero_point) * nudged_scale
nudged_max = (quant_max_float - nudged_zero_point) * nudged_scale

return (
nudged_min,
nudged_max,
final_step_size,
) # Returning nudged values and scale


@keras_export("keras.quantizers.fake_quant_with_min_max_args")
def fake_quant_with_min_max_args(
inputs,
min_range=-6.0,
max_range=6.0,
num_bits=8,
narrow_range=False,
):
"""Fake quantization operation matching TensorFlow's implementation."""

if isinstance(inputs, np.ndarray):
inputs = ops.convert_to_tensor(inputs)

@ops.custom_gradient
def _fake_quant_with_min_max_args(x):
quant_min, quant_max, step_size = adjust_and_nudge(
min_range, max_range, num_bits, narrow_range
)

n_steps = 2**num_bits - 1
if narrow_range:
n_steps -= 1

# Clip and nudge input to the range
x_clipped = ops.clip(x, quant_min, quant_max)
x_norm = (x_clipped - quant_min) / step_size
x_quantized = ops.round(x_norm)
x_quantized = ops.clip(x_quantized, 0.0, n_steps)
result = x_quantized * step_size + quant_min

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args
# Gradient mask: valid within the range
mask = ops.cast(
(x >= quant_min) & (x <= quant_max), dtype=upstream.dtype
)
return ops.multiply(upstream, mask)

return result, grad

return _fake_quant_with_min_max_args(inputs)


@keras_export("keras.quantizers.fake_quant_with_min_max_args_gradient")
def fake_quant_with_min_max_args_gradient(
gradients,
inputs,
min_range=-6.0,
max_range=6.0,
num_bits=8,
narrow_range=False,
):
"""Fake quantization operation with gradient,
matching TensorFlow's implementation."""

if isinstance(inputs, np.ndarray):
inputs = ops.convert_to_tensor(inputs)

def _fake_quant_with_min_max_args_gradient(x):
quant_min, quant_max, step_size = adjust_and_nudge(
min_range, max_range, num_bits, narrow_range
)

n_steps = 2**num_bits - 1
if narrow_range:
n_steps -= 1

# Clip and nudge input to the range
x_clipped = ops.clip(x, quant_min, quant_max)
x_norm = (x_clipped - quant_min) / step_size
x_quantized = ops.round(x_norm)
x_quantized = ops.clip(x_quantized, 0.0, n_steps)
result = x_quantized * step_size + quant_min

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args
# Gradient mask: valid within the range
mask = ops.cast(
(x >= quant_min) & (x <= quant_max), dtype=upstream.dtype
)
return ops.multiply(upstream, mask)

return result, grad

output, grad = _fake_quant_with_min_max_args_gradient(inputs)
return output, grad(gradients)


@keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel")
def fake_quant_with_min_max_vars_per_channel(
inputs,
min_vals,
max_vals,
num_bits,
narrow_range,
):
"""
Perform per-channel fake quantization with custom gradient.

Args:
inputs: Input tensor of float type
min_vals: Per-channel minimum values
max_vals: Per-channel maximum values
num_bits: Quantization bit width (2-16)
narrow_range: Whether to use narrow quantization range

Returns:
Fake-quantized tensor
"""

if isinstance(inputs, np.ndarray):
inputs = ops.convert_to_tensor(inputs)
min_vals = ops.convert_to_tensor(min_vals)
max_vals = ops.convert_to_tensor(max_vals)

@ops.custom_gradient
def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):
# Determine the number of channels
num_channels = min_val.shape[-1]

# Initialize an empty list to store quantized values for each channel
quantized_channels = []
masks = []

# Iterate over each channel
for i in range(num_channels):
# Extract min/max values for current channel
current_min = min_val[..., i]
current_max = max_val[..., i]

# Calculate step size and quantized min/max using _adjust_range
qnt_min, qnt_max, step_size = adjust_and_nudge(
current_min, current_max, num_bits, narrow_range
)
# Calculate the number of steps
n_steps = 2**num_bits - 1
if narrow_range:
n_steps -= 1

# Clip and nudge input to the range for the current channel
x_clipped = ops.clip(x[..., i], qnt_min, qnt_max)
x_norm = (x_clipped - qnt_min) / step_size
x_quantized = ops.round(x_norm)
x_quantized = ops.clip(x_quantized, 0.0, n_steps)
result_channel = x_quantized * step_size + qnt_min

quantized_channels.append(result_channel)
mask = ops.cast(
(x[..., i] >= qnt_min) & (x[..., i] <= qnt_max),
dtype=np.float32,
)
masks.append(mask)

# Concatenate quantized channels
result = ops.stack(quantized_channels, axis=-1)

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args

# Gradient mask: valid within the range
return ops.multiply(upstream, mask)

return result, grad

return _fake_quant_with_min_max_vars_per_channel(inputs, min_vals, max_vals)


@keras_export(
"keras.quantizers.fake_quant_with_min_max_vars_per_channel_gradient"
)
def fake_quant_with_min_max_vars_per_channel_gradient(
gradients,
inputs,
min_vals,
max_vals,
num_bits,
narrow_range,
):
"""
Perform per-channel fake quantization with custom gradient.

Args:
inputs: Input tensor of float type
min_vals: Per-channel minimum values
max_vals: Per-channel maximum values
num_bits: Quantization bit width (2-16)
narrow_range: Whether to use narrow quantization range

Returns:
Fake-quantized tensor
"""

if isinstance(inputs, np.ndarray):
inputs = ops.convert_to_tensor(inputs)
min_vals = ops.convert_to_tensor(min_vals)
max_vals = ops.convert_to_tensor(max_vals)

# @ops.custom_gradient
def _fake_quant_with_min_max_vars_per_channel_gradient(x, min_val, max_val):
# Determine the number of channels
num_channels = min_val.shape[-1]

# Initialize an empty list to store quantized values for each channel
quantized_channels = []
between_min_max_masks = []
below_min_masks = []
above_max_masks = []

# Iterate over each channel
for i in range(num_channels):
# Extract min/max values for current channel
current_min = min_val[..., i]
current_max = max_val[..., i]

# Calculate step size and quantized min/max using _adjust_range
qnt_min, qnt_max, step_size = adjust_and_nudge(
current_min, current_max, num_bits, narrow_range
)

# Calculate the number of steps
n_steps = 2**num_bits - 1
if narrow_range:
n_steps -= 1

# Clip and nudge input to the range for the current channel
x_clipped = ops.clip(x[..., i], qnt_min, qnt_max)
x_norm = (x_clipped - qnt_min) / step_size
x_quantized = ops.round(x_norm)
x_quantized = ops.clip(x_quantized, 0.0, n_steps)
result_channel = x_quantized * step_size + qnt_min
between_min_max_mask = ops.cast(
(x[..., i] >= qnt_min) & (x[..., i] <= qnt_max),
dtype=np.float32,
)
below_min_mask = ops.cast((x[..., i] < qnt_min), dtype=np.float32)
above_max_mask = ops.cast((x[..., i] > qnt_max), dtype=np.float32)
between_min_max_masks.append(between_min_max_mask)
below_min_masks.append(below_min_mask)
above_max_masks.append(above_max_mask)
quantized_channels.append(result_channel)

# Concatenate quantized channels
result = ops.stack(quantized_channels, axis=-1)
between_min_max_masks = ops.stack(between_min_max_masks, axis=-1)
below_min_masks = ops.stack(below_min_masks, axis=-1)
above_max_masks = ops.stack(above_max_masks, axis=-1)

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args
backprops_wrt_input = ops.multiply(upstream, between_min_max_masks)
backprops_wrt_min = ops.sum(
ops.multiply(upstream, below_min_masks), axis=0
)
backprops_wrt_max = ops.sum(
ops.multiply(upstream, above_max_masks), axis=0
)

return backprops_wrt_input, backprops_wrt_min, backprops_wrt_max

return result, grad

output, grad = _fake_quant_with_min_max_vars_per_channel_gradient(
inputs, min_vals, max_vals
)
backprops_wrt_input, backprops_wrt_min, backprops_wrt_max = grad(gradients)

return output, backprops_wrt_input, backprops_wrt_min, backprops_wrt_max


"""Float8-related methods"""


Expand Down
Loading
Loading