How would I reparametrize nnx.Module parameters?
#4546
-
|
In jax, it's easy to re-parametrize a neural network using something similar to the following: How do I achieve something similar using nnx since params are part of the model. Ofcourse I can use something like: Now if i want to get the grads w.r.t. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 7 replies
-
|
I also looked into the LoRA implementation to see how this issue might be handled. Here's what happens: def __call__(self, x: jax.Array):
out = x @ self.lora_a @ self.lora_b
if self.base_module is not None:
if not callable(self.base_module):
raise ValueError('`self.base_module` must be callable.')
out += self.base_module(x)
return outBut the problem here is that it is essentially calculating So essentially, my question could be simplified as follows. Given a reparametrization function, |
Beta Was this translation helpful? Give feedback.
-
|
Hi @aniquetahir, to get a gradient wrt to any substate you can pass a |
Beta Was this translation helpful? Give feedback.
@aniquetahir can you create a separate optimizer for
new_model(maybe call itsampled_model) at the begging and then simply update it after sampling its params? E.g.