You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 22.04 WSL
Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax = 0.10.2 jax=0.5.0, jaxlib=0.5.0
Python version: 12
GPU/TPU model and memory: 1x Tesla V100, 32.0 GB
CUDA version (if applicable):
Problem you have encountered:
Hello Flax,
I'm stuck debugging linen.Bidirectional because it does not throw errors.
Logic tells me to do something like this:
defsetup(self):
self.rnn_var=nn.recurrent.Bidirectional(
nn.RNN(
nn.recurrent.OptimizedLSTMCell(self.rnn_var_lstm_cell_size),
return_carry=True,
),
nn.RNN(
nn.recurrent.OptimizedLSTMCell(self.rnn_var_lstm_cell_size),
return_carry=True,
),
return_carry=True,
) # num_layers=10,self.initial_rnn_carry= (
( # note: don't forget to add ibatch size in the future.nn.initializers.zeros(
jax.random.PRNGKey(0),
(self.rnn_var_lstm_cell_size,),
), # c (cell state)nn.initializers.zeros(
jax.random.PRNGKey(0),
(self.rnn_var_lstm_cell_size,),
),
), # h (hidden state)
(
nn.initializers.zeros(
jax.random.PRNGKey(0),
(self.rnn_var_lstm_cell_size,),
), # c (cell state) for backwardnn.initializers.zeros(
jax.random.PRNGKey(0),
(self.rnn_var_lstm_cell_size,),
),
), # h (hidden state) for backward
)
# ...def__call__(self, x: dict[str, Array]) ->Array:
importpdbpdb.set_trace()
rnn_var_carry, _processed_vars=self.rnn_var_rnn(
nodes_gather_x, initial_carry=self.initial_rnn_carry, return_carry=True
)
#failure here^
And the problem is, I don't get understandable error messages, at all. For example this:
File "/Problemologist-flax/dreamerv3_flax/encoder.py", line 159, in __call__
rnn_var_rnn_carry, _processed_rnn_var = self.rnn_var_rnn(
^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/flax/linen/recurrent.py", line 1315, in __call__
carry_forward, outputs_forward = self.forward_rnn(
^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/flax/linen/recurrent.py", line 1135, in __call__
scan_output = scan(self.cell, carry, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/flax/core/axes_scan.py", line 152, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/flax/core/axes_scan.py", line 124, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/flax/linen/recurrent.py", line 1114, in scan_fn
carry, y = cell(carry, x)
^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/flax/linen/recurrent.py", line 336, in __call__
dense_params_i[component] = DenseParams(
^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/flax/linen/recurrent.py", line 214, in __call__
(inputs.shape[-1], self.features),
~~~~~~~~~~~~^^^^ # <note, fails at `inputs.shape[-1]
IndexError: tuple index out of range
So what am I supposed to do?
I've already tried at least 7 different setups trying to workaround what I'm doing, namely, going from 1 carry to 4 (as stated here), working without initial carry at all.
What you expected to happen:
Give me adequate errors on what I'm doing wrong. I don't want to waste time debugging something I don't even know.
Logs, error messages, etc:
As above.
Also, there was an error something like c, h = carry(x, ...) which, I guess, was saying that I don't have enough carries, so I've resolved it with adding more jnp.zeros in self.initial_rnn_carry
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
As above.
The text was updated successfully, but these errors were encountered:
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
pip show flax jax jaxlib
: flax = 0.10.2 jax=0.5.0, jaxlib=0.5.0Problem you have encountered:
Hello Flax,
I'm stuck debugging linen.Bidirectional because it does not throw errors.
Logic tells me to do something like this:
And the problem is, I don't get understandable error messages, at all. For example this:
So what am I supposed to do?
I've already tried at least 7 different setups trying to workaround what I'm doing, namely, going from 1 carry to 4 (as stated here), working without initial carry at all.
What you expected to happen:
Give me adequate errors on what I'm doing wrong. I don't want to waste time debugging something I don't even know.
Logs, error messages, etc:
As above.
Also, there was an error something like
c, h = carry(x, ...)
which, I guess, was saying that I don't have enough carries, so I've resolved it with adding more jnp.zeros inself.initial_rnn_carry
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
As above.
The text was updated successfully, but these errors were encountered: