Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions keras/api/_tf_keras/keras/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,9 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_args
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import quantize_and_dequantize
5 changes: 5 additions & 0 deletions keras/api/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,9 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_args
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import quantize_and_dequantize
5 changes: 5 additions & 0 deletions keras/src/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_args
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import quantize_and_dequantize
from keras.src.saving import serialization_lib
from keras.src.utils.naming import to_snake_case
Expand Down
138 changes: 138 additions & 0 deletions keras/src/quantizers/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,144 @@ def get_config(self):
}


def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
"""Adjusts and nudges the quantization range for better accuracy."""

quant_max = ops.cast(ops.subtract(ops.power(2, num_bits), 1.0), "float32")

quant_min = ops.cast(0.0 if not narrow_range else 1.0, "float32")

# Calculate the scale and ensure it's positive
scale = ops.divide(
ops.subtract(max_range, min_range), ops.subtract(quant_max, quant_min)
)

inv_scale = ops.reciprocal(scale)

# Calculate the zero point from the min range
zero_point_from_min = quant_min - ops.divide(min_range, scale)

# Ensure zero point is within valid range [0, quant_max]
zero_point = ops.clip(zero_point_from_min, quant_min, quant_max)

# Nudge zero point if it's very close to an integer
nudged_zero_point = ops.round(zero_point)

# Calculate nudged limits
nudged_min = ops.multiply(ops.subtract(quant_min, nudged_zero_point), scale)
nudged_max = ops.multiply(ops.subtract(quant_max, nudged_zero_point), scale)

return nudged_min, nudged_max, scale, inv_scale


@keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel")
def fake_quant_with_min_max_vars_per_channel(
inputs,
min_vals,
max_vals,
num_bits,
narrow_range,
):
"""
Perform per-channel fake quantization with custom gradient using vectorized
operations.

Args:
inputs: Input tensor of float type
min_vals: Per-channel minimum values
max_vals: Per-channel maximum values
num_bits: Quantization bit width (2-16)
narrow_range: Whether to use narrow quantization range

Returns:
Fake-quantized tensor
"""
inputs = ops.convert_to_tensor(inputs)
min_vals = ops.convert_to_tensor(min_vals)
max_vals = ops.convert_to_tensor(max_vals)

@ops.custom_gradient
def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):
# Calculate quantization parameters for all channels at once
nudged_min, nudged_max, scale, inv_scale = adjust_and_nudge(
min_val, max_val, num_bits, narrow_range
)

quant_zero = ops.floor(
ops.add(ops.multiply(-nudged_min, inv_scale), 0.5)
)
x_clamped = ops.clip(x, nudged_min, nudged_max)
x_clamped_shifted = ops.subtract(x_clamped, nudged_min)
result = ops.multiply(
ops.floor(
ops.add(
ops.subtract(
ops.multiply(x_clamped_shifted, inv_scale), quant_zero
),
0.5,
)
),
scale,
)

# Create gradient mask for all channels
masks = ops.cast(
(x >= nudged_min) & (x <= nudged_max),
dtype="float32",
)

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args

# Gradient for x
dx = ops.multiply(upstream, masks)

# Gradient for min_val
# When x is clipped to min, the gradient flows to min_val
min_mask = ops.cast(x <= nudged_min, dtype="float32")
grad_min = ops.sum(ops.multiply(upstream, min_mask))

# Gradient for max_val
# When x is clipped to max, the gradient flows to max_val
max_mask = ops.cast(x >= nudged_max, dtype="float32")
grad_max = ops.sum(ops.multiply(upstream, max_mask))

return dx, grad_min, grad_max

return result, grad

return _fake_quant_with_min_max_vars_per_channel(inputs, min_vals, max_vals)


@keras_export("keras.quantizers.fake_quant_with_min_max_args")
def fake_quant_with_min_max_args(
inputs,
min_vals,
max_vals,
num_bits=8,
narrow_range=False,
):
"""Fake quantization operation matching TensorFlow's implementation."""
return fake_quant_with_min_max_vars_per_channel(
inputs, min_vals, max_vals, num_bits, narrow_range
)


@keras_export("keras.quantizers.fake_quant_with_min_max_vars")
def fake_quant_with_min_max_vars(
inputs,
min_vals,
max_vals,
num_bits=8,
narrow_range=False,
):
"""Fake quantization operation matching TensorFlow's implementation."""
return fake_quant_with_min_max_vars_per_channel(
inputs, min_vals, max_vals, num_bits, narrow_range
)


"""Float8-related methods"""


Expand Down
Loading
Loading