Skip to content

Commit

Permalink
all simple tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
doncarlos999 committed Jan 13, 2025
1 parent 0ac1b2b commit d582f65
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions keras/src/quantizers/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,12 @@ def grad(*args, upstream=None):
# Gradient for min_val
# When x is clipped to min, the gradient flows to min_val
min_mask = ops.cast(x <= nudged_min, dtype="float32")
dims_to_reduce = list(range(len(x.shape) - 1))
grad_min = ops.sum(upstream * min_mask, axis=dims_to_reduce)
grad_min = ops.sum(upstream * min_mask)

# Gradient for max_val
# When x is clipped to max, the gradient flows to max_val
max_mask = ops.cast(x >= nudged_max, dtype="float32")
grad_max = ops.sum(upstream * max_mask, axis=dims_to_reduce)
grad_max = ops.sum(upstream * max_mask)

return dx, grad_min, grad_max

Expand Down

0 comments on commit d582f65

Please sign in to comment.