Unclear how to add up two grads of nn.Module type #389
Unanswered
BoyuanChen99
asked this question in
General
Replies: 1 comment
-
|
I believe you mean the grads are stored inside an _, params = NeRF_Model.init(random.PRNGKey(0), jnp.ones((10,10)))
model = nn.Model(NeRF_Model, params)
optimizer = flax.optim.Adam(learning_rate=0.01).create(model)
del model
# optimizer.target contains the modelyou could then use a tree_multimap: def dumb_loss(model, x):
return jnp.sum(model(x))
grad_fn = jax.grad(dumb_loss)
x1 = random.uniform(random.PRNGKey(0), (10, 10))
x2 = random.uniform(random.PRNGKey(1), (10, 10))
grad1 = grad_fn(optimizer.target, x1)
grad2 = grad_fn(optimizer.target, x2)
summed_grad = jax.tree_multimap(lambda x,y: x+y, grad1, grad2)in a colab, see: |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Description of the model to be implemented
I am trying to add up the grads for an optimizer. The grads has type flax.nn.Module. I wonder how to sum them up.
Dataset the model could be trained on
Image Data
Specific points to consider
/
Reference implementations in other frameworks
/
Beta Was this translation helpful? Give feedback.
All reactions