Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[linen] Undocumented RNN errors #4513

Open
MRiabov opened this issue Jan 29, 2025 · 1 comment
Open

[linen] Undocumented RNN errors #4513

MRiabov opened this issue Jan 29, 2025 · 1 comment

Comments

@MRiabov
Copy link

MRiabov commented Jan 29, 2025

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 22.04 WSL
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax = 0.10.2 jax=0.5.0, jaxlib=0.5.0
  • Python version: 12
  • GPU/TPU model and memory: 1x Tesla V100, 32.0 GB
  • CUDA version (if applicable):

Problem you have encountered:

Hello Flax,
I'm stuck debugging linen.Bidirectional because it does not throw errors.
Logic tells me to do something like this:

def setup(self):

        self.rnn_var = nn.recurrent.Bidirectional(
            nn.RNN(
                nn.recurrent.OptimizedLSTMCell(self.rnn_var_lstm_cell_size),
                return_carry=True,
            ),
            nn.RNN(
                nn.recurrent.OptimizedLSTMCell(self.rnn_var_lstm_cell_size),
                return_carry=True,
            ),
            return_carry=True,
        )  # num_layers=10,

        self.initial_rnn_carry = (
            (  # note: don't forget to add ibatch size in the future.
                nn.initializers.zeros(
                    jax.random.PRNGKey(0),
                    (self.rnn_var_lstm_cell_size,),
                ),  # c (cell state)
                nn.initializers.zeros(
                    jax.random.PRNGKey(0),
                    (self.rnn_var_lstm_cell_size,),
                ),
            ),  # h (hidden state)
            (
                nn.initializers.zeros(
                    jax.random.PRNGKey(0),
                    (self.rnn_var_lstm_cell_size,),
                ),  # c (cell state) for backward
                nn.initializers.zeros(
                    jax.random.PRNGKey(0),
                    (self.rnn_var_lstm_cell_size,),
                ),
            ),  # h (hidden state) for backward
        )
    # ...

    def __call__(self, x: dict[str, Array]) -> Array:
        import pdb

        pdb.set_trace()
       rnn_var_carry, _processed_vars = self.rnn_var_rnn(
            nodes_gather_x, initial_carry=self.initial_rnn_carry, return_carry=True
        )
       #failure here^

And the problem is, I don't get understandable error messages, at all. For example this:

  File "/Problemologist-flax/dreamerv3_flax/encoder.py", line 159, in __call__
    rnn_var_rnn_carry, _processed_rnn_var = self.rnn_var_rnn(
                                          ^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/flax/linen/recurrent.py", line 1315, in __call__
    carry_forward, outputs_forward = self.forward_rnn(
                                     ^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/flax/linen/recurrent.py", line 1135, in __call__
    scan_output = scan(self.cell, carry, inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/flax/core/axes_scan.py", line 152, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/flax/core/axes_scan.py", line 124, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
                           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/flax/linen/recurrent.py", line 1114, in scan_fn
    carry, y = cell(carry, x)
               ^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/flax/linen/recurrent.py", line 336, in __call__
    dense_params_i[component] = DenseParams(
                                ^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/flax/linen/recurrent.py", line 214, in __call__
    (inputs.shape[-1], self.features),
     ~~~~~~~~~~~~^^^^ # <note, fails at `inputs.shape[-1]
IndexError: tuple index out of range

So what am I supposed to do?
I've already tried at least 7 different setups trying to workaround what I'm doing, namely, going from 1 carry to 4 (as stated here), working without initial carry at all.

What you expected to happen:

Give me adequate errors on what I'm doing wrong. I don't want to waste time debugging something I don't even know.

Logs, error messages, etc:

As above.
Also, there was an error something like c, h = carry(x, ...) which, I guess, was saying that I don't have enough carries, so I've resolved it with adding more jnp.zeros in self.initial_rnn_carry

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
As above.

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 4, 2025

Hi @MRiabov, can you post a fully reproducible example?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants