diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 1272c453563..9d47642ef34 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -224,13 +224,12 @@ def grad(*args, upstream=None): # Gradient for min_val # When x is clipped to min, the gradient flows to min_val min_mask = ops.cast(x <= nudged_min, dtype="float32") - dims_to_reduce = list(range(len(x.shape) - 1)) - grad_min = ops.sum(upstream * min_mask, axis=dims_to_reduce) + grad_min = ops.sum(upstream * min_mask) # Gradient for max_val # When x is clipped to max, the gradient flows to max_val max_mask = ops.cast(x >= nudged_max, dtype="float32") - grad_max = ops.sum(upstream * max_mask, axis=dims_to_reduce) + grad_max = ops.sum(upstream * max_mask) return dx, grad_min, grad_max