You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I stumbled upon this because my recipe for checkpointing breaks when my model contains an LSTM:
import orbax.checkpoint as ocp
def savemodel(model, path):
_, state = nnx.split(model)
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(path, state)
Calling savemodel(model, path) throws:
TypeError: JAX array with PRNGKey dtype cannot be converted to a NumPy array. Use jax.random.key_data(arr) if you wish to extract the underlying integer array.
This was surprising because I've been using that recipe before and never had a problem while using other non-LSTM modules.
The text was updated successfully, but these errors were encountered:
Hi @JoaoAparicio, great question. Its because in general the carry initializer might use the random state. In practice its almost always zeros but currently we support the general case.
It seems that LSTMCell keeps rngs in its state:
flax/flax/nnx/nn/recurrent.py
Line 137 in a8a192f
Is this intentional? Why?
I stumbled upon this because my recipe for checkpointing breaks when my model contains an LSTM:
Calling
savemodel(model, path)
throws:This was surprising because I've been using that recipe before and never had a problem while using other non-LSTM modules.
The text was updated successfully, but these errors were encountered: