Skip to content
Discussion options

You must be logged in to vote

jax.value_and_grad computes the loss wrt the first argument. So in this code snippet you are computing the gradients wrt the inputs instead of the params. You can fix this by passing the model params as the first argument of the loss function

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@mwitiderrick
Comment options

Answer selected by jheek
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #2419 on August 29, 2022 10:43.