Apply Gradient Not Behaving appropriately #1030
Unanswered
vasilavramov
asked this question in
Q&A
Replies: 1 comment 1 reply
-
|
Hi @vasilavramov, I think you'll have to ask your question more concretely and/or share code to get more help. My first guess is that you're doing something like |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I am currently trying to train a NN through reinforcement learning and I am struggling to update the NN parameters once an epoch of training is complete. The gradients are computed as follows:
`def compute_gradients_EV(optimizer, input, params):
def EV_short(params):
return policy_EV(input, params)
policy_fit, grad = jax.value_and_grad(EV_short, has_aux = False)(params)
return policy_fit, grad
Error unsupported operand type(s) for *: 'float' and 'FrozenDict'`
The error I get makes no sense given that in the base documentation of the optim class it says that apply gradients should work with a pytree of gradients. I have tried a different method, which is similar to the one showed in examples where the jax.value_and_grad function is called with the model and the using optimiser.target to differentiate. However, in that case I get the error that the model with which I am trying to call the differentiating function 'is not a valid Jax type'.
From what I understand I have successfully computed the gradients, however updating them seems impossible using optimizer.apply_gradients(), but it should be. Anyway, any help would be much appreciated.
Beta Was this translation helpful? Give feedback.
All reactions