|
4 | 4 | from keras.src import backend |
5 | 5 | from keras.src import ops |
6 | 6 | from keras.src.api_export import keras_export |
| 7 | +from keras.src.backend.common.backend_utils import canonicalize_axis |
7 | 8 | from keras.src.backend.common.backend_utils import standardize_axis_for_numpy |
8 | 9 |
|
9 | 10 | """Int8-related classes and methods""" |
@@ -127,6 +128,142 @@ def get_config(self): |
127 | 128 | } |
128 | 129 |
|
129 | 130 |
|
| 131 | +def adjust_and_nudge(min_range, max_range, num_bits, narrow_range): |
| 132 | + """Adjusts and nudges the quantization range for better accuracy.""" |
| 133 | + |
| 134 | + quant_max = ops.cast(ops.subtract(ops.power(2, num_bits), 1.0), "float32") |
| 135 | + |
| 136 | + quant_min = ops.cast(0.0 if not narrow_range else 1.0, "float32") |
| 137 | + |
| 138 | + # Calculate the scale and ensure it's positive |
| 139 | + scale = ops.divide( |
| 140 | + ops.subtract(max_range, min_range), ops.subtract(quant_max, quant_min) |
| 141 | + ) |
| 142 | + |
| 143 | + inv_scale = ops.reciprocal(scale) |
| 144 | + |
| 145 | + # Calculate the zero point from the min range |
| 146 | + zero_point_from_min = quant_min - ops.divide(min_range, scale) |
| 147 | + |
| 148 | + # Ensure zero point is within valid range [0, quant_max] |
| 149 | + zero_point = ops.clip(zero_point_from_min, quant_min, quant_max) |
| 150 | + |
| 151 | + # Nudge zero point if it's very close to an integer |
| 152 | + nudged_zero_point = ops.round(zero_point) |
| 153 | + |
| 154 | + # Calculate nudged limits |
| 155 | + nudged_min = ops.multiply(ops.subtract(quant_min, nudged_zero_point), scale) |
| 156 | + nudged_max = ops.multiply(ops.subtract(quant_max, nudged_zero_point), scale) |
| 157 | + |
| 158 | + return nudged_min, nudged_max, scale, inv_scale |
| 159 | + |
| 160 | + |
| 161 | +@keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel") |
| 162 | +def fake_quant_with_min_max_vars( |
| 163 | + inputs, |
| 164 | + min_vals, |
| 165 | + max_vals, |
| 166 | + num_bits, |
| 167 | + narrow_range=False, |
| 168 | + axis=None, |
| 169 | +): |
| 170 | + """ |
| 171 | + Perform per-tensor or per-channel fake quantization. |
| 172 | +
|
| 173 | + `[min_vals, max_vals]` define the clamping range for the `inputs`. |
| 174 | +
|
| 175 | + The `inputs` are quantized into the quantization range: |
| 176 | + - `[0, 2^num_bits - 1]` when `narrow_range=False` |
| 177 | + - `[1, 2^num_bits - 1]` when `narrow_range=True` |
| 178 | +
|
| 179 | + After quantization, the values are dequantized and output as floats within |
| 180 | + the `[min_vals, max_vals]` interval. |
| 181 | +
|
| 182 | + This operation supports gradient computation, allowing `min_vals` and |
| 183 | + `max_vals` to be trained. |
| 184 | +
|
| 185 | + Args: |
| 186 | + inputs: Input tensor of float dtype. |
| 187 | + min_vals: A global minimum scalar or a per-channel minimum tensor. |
| 188 | + max_vals: A global maximum scalar or a per-channel maximum tensor. |
| 189 | + num_bits: Quantization bit width (e.g., `8` for int8). |
| 190 | + narrow_range: Whether to use narrow quantization range. |
| 191 | + axis: Axis along which to perform per-channel quantization. If `None`, |
| 192 | + per-tensor quantization is performed. Defaults to `None`. |
| 193 | +
|
| 194 | +
|
| 195 | + Returns: |
| 196 | + Fake-quantized tensor |
| 197 | + """ |
| 198 | + inputs = ops.convert_to_tensor(inputs) |
| 199 | + min_vals = ops.convert_to_tensor(min_vals) |
| 200 | + max_vals = ops.convert_to_tensor(max_vals) |
| 201 | + |
| 202 | + if axis is not None: |
| 203 | + axis = canonicalize_axis(axis, inputs.ndim) |
| 204 | + |
| 205 | + @ops.custom_gradient |
| 206 | + def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val): |
| 207 | + # Calculate quantization parameters for all channels at once |
| 208 | + nudged_min, nudged_max, scale, inv_scale = adjust_and_nudge( |
| 209 | + min_val, max_val, num_bits, narrow_range |
| 210 | + ) |
| 211 | + |
| 212 | + quant_zero = ops.floor( |
| 213 | + ops.add(ops.multiply(-nudged_min, inv_scale), 0.5) |
| 214 | + ) |
| 215 | + x_clamped = ops.clip(x, nudged_min, nudged_max) |
| 216 | + x_clamped_shifted = ops.subtract(x_clamped, nudged_min) |
| 217 | + result = ops.multiply( |
| 218 | + ops.floor( |
| 219 | + ops.add( |
| 220 | + ops.subtract( |
| 221 | + ops.multiply(x_clamped_shifted, inv_scale), quant_zero |
| 222 | + ), |
| 223 | + 0.5, |
| 224 | + ) |
| 225 | + ), |
| 226 | + scale, |
| 227 | + ) |
| 228 | + |
| 229 | + # Create gradient mask for all channels |
| 230 | + masks = ops.cast( |
| 231 | + (x >= nudged_min) & (x <= nudged_max), |
| 232 | + dtype="float32", |
| 233 | + ) |
| 234 | + |
| 235 | + def grad(*args, upstream=None): |
| 236 | + if upstream is None: |
| 237 | + (upstream,) = args |
| 238 | + |
| 239 | + # Gradient for x |
| 240 | + dx = ops.multiply(upstream, masks) |
| 241 | + axes = [i for i in range(len(dx.shape)) if i != axis] |
| 242 | + # Gradient for min_val |
| 243 | + # When x is clipped to min, the gradient flows to min_val |
| 244 | + min_mask = ops.cast(x <= nudged_min, dtype="float32") |
| 245 | + grad_min = ops.multiply(upstream, min_mask) |
| 246 | + if axis is not None: |
| 247 | + grad_min = ops.sum(grad_min, axis=axes) |
| 248 | + else: |
| 249 | + grad_min = ops.sum(grad_min) |
| 250 | + |
| 251 | + # Gradient for max_val |
| 252 | + # When x is clipped to max, the gradient flows to max_val |
| 253 | + max_mask = ops.cast(x >= nudged_max, dtype="float32") |
| 254 | + grad_max = ops.multiply(upstream, max_mask) |
| 255 | + if axis is not None: |
| 256 | + grad_max = ops.sum(grad_max, axis=axes) |
| 257 | + else: |
| 258 | + grad_max = ops.sum(grad_max) |
| 259 | + |
| 260 | + return dx, grad_min, grad_max |
| 261 | + |
| 262 | + return result, grad |
| 263 | + |
| 264 | + return _fake_quant_with_min_max_vars_per_channel(inputs, min_vals, max_vals) |
| 265 | + |
| 266 | + |
130 | 267 | """Float8-related methods""" |
131 | 268 |
|
132 | 269 |
|
|
0 commit comments