Why doesn't nnx.fori_loop work here? #4433
Unanswered
onnoeberhard
asked this question in
Q&A
Replies: 2 comments
-
|
It seems that a possible workaround is import jax
from flax import nnx
model = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(0)))
model2 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(1)))
container = nnx.Module()
container.model = model
container.model2 = model2
def f(i, x):
return x
nnx.fori_loop(0, 10, f, container)Is there a reason why this works and the above does not? And will this "solution" yield the expected results? I am very confused about this error. |
Beta Was this translation helpful? Give feedback.
0 replies
-
|
Thanks for reporting this. Looks like a bug. Converting this into an issue. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I want to train two models at the same time. To do this, I use a
fori_loop:The above code throws the following error:
ValueError: nnx.fori_loop requires body function's input and output to have the same reference and pytree structure, but they differ. If the mismatch comes from index_mapping field, you might have modified reference structure within the body function, which is not allowed.If I loop with only one model, for example
nnx.fori_loop(0, 10, f, (model, model)), there is no error. What is the problem here?Beta Was this translation helpful? Give feedback.
All reactions