Custom node type mismatch: expected type: <class 'flax.core.frozen_dict.FrozenDict'>, value: Traced<ShapedArray(float32[256,14,14,3])>with<DynamicJaxprTrace(level=0/1)>. #2420
Answered
by
jheek
mwitiderrick
asked this question in
General
-
|
Trying to implement batch norm, Am I doing something wrong? |
Beta Was this translation helpful? Give feedback.
Answered by
jheek
Aug 29, 2022
Replies: 1 comment 1 reply
-
|
jax.value_and_grad computes the loss wrt the first argument. So in this code snippet you are computing the gradients wrt the inputs instead of the params. You can fix this by passing the model params as the first argument of the loss function |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
jheek
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
jax.value_and_grad computes the loss wrt the first argument. So in this code snippet you are computing the gradients wrt the inputs instead of the params. You can fix this by passing the model params as the first argument of the loss function