Skip to content

modeling_flax_gemma.FlaxGemmaModule failed with incompatible shapes when running with GemmaConfig #39492

@nhatleSummer22

Description

@nhatleSummer22

Hi, I got this error when running the above model with GemmConfig:

"/python3.11/site-packages/jax/_src/numpy/ufuncs.py", line 1280, in multiply
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)

from transformers.models.gemma import modeling_flax_gemma

from transformers import GemmaConfig

config = GemmaConfig()
model = modeling_flax_gemma.FlaxGemmaModule(config, dtype=jnp.float32)

input_ids = jnp.zeros((32, 128), dtype=jnp.int32)

variables = model.init(
        jax.random.key(0),
        input_ids=input_ids,
      )

def model_apply(input_ids):
        return model.apply(variables, input_ids=input_ids)
model_apply(input_ids)

I am using transformers 4.53.2 and jax 3.10. Could you please take a look? Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions