Skip to content

Unit test failures #276

Open
Open
@olupton

Description

@olupton

There are several pytests that fail when run against the latest JAX etc.
See for example https://github.com/NVIDIA/JAX-Toolbox/actions/runs/14589659816/job/40924621016#step:7:562

=================================== FAILURES ===================================
_________________________________ test_module __________________________________

    def test_module():
      denses_features_replaces = []
    
      def _replace_module(module):
        if isinstance(module, nn.Dense):
          denses_features_replaces.append(module.features)
          return WrapperModule(wrapped=module)
        else:
          return module
    
      model = MyModule()
      with peft.ModuleInterceptor(_replace_module):
>       out, params = model.init_with_output(jax.random.key(0), jnp.zeros((3,)))

gemma/peft/_interceptors_test.py:76: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
gemma/peft/_interceptors.py:118: in interceptor
    return next_fun(*args, **kwargs)
gemma/peft/_interceptors_test.py:51: in __call__
    y1 = model(x)
gemma/peft/_interceptors.py:123: in interceptor
    return getattr(module, context.method_name)(*args, **kwargs)
gemma/peft/_interceptors.py:103: in interceptor
    return next_fun(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <[RecursionError('maximum recursion depth exceeded') raised in repr()] WrapperModule object at 0x7fc3b813f110>
args = (Array([0., 0., 0.], dtype=float32),), kwargs = {}

    @nn.compact
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
      # Create an extra param.
>     self.param('extra_param', lambda _: jnp.zeros(()))
E     TypeError: Module.param() missing 1 required positional argument: 'shape'
E     --------------------
E     For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

gemma/peft/_interceptors_test.py:37: TypeError
_________________________ test_module_non_share_scope __________________________

    def test_module_non_share_scope():
      denses_features_replaces = []
    
      def _replace_module(module):
        if isinstance(module, nn.Dense):
          denses_features_replaces.append(module.features)
          return WrapperModule(wrapped=module, share_scope=False)
        else:
          return module
    
      model = MyModule()
      with peft.ModuleInterceptor(_replace_module):
>       out, params = model.init_with_output(jax.random.key(0), jnp.zeros((3,)))

gemma/peft/_interceptors_test.py:117: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
gemma/peft/_interceptors.py:118: in interceptor
    return next_fun(*args, **kwargs)
gemma/peft/_interceptors_test.py:51: in __call__
    y1 = model(x)
gemma/peft/_interceptors.py:123: in interceptor
    return getattr(module, context.method_name)(*args, **kwargs)
gemma/peft/_interceptors.py:103: in interceptor
    return next_fun(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <[RecursionError('maximum recursion depth exceeded') raised in repr()] WrapperModule object at 0x7fc3b8417440>
args = (Array([0., 0., 0.], dtype=float32),), kwargs = {}

    @nn.compact
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
      # Create an extra param.
>     self.param('extra_param', lambda _: jnp.zeros(()))
E     TypeError: Module.param() missing 1 required positional argument: 'shape'
E     --------------------
E     For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

gemma/peft/_interceptors_test.py:37: TypeError
____________________ TransformerTest.test_forward_no_cache0 ____________________

self = <gemma.transformer_test.TransformerTest testMethod=test_forward_no_cache0>
batch_size = 1, seq_size = 4
config = TransformerConfig(num_layers=2, num_embed=4, embed_dim=2, hidden_dim=12, num_heads=3, head_dim=4, num_kv_heads=3, fina...obal_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)

    @parameterized.parameters([
        dict(
            batch_size=1,
            seq_size=4,
            config=transformer_lib.TransformerConfig(
                num_layers=2,
                num_embed=4,  # unused
                embed_dim=2,
                hidden_dim=12,  # unused
                num_heads=3,
                head_dim=4,
                num_kv_heads=3,
                max_cache_length=6,
                final_logit_softcap=None,
                attention_types=[modules.AttentionType.GLOBAL] * 2,
                use_post_attn_norm=False,
                use_post_ffw_norm=False,
            ),
        )
    ])
    def test_forward_no_cache(
        self,
        batch_size: int,
        seq_size: int,
        config: transformer_lib.TransformerConfig,
    ):
      token_input = jnp.ones((batch_size, seq_size), dtype=jnp.int32)
      empty_cache = config.init_cache(batch_size, dtype=jnp.float32)
      with jax.numpy_rank_promotion('raise'):
>       transformer = transformer_lib.Transformer(config=config)

gemma/transformer_test.py:291: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../flax/flax/linen/kw_only_dataclasses.py:235: in init_wrapper
    dataclass_init(self, *args, **kwargs)
<string>:6: in __init__
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Transformer(
    # attributes
    config = TransformerConfig(num_layers=2, num_embed=4, embed_dim=2, hidden_dim=12, nu...al_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
)

    def __post_init__(self):
      if type(self) == Transformer:  # pylint: disable=unidiomatic-typecheck]
        msg = (
            'The old Transformer class is deprecated, behave unexpectedly and'
            " doesn't support multimodal."
            ' Instead, `gm.nn.GemmaXX` should be used.'
            ' See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/) '
        )
>       raise DeprecationWarning(msg)
E       DeprecationWarning: The old Transformer class is deprecated, behave unexpectedly and doesn't support multimodal. Instead, `gm.nn.GemmaXX` should be used. See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/)

gemma/transformer.py:418: DeprecationWarning
_____________________ TransformerTest.test_logit_softcap0 ______________________

self = <gemma.transformer_test.TransformerTest testMethod=test_logit_softcap0>
soft_cap_arg = 'final_logit_softcap'

    @parameterized.parameters(
        ('final_logit_softcap',),
        ('attn_logits_soft_cap',),
    )
    def test_logit_softcap(
        self,
        soft_cap_arg,
    ):
      cache_size = 2
      batch_size = 1
      sequence_length = 1
      soft_cap_val = 0.001
    
      attention_mask = jnp.ones((batch_size, 1, cache_size), dtype=jnp.bool)
    
      params = dict(
          num_layers=3,
          num_embed=4,
          embed_dim=2,
          num_heads=2,
          num_kv_heads=1,
          hidden_dim=4,
          head_dim=4,
          max_cache_length=cache_size,
          attention_types=[modules.AttentionType.GLOBAL] * 3,
          use_post_attn_norm=False,
          use_post_ffw_norm=False,
      )
    
      no_soft_cap_args = {
          'final_logit_softcap': None,
          'attn_logits_soft_cap': None,
      }
    
      soft_cap_args = no_soft_cap_args.copy()
      soft_cap_args[soft_cap_arg] = soft_cap_val
    
      config_soft_cap = transformer_lib.TransformerConfig(
          **(params | soft_cap_args)
      )
      config_no_soft_cap = transformer_lib.TransformerConfig(
          **(params | no_soft_cap_args)
      )
    
      all_outputs = []
      for config in [config_soft_cap, config_no_soft_cap]:
        with jax.numpy_rank_promotion('raise'):
          cache = config.init_cache(batch_size, dtype=jnp.float32)
>         transformer = transformer_lib.Transformer(config=config)

gemma/transformer_test.py:207: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../flax/flax/linen/kw_only_dataclasses.py:235: in init_wrapper
    dataclass_init(self, *args, **kwargs)
<string>:6: in __init__
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Transformer(
    # attributes
    config = TransformerConfig(num_layers=3, num_embed=4, embed_dim=2, hidden_dim=4, num...al_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
)

    def __post_init__(self):
      if type(self) == Transformer:  # pylint: disable=unidiomatic-typecheck]
        msg = (
            'The old Transformer class is deprecated, behave unexpectedly and'
            " doesn't support multimodal."
            ' Instead, `gm.nn.GemmaXX` should be used.'
            ' See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/) '
        )
>       raise DeprecationWarning(msg)
E       DeprecationWarning: The old Transformer class is deprecated, behave unexpectedly and doesn't support multimodal. Instead, `gm.nn.GemmaXX` should be used. See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/)

gemma/transformer.py:418: DeprecationWarning
_____________________ TransformerTest.test_logit_softcap1 ______________________

self = <gemma.transformer_test.TransformerTest testMethod=test_logit_softcap1>
soft_cap_arg = 'attn_logits_soft_cap'

    @parameterized.parameters(
        ('final_logit_softcap',),
        ('attn_logits_soft_cap',),
    )
    def test_logit_softcap(
        self,
        soft_cap_arg,
    ):
      cache_size = 2
      batch_size = 1
      sequence_length = 1
      soft_cap_val = 0.001
    
      attention_mask = jnp.ones((batch_size, 1, cache_size), dtype=jnp.bool)
    
      params = dict(
          num_layers=3,
          num_embed=4,
          embed_dim=2,
          num_heads=2,
          num_kv_heads=1,
          hidden_dim=4,
          head_dim=4,
          max_cache_length=cache_size,
          attention_types=[modules.AttentionType.GLOBAL] * 3,
          use_post_attn_norm=False,
          use_post_ffw_norm=False,
      )
    
      no_soft_cap_args = {
          'final_logit_softcap': None,
          'attn_logits_soft_cap': None,
      }
    
      soft_cap_args = no_soft_cap_args.copy()
      soft_cap_args[soft_cap_arg] = soft_cap_val
    
      config_soft_cap = transformer_lib.TransformerConfig(
          **(params | soft_cap_args)
      )
      config_no_soft_cap = transformer_lib.TransformerConfig(
          **(params | no_soft_cap_args)
      )
    
      all_outputs = []
      for config in [config_soft_cap, config_no_soft_cap]:
        with jax.numpy_rank_promotion('raise'):
          cache = config.init_cache(batch_size, dtype=jnp.float32)
>         transformer = transformer_lib.Transformer(config=config)

gemma/transformer_test.py:207: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../flax/flax/linen/kw_only_dataclasses.py:235: in init_wrapper
    dataclass_init(self, *args, **kwargs)
<string>:6: in __init__
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Transformer(
    # attributes
    config = TransformerConfig(num_layers=3, num_embed=4, embed_dim=2, hidden_dim=4, num...al_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
)

    def __post_init__(self):
      if type(self) == Transformer:  # pylint: disable=unidiomatic-typecheck]
        msg = (
            'The old Transformer class is deprecated, behave unexpectedly and'
            " doesn't support multimodal."
            ' Instead, `gm.nn.GemmaXX` should be used.'
            ' See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/) '
        )
>       raise DeprecationWarning(msg)
E       DeprecationWarning: The old Transformer class is deprecated, behave unexpectedly and doesn't support multimodal. Instead, `gm.nn.GemmaXX` should be used. See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/)

gemma/transformer.py:418: DeprecationWarning
______________________ TransformerTest.test_transformer0 _______________________

self = <gemma.transformer_test.TransformerTest testMethod=test_transformer0>
num_layers = 3, num_embed = 17, embed_dim = 2, num_heads = 2, num_kv_heads = 2
hidden_dim = 11, head_dim = 8, cache_size = 29, batch_size = 7
sequence_length = 17, expected_outputs_shape = (7, 17, 17)
expected_cache_shape = (7, 29, 2, 8)

    @parameterized.parameters(
        # Prime number to ease shape tracing
        dict(
            num_layers=3,
            num_embed=17,
            embed_dim=2,
            num_heads=2,
            num_kv_heads=2,
            hidden_dim=11,
            head_dim=8,
            cache_size=29,
            batch_size=7,
            sequence_length=17,
            expected_outputs_shape=(7, 17, 17),
            expected_cache_shape=(7, 29, 2, 8),
        ),
        dict(
            num_layers=3,
            num_embed=4,
            embed_dim=2,
            num_heads=2,
            num_kv_heads=1,
            hidden_dim=4,
            head_dim=4,
            cache_size=2,
            batch_size=1,
            sequence_length=1,
            expected_outputs_shape=(1, 1, 4),
            expected_cache_shape=(1, 2, 1, 4),
        ),
    )
    def test_transformer(
        self,
        num_layers,
        num_embed,
        embed_dim,
        num_heads,
        num_kv_heads,
        hidden_dim,
        head_dim,
        cache_size,
        batch_size,
        sequence_length,
        expected_outputs_shape,
        expected_cache_shape,
    ):
    
      config = transformer_lib.TransformerConfig(
          num_layers=num_layers,
          num_embed=num_embed,
          embed_dim=embed_dim,
          hidden_dim=hidden_dim,
          num_heads=num_heads,
          head_dim=head_dim,
          num_kv_heads=num_kv_heads,
          max_cache_length=cache_size,
          final_logit_softcap=None,
          attention_types=[modules.AttentionType.GLOBAL] * num_layers,
          use_post_attn_norm=False,
          use_post_ffw_norm=False,
      )
      cache = config.init_cache(batch_size, dtype=jnp.float32)
      attention_mask = jnp.ones((batch_size, 1, cache_size), dtype=jnp.bool)
      with jax.numpy_rank_promotion('raise'):
>       transformer = transformer_lib.Transformer(config=config)

gemma/transformer_test.py:139: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../flax/flax/linen/kw_only_dataclasses.py:235: in init_wrapper
    dataclass_init(self, *args, **kwargs)
<string>:6: in __init__
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Transformer(
    # attributes
    config = TransformerConfig(num_layers=3, num_embed=17, embed_dim=2, hidden_dim=11, n...al_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
)

    def __post_init__(self):
      if type(self) == Transformer:  # pylint: disable=unidiomatic-typecheck]
        msg = (
            'The old Transformer class is deprecated, behave unexpectedly and'
            " doesn't support multimodal."
            ' Instead, `gm.nn.GemmaXX` should be used.'
            ' See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/) '
        )
>       raise DeprecationWarning(msg)
E       DeprecationWarning: The old Transformer class is deprecated, behave unexpectedly and doesn't support multimodal. Instead, `gm.nn.GemmaXX` should be used. See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/)

gemma/transformer.py:418: DeprecationWarning
______________________ TransformerTest.test_transformer1 _______________________

self = <gemma.transformer_test.TransformerTest testMethod=test_transformer1>
num_layers = 3, num_embed = 4, embed_dim = 2, num_heads = 2, num_kv_heads = 1
hidden_dim = 4, head_dim = 4, cache_size = 2, batch_size = 1
sequence_length = 1, expected_outputs_shape = (1, 1, 4)
expected_cache_shape = (1, 2, 1, 4)

    @parameterized.parameters(
        # Prime number to ease shape tracing
        dict(
            num_layers=3,
            num_embed=17,
            embed_dim=2,
            num_heads=2,
            num_kv_heads=2,
            hidden_dim=11,
            head_dim=8,
            cache_size=29,
            batch_size=7,
            sequence_length=17,
            expected_outputs_shape=(7, 17, 17),
            expected_cache_shape=(7, 29, 2, 8),
        ),
        dict(
            num_layers=3,
            num_embed=4,
            embed_dim=2,
            num_heads=2,
            num_kv_heads=1,
            hidden_dim=4,
            head_dim=4,
            cache_size=2,
            batch_size=1,
            sequence_length=1,
            expected_outputs_shape=(1, 1, 4),
            expected_cache_shape=(1, 2, 1, 4),
        ),
    )
    def test_transformer(
        self,
        num_layers,
        num_embed,
        embed_dim,
        num_heads,
        num_kv_heads,
        hidden_dim,
        head_dim,
        cache_size,
        batch_size,
        sequence_length,
        expected_outputs_shape,
        expected_cache_shape,
    ):
    
      config = transformer_lib.TransformerConfig(
          num_layers=num_layers,
          num_embed=num_embed,
          embed_dim=embed_dim,
          hidden_dim=hidden_dim,
          num_heads=num_heads,
          head_dim=head_dim,
          num_kv_heads=num_kv_heads,
          max_cache_length=cache_size,
          final_logit_softcap=None,
          attention_types=[modules.AttentionType.GLOBAL] * num_layers,
          use_post_attn_norm=False,
          use_post_ffw_norm=False,
      )
      cache = config.init_cache(batch_size, dtype=jnp.float32)
      attention_mask = jnp.ones((batch_size, 1, cache_size), dtype=jnp.bool)
      with jax.numpy_rank_promotion('raise'):
>       transformer = transformer_lib.Transformer(config=config)

gemma/transformer_test.py:139: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../flax/flax/linen/kw_only_dataclasses.py:235: in init_wrapper
    dataclass_init(self, *args, **kwargs)
<string>:6: in __init__
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Transformer(
    # attributes
    config = TransformerConfig(num_layers=3, num_embed=4, embed_dim=2, hidden_dim=4, num...al_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
)

    def __post_init__(self):
      if type(self) == Transformer:  # pylint: disable=unidiomatic-typecheck]
        msg = (
            'The old Transformer class is deprecated, behave unexpectedly and'
            " doesn't support multimodal."
            ' Instead, `gm.nn.GemmaXX` should be used.'
            ' See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/) '
        )
>       raise DeprecationWarning(msg)
E       DeprecationWarning: The old Transformer class is deprecated, behave unexpectedly and doesn't support multimodal. Instead, `gm.nn.GemmaXX` should be used. See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/)

gemma/transformer.py:418: DeprecationWarning
=============================== warnings summary ===============================
tests/test_gemma.py:16
  /opt/gemma/tests/test_gemma.py:16: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display
    from IPython.core.display import display, HTML

tests/test_gemma.py:32
  /opt/gemma/tests/test_gemma.py:32: DeprecationWarning: jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.
    backend = jax.lib.xla_bridge.get_backend()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED gemma/peft/_interceptors_test.py::test_module - TypeError: Module.para...
FAILED gemma/peft/_interceptors_test.py::test_module_non_share_scope - TypeEr...
FAILED gemma/transformer_test.py::TransformerTest::test_forward_no_cache0 - D...
FAILED gemma/transformer_test.py::TransformerTest::test_logit_softcap0 - Depr...
FAILED gemma/transformer_test.py::TransformerTest::test_logit_softcap1 - Depr...
FAILED gemma/transformer_test.py::TransformerTest::test_transformer0 - Deprec...
FAILED gemma/transformer_test.py::TransformerTest::test_transformer1 - Deprec...
============= 7 failed, 67 passed, 2 warnings in 136.56s (0:02:16) =============

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions