From 0ac1b2bd0c7e147552e3e9d99e927da70c63a774 Mon Sep 17 00:00:00 2001 From: Carl Date: Mon, 13 Jan 2025 17:03:54 +0000 Subject: [PATCH] first test passing --- keras/src/quantizers/quantizers.py | 105 ++++++++++------------------- 1 file changed, 35 insertions(+), 70 deletions(-) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 08732d3d92e..1272c453563 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -129,66 +129,31 @@ def get_config(self): def adjust_and_nudge(min_range, max_range, num_bits, narrow_range): """Adjusts and nudges the quantization range for better accuracy.""" - if num_bits < 2: - raise ValueError("num_bits must be >= 2") - n_steps = ops.cast(2**num_bits - 1, "float32") - n_steps = n_steps if not narrow_range else n_steps - 1.0 + quant_max = ops.cast(2**num_bits - 1.0, "float32") + quant_max = quant_max if not narrow_range else quant_max - 1.0 - # Handle the case where min and max are too close - # if abs(max_range - min_range) < 1e-10: - # return min_range, max_range, 1.0 + quant_min = ops.cast(0.0 if not narrow_range else 1.0, "float32") - # Calculate the step size - step_size = ops.divide((max_range - min_range), n_steps) + # Calculate the scale and ensure it's positive + scale = (max_range - min_range) / (quant_max - quant_min) - # Calculate the reciprocal of the step size - inv_step_size = 1.0 / step_size + inv_scale = ops.reciprocal(scale) - # Round the reciprocal to get an integer - rounded_inv_step_size = ops.round(inv_step_size) + # Calculate the zero point from the min range + zero_point_from_min = quant_min - min_range / scale - # Calculate the final step size - final_step_size = 1.0 / rounded_inv_step_size + # Ensure zero point is within valid range [0, quant_max] + zero_point = ops.clip(zero_point_from_min, quant_min, quant_max) - # Calculate the quantized min/max values, ensuring accurate rounding - quantized_min = ( - ops.round(min_range * rounded_inv_step_size) / rounded_inv_step_size - ) - quantized_max = ( - ops.round(max_range * rounded_inv_step_size) / rounded_inv_step_size - ) - - # Convert quantization limits to float - quant_min_float = ops.cast(quantized_min, "float32") - quant_max_float = ops.cast(quantized_max, "float32") - - # Calculate the scale - nudged_scale = (max_range - min_range) / (quant_max_float - quant_min_float) - - # Calculate zero point from min - zero_point_from_min = quant_min_float - min_range / nudged_scale - - # Determine nudged zero point - nudged_zero_point = ops.where( - zero_point_from_min < quant_min_float, - quantized_min, - ops.where( - zero_point_from_min > quant_max_float, - quantized_max, - ops.round(zero_point_from_min), - ), - ) + # Nudge zero point if it's very close to an integer + nudged_zero_point = ops.round(zero_point) - # Calculate nudged min and max - nudged_min = (quant_min_float - nudged_zero_point) * nudged_scale - nudged_max = (quant_max_float - nudged_zero_point) * nudged_scale + # Calculate nudged limits + nudged_min = (quant_min - nudged_zero_point) * scale + nudged_max = (quant_max - nudged_zero_point) * scale - return ( - nudged_min, - nudged_max, - final_step_size, - ) # Returning nudged values and scale + return nudged_min, nudged_max, scale, inv_scale @keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel") @@ -220,32 +185,32 @@ def fake_quant_with_min_max_vars_per_channel( @ops.custom_gradient def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val): # Calculate quantization parameters for all channels at once - qnt_min, qnt_max, step_size = adjust_and_nudge( + nudged_min, nudged_max, scale, inv_scale = adjust_and_nudge( min_val, max_val, num_bits, narrow_range ) - # Calculate number of steps - n_steps = 2**num_bits - 1 - if narrow_range: - n_steps -= 1 - # Expand dimensions to allow broadcasting - qnt_min = ops.expand_dims(qnt_min, axis=list(range(len(x.shape) - 1))) - qnt_max = ops.expand_dims(qnt_max, axis=list(range(len(x.shape) - 1))) - step_size = ops.expand_dims( - step_size, axis=list(range(len(x.shape) - 1)) + nudged_min = ops.expand_dims( + nudged_min, axis=list(range(len(x.shape) - 1)) + ) + nudged_max = ops.expand_dims( + nudged_max, axis=list(range(len(x.shape) - 1)) + ) + scale = ops.expand_dims(scale, axis=list(range(len(x.shape) - 1))) + inv_scale = ops.expand_dims( + inv_scale, axis=list(range(len(x.shape) - 1)) ) - # Clip and quantize all channels simultaneously - x_clipped = ops.clip(x, qnt_min, qnt_max) - x_norm = (x_clipped - qnt_min) / step_size - x_quantized = ops.round(x_norm) - x_quantized = ops.clip(x_quantized, 0.0, n_steps) - result = x_quantized * step_size + qnt_min + quant_zero = ops.floor(-nudged_min * inv_scale + 0.5) + x_clamped = ops.clip(x, nudged_min, nudged_max) + x_clamped_shifted = x_clamped - nudged_min + result = ( + ops.floor(x_clamped_shifted * inv_scale - quant_zero + 0.5) * scale + ) # Create gradient mask for all channels masks = ops.cast( - (x >= qnt_min) & (x <= qnt_max), + (x >= nudged_min) & (x <= nudged_max), dtype="float32", ) @@ -258,13 +223,13 @@ def grad(*args, upstream=None): # Gradient for min_val # When x is clipped to min, the gradient flows to min_val - min_mask = ops.cast(x <= qnt_min, dtype="float32") + min_mask = ops.cast(x <= nudged_min, dtype="float32") dims_to_reduce = list(range(len(x.shape) - 1)) grad_min = ops.sum(upstream * min_mask, axis=dims_to_reduce) # Gradient for max_val # When x is clipped to max, the gradient flows to max_val - max_mask = ops.cast(x >= qnt_max, dtype="float32") + max_mask = ops.cast(x >= nudged_max, dtype="float32") grad_max = ops.sum(upstream * max_mask, axis=dims_to_reduce) return dx, grad_min, grad_max