Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why does LSTMCell keep rngs in its state? #4509

Open
JoaoAparicio opened this issue Jan 28, 2025 · 1 comment
Open

Why does LSTMCell keep rngs in its state? #4509

JoaoAparicio opened this issue Jan 28, 2025 · 1 comment

Comments

@JoaoAparicio
Copy link

JoaoAparicio commented Jan 28, 2025

It seems that LSTMCell keeps rngs in its state:

self.rngs = rngs

Is this intentional? Why?

I stumbled upon this because my recipe for checkpointing breaks when my model contains an LSTM:

import orbax.checkpoint as ocp
def savemodel(model, path):
    _, state = nnx.split(model)
    checkpointer = ocp.StandardCheckpointer()
    checkpointer.save(path, state)

Calling savemodel(model, path) throws:

TypeError: JAX array with PRNGKey dtype cannot be converted to a NumPy array. Use jax.random.key_data(arr) if you wish to extract the underlying integer array.

This was surprising because I've been using that recipe before and never had a problem while using other non-LSTM modules.

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 3, 2025

Hi @JoaoAparicio, great question. Its because in general the carry initializer might use the random state. In practice its almost always zeros but currently we support the general case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants