Skip to content

Commit

Permalink
first test passing
Browse files Browse the repository at this point in the history
  • Loading branch information
doncarlos999 committed Jan 13, 2025
1 parent bd17ffa commit 0ac1b2b
Showing 1 changed file with 35 additions and 70 deletions.
105 changes: 35 additions & 70 deletions keras/src/quantizers/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,66 +129,31 @@ def get_config(self):

def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
"""Adjusts and nudges the quantization range for better accuracy."""
if num_bits < 2:
raise ValueError("num_bits must be >= 2")

n_steps = ops.cast(2**num_bits - 1, "float32")
n_steps = n_steps if not narrow_range else n_steps - 1.0
quant_max = ops.cast(2**num_bits - 1.0, "float32")
quant_max = quant_max if not narrow_range else quant_max - 1.0

# Handle the case where min and max are too close
# if abs(max_range - min_range) < 1e-10:
# return min_range, max_range, 1.0
quant_min = ops.cast(0.0 if not narrow_range else 1.0, "float32")

# Calculate the step size
step_size = ops.divide((max_range - min_range), n_steps)
# Calculate the scale and ensure it's positive
scale = (max_range - min_range) / (quant_max - quant_min)

# Calculate the reciprocal of the step size
inv_step_size = 1.0 / step_size
inv_scale = ops.reciprocal(scale)

# Round the reciprocal to get an integer
rounded_inv_step_size = ops.round(inv_step_size)
# Calculate the zero point from the min range
zero_point_from_min = quant_min - min_range / scale

# Calculate the final step size
final_step_size = 1.0 / rounded_inv_step_size
# Ensure zero point is within valid range [0, quant_max]
zero_point = ops.clip(zero_point_from_min, quant_min, quant_max)

# Calculate the quantized min/max values, ensuring accurate rounding
quantized_min = (
ops.round(min_range * rounded_inv_step_size) / rounded_inv_step_size
)
quantized_max = (
ops.round(max_range * rounded_inv_step_size) / rounded_inv_step_size
)

# Convert quantization limits to float
quant_min_float = ops.cast(quantized_min, "float32")
quant_max_float = ops.cast(quantized_max, "float32")

# Calculate the scale
nudged_scale = (max_range - min_range) / (quant_max_float - quant_min_float)

# Calculate zero point from min
zero_point_from_min = quant_min_float - min_range / nudged_scale

# Determine nudged zero point
nudged_zero_point = ops.where(
zero_point_from_min < quant_min_float,
quantized_min,
ops.where(
zero_point_from_min > quant_max_float,
quantized_max,
ops.round(zero_point_from_min),
),
)
# Nudge zero point if it's very close to an integer
nudged_zero_point = ops.round(zero_point)

# Calculate nudged min and max
nudged_min = (quant_min_float - nudged_zero_point) * nudged_scale
nudged_max = (quant_max_float - nudged_zero_point) * nudged_scale
# Calculate nudged limits
nudged_min = (quant_min - nudged_zero_point) * scale
nudged_max = (quant_max - nudged_zero_point) * scale

return (
nudged_min,
nudged_max,
final_step_size,
) # Returning nudged values and scale
return nudged_min, nudged_max, scale, inv_scale


@keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel")
Expand Down Expand Up @@ -220,32 +185,32 @@ def fake_quant_with_min_max_vars_per_channel(
@ops.custom_gradient
def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):
# Calculate quantization parameters for all channels at once
qnt_min, qnt_max, step_size = adjust_and_nudge(
nudged_min, nudged_max, scale, inv_scale = adjust_and_nudge(
min_val, max_val, num_bits, narrow_range
)

# Calculate number of steps
n_steps = 2**num_bits - 1
if narrow_range:
n_steps -= 1

# Expand dimensions to allow broadcasting
qnt_min = ops.expand_dims(qnt_min, axis=list(range(len(x.shape) - 1)))
qnt_max = ops.expand_dims(qnt_max, axis=list(range(len(x.shape) - 1)))
step_size = ops.expand_dims(
step_size, axis=list(range(len(x.shape) - 1))
nudged_min = ops.expand_dims(
nudged_min, axis=list(range(len(x.shape) - 1))
)
nudged_max = ops.expand_dims(
nudged_max, axis=list(range(len(x.shape) - 1))
)
scale = ops.expand_dims(scale, axis=list(range(len(x.shape) - 1)))
inv_scale = ops.expand_dims(
inv_scale, axis=list(range(len(x.shape) - 1))
)

# Clip and quantize all channels simultaneously
x_clipped = ops.clip(x, qnt_min, qnt_max)
x_norm = (x_clipped - qnt_min) / step_size
x_quantized = ops.round(x_norm)
x_quantized = ops.clip(x_quantized, 0.0, n_steps)
result = x_quantized * step_size + qnt_min
quant_zero = ops.floor(-nudged_min * inv_scale + 0.5)
x_clamped = ops.clip(x, nudged_min, nudged_max)
x_clamped_shifted = x_clamped - nudged_min
result = (
ops.floor(x_clamped_shifted * inv_scale - quant_zero + 0.5) * scale
)

# Create gradient mask for all channels
masks = ops.cast(
(x >= qnt_min) & (x <= qnt_max),
(x >= nudged_min) & (x <= nudged_max),
dtype="float32",
)

Expand All @@ -258,13 +223,13 @@ def grad(*args, upstream=None):

# Gradient for min_val
# When x is clipped to min, the gradient flows to min_val
min_mask = ops.cast(x <= qnt_min, dtype="float32")
min_mask = ops.cast(x <= nudged_min, dtype="float32")
dims_to_reduce = list(range(len(x.shape) - 1))
grad_min = ops.sum(upstream * min_mask, axis=dims_to_reduce)

# Gradient for max_val
# When x is clipped to max, the gradient flows to max_val
max_mask = ops.cast(x >= qnt_max, dtype="float32")
max_mask = ops.cast(x >= nudged_max, dtype="float32")
grad_max = ops.sum(upstream * max_mask, axis=dims_to_reduce)

return dx, grad_min, grad_max
Expand Down

0 comments on commit 0ac1b2b

Please sign in to comment.