Hi, I got this error when running the above model with GemmConfig: "/python3.11/site-packages/jax/_src/numpy/ufuncs.py", line 1280, in multiply return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) ``` from transformers.models.gemma import modeling_flax_gemma from transformers import GemmaConfig config = GemmaConfig() model = modeling_flax_gemma.FlaxGemmaModule(config, dtype=jnp.float32) input_ids = jnp.zeros((32, 128), dtype=jnp.int32) variables = model.init( jax.random.key(0), input_ids=input_ids, ) def model_apply(input_ids): return model.apply(variables, input_ids=input_ids) model_apply(input_ids) ``` I am using transformers 4.53.2 and jax 3.10. Could you please take a look? Thanks!