Open
Description
There are several pytests that fail when run against the latest JAX etc.
See for example https://github.com/NVIDIA/JAX-Toolbox/actions/runs/14589659816/job/40924621016#step:7:562
=================================== FAILURES ===================================
_________________________________ test_module __________________________________
def test_module():
denses_features_replaces = []
def _replace_module(module):
if isinstance(module, nn.Dense):
denses_features_replaces.append(module.features)
return WrapperModule(wrapped=module)
else:
return module
model = MyModule()
with peft.ModuleInterceptor(_replace_module):
> out, params = model.init_with_output(jax.random.key(0), jnp.zeros((3,)))
gemma/peft/_interceptors_test.py:76:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
gemma/peft/_interceptors.py:118: in interceptor
return next_fun(*args, **kwargs)
gemma/peft/_interceptors_test.py:51: in __call__
y1 = model(x)
gemma/peft/_interceptors.py:123: in interceptor
return getattr(module, context.method_name)(*args, **kwargs)
gemma/peft/_interceptors.py:103: in interceptor
return next_fun(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <[RecursionError('maximum recursion depth exceeded') raised in repr()] WrapperModule object at 0x7fc3b813f110>
args = (Array([0., 0., 0.], dtype=float32),), kwargs = {}
@nn.compact
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Create an extra param.
> self.param('extra_param', lambda _: jnp.zeros(()))
E TypeError: Module.param() missing 1 required positional argument: 'shape'
E --------------------
E For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
gemma/peft/_interceptors_test.py:37: TypeError
_________________________ test_module_non_share_scope __________________________
def test_module_non_share_scope():
denses_features_replaces = []
def _replace_module(module):
if isinstance(module, nn.Dense):
denses_features_replaces.append(module.features)
return WrapperModule(wrapped=module, share_scope=False)
else:
return module
model = MyModule()
with peft.ModuleInterceptor(_replace_module):
> out, params = model.init_with_output(jax.random.key(0), jnp.zeros((3,)))
gemma/peft/_interceptors_test.py:117:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
gemma/peft/_interceptors.py:118: in interceptor
return next_fun(*args, **kwargs)
gemma/peft/_interceptors_test.py:51: in __call__
y1 = model(x)
gemma/peft/_interceptors.py:123: in interceptor
return getattr(module, context.method_name)(*args, **kwargs)
gemma/peft/_interceptors.py:103: in interceptor
return next_fun(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <[RecursionError('maximum recursion depth exceeded') raised in repr()] WrapperModule object at 0x7fc3b8417440>
args = (Array([0., 0., 0.], dtype=float32),), kwargs = {}
@nn.compact
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Create an extra param.
> self.param('extra_param', lambda _: jnp.zeros(()))
E TypeError: Module.param() missing 1 required positional argument: 'shape'
E --------------------
E For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
gemma/peft/_interceptors_test.py:37: TypeError
____________________ TransformerTest.test_forward_no_cache0 ____________________
self = <gemma.transformer_test.TransformerTest testMethod=test_forward_no_cache0>
batch_size = 1, seq_size = 4
config = TransformerConfig(num_layers=2, num_embed=4, embed_dim=2, hidden_dim=12, num_heads=3, head_dim=4, num_kv_heads=3, fina...obal_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
@parameterized.parameters([
dict(
batch_size=1,
seq_size=4,
config=transformer_lib.TransformerConfig(
num_layers=2,
num_embed=4, # unused
embed_dim=2,
hidden_dim=12, # unused
num_heads=3,
head_dim=4,
num_kv_heads=3,
max_cache_length=6,
final_logit_softcap=None,
attention_types=[modules.AttentionType.GLOBAL] * 2,
use_post_attn_norm=False,
use_post_ffw_norm=False,
),
)
])
def test_forward_no_cache(
self,
batch_size: int,
seq_size: int,
config: transformer_lib.TransformerConfig,
):
token_input = jnp.ones((batch_size, seq_size), dtype=jnp.int32)
empty_cache = config.init_cache(batch_size, dtype=jnp.float32)
with jax.numpy_rank_promotion('raise'):
> transformer = transformer_lib.Transformer(config=config)
gemma/transformer_test.py:291:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../flax/flax/linen/kw_only_dataclasses.py:235: in init_wrapper
dataclass_init(self, *args, **kwargs)
<string>:6: in __init__
???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Transformer(
# attributes
config = TransformerConfig(num_layers=2, num_embed=4, embed_dim=2, hidden_dim=12, nu...al_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
)
def __post_init__(self):
if type(self) == Transformer: # pylint: disable=unidiomatic-typecheck]
msg = (
'The old Transformer class is deprecated, behave unexpectedly and'
" doesn't support multimodal."
' Instead, `gm.nn.GemmaXX` should be used.'
' See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/) '
)
> raise DeprecationWarning(msg)
E DeprecationWarning: The old Transformer class is deprecated, behave unexpectedly and doesn't support multimodal. Instead, `gm.nn.GemmaXX` should be used. See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/)
gemma/transformer.py:418: DeprecationWarning
_____________________ TransformerTest.test_logit_softcap0 ______________________
self = <gemma.transformer_test.TransformerTest testMethod=test_logit_softcap0>
soft_cap_arg = 'final_logit_softcap'
@parameterized.parameters(
('final_logit_softcap',),
('attn_logits_soft_cap',),
)
def test_logit_softcap(
self,
soft_cap_arg,
):
cache_size = 2
batch_size = 1
sequence_length = 1
soft_cap_val = 0.001
attention_mask = jnp.ones((batch_size, 1, cache_size), dtype=jnp.bool)
params = dict(
num_layers=3,
num_embed=4,
embed_dim=2,
num_heads=2,
num_kv_heads=1,
hidden_dim=4,
head_dim=4,
max_cache_length=cache_size,
attention_types=[modules.AttentionType.GLOBAL] * 3,
use_post_attn_norm=False,
use_post_ffw_norm=False,
)
no_soft_cap_args = {
'final_logit_softcap': None,
'attn_logits_soft_cap': None,
}
soft_cap_args = no_soft_cap_args.copy()
soft_cap_args[soft_cap_arg] = soft_cap_val
config_soft_cap = transformer_lib.TransformerConfig(
**(params | soft_cap_args)
)
config_no_soft_cap = transformer_lib.TransformerConfig(
**(params | no_soft_cap_args)
)
all_outputs = []
for config in [config_soft_cap, config_no_soft_cap]:
with jax.numpy_rank_promotion('raise'):
cache = config.init_cache(batch_size, dtype=jnp.float32)
> transformer = transformer_lib.Transformer(config=config)
gemma/transformer_test.py:207:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../flax/flax/linen/kw_only_dataclasses.py:235: in init_wrapper
dataclass_init(self, *args, **kwargs)
<string>:6: in __init__
???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Transformer(
# attributes
config = TransformerConfig(num_layers=3, num_embed=4, embed_dim=2, hidden_dim=4, num...al_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
)
def __post_init__(self):
if type(self) == Transformer: # pylint: disable=unidiomatic-typecheck]
msg = (
'The old Transformer class is deprecated, behave unexpectedly and'
" doesn't support multimodal."
' Instead, `gm.nn.GemmaXX` should be used.'
' See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/) '
)
> raise DeprecationWarning(msg)
E DeprecationWarning: The old Transformer class is deprecated, behave unexpectedly and doesn't support multimodal. Instead, `gm.nn.GemmaXX` should be used. See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/)
gemma/transformer.py:418: DeprecationWarning
_____________________ TransformerTest.test_logit_softcap1 ______________________
self = <gemma.transformer_test.TransformerTest testMethod=test_logit_softcap1>
soft_cap_arg = 'attn_logits_soft_cap'
@parameterized.parameters(
('final_logit_softcap',),
('attn_logits_soft_cap',),
)
def test_logit_softcap(
self,
soft_cap_arg,
):
cache_size = 2
batch_size = 1
sequence_length = 1
soft_cap_val = 0.001
attention_mask = jnp.ones((batch_size, 1, cache_size), dtype=jnp.bool)
params = dict(
num_layers=3,
num_embed=4,
embed_dim=2,
num_heads=2,
num_kv_heads=1,
hidden_dim=4,
head_dim=4,
max_cache_length=cache_size,
attention_types=[modules.AttentionType.GLOBAL] * 3,
use_post_attn_norm=False,
use_post_ffw_norm=False,
)
no_soft_cap_args = {
'final_logit_softcap': None,
'attn_logits_soft_cap': None,
}
soft_cap_args = no_soft_cap_args.copy()
soft_cap_args[soft_cap_arg] = soft_cap_val
config_soft_cap = transformer_lib.TransformerConfig(
**(params | soft_cap_args)
)
config_no_soft_cap = transformer_lib.TransformerConfig(
**(params | no_soft_cap_args)
)
all_outputs = []
for config in [config_soft_cap, config_no_soft_cap]:
with jax.numpy_rank_promotion('raise'):
cache = config.init_cache(batch_size, dtype=jnp.float32)
> transformer = transformer_lib.Transformer(config=config)
gemma/transformer_test.py:207:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../flax/flax/linen/kw_only_dataclasses.py:235: in init_wrapper
dataclass_init(self, *args, **kwargs)
<string>:6: in __init__
???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Transformer(
# attributes
config = TransformerConfig(num_layers=3, num_embed=4, embed_dim=2, hidden_dim=4, num...al_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
)
def __post_init__(self):
if type(self) == Transformer: # pylint: disable=unidiomatic-typecheck]
msg = (
'The old Transformer class is deprecated, behave unexpectedly and'
" doesn't support multimodal."
' Instead, `gm.nn.GemmaXX` should be used.'
' See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/) '
)
> raise DeprecationWarning(msg)
E DeprecationWarning: The old Transformer class is deprecated, behave unexpectedly and doesn't support multimodal. Instead, `gm.nn.GemmaXX` should be used. See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/)
gemma/transformer.py:418: DeprecationWarning
______________________ TransformerTest.test_transformer0 _______________________
self = <gemma.transformer_test.TransformerTest testMethod=test_transformer0>
num_layers = 3, num_embed = 17, embed_dim = 2, num_heads = 2, num_kv_heads = 2
hidden_dim = 11, head_dim = 8, cache_size = 29, batch_size = 7
sequence_length = 17, expected_outputs_shape = (7, 17, 17)
expected_cache_shape = (7, 29, 2, 8)
@parameterized.parameters(
# Prime number to ease shape tracing
dict(
num_layers=3,
num_embed=17,
embed_dim=2,
num_heads=2,
num_kv_heads=2,
hidden_dim=11,
head_dim=8,
cache_size=29,
batch_size=7,
sequence_length=17,
expected_outputs_shape=(7, 17, 17),
expected_cache_shape=(7, 29, 2, 8),
),
dict(
num_layers=3,
num_embed=4,
embed_dim=2,
num_heads=2,
num_kv_heads=1,
hidden_dim=4,
head_dim=4,
cache_size=2,
batch_size=1,
sequence_length=1,
expected_outputs_shape=(1, 1, 4),
expected_cache_shape=(1, 2, 1, 4),
),
)
def test_transformer(
self,
num_layers,
num_embed,
embed_dim,
num_heads,
num_kv_heads,
hidden_dim,
head_dim,
cache_size,
batch_size,
sequence_length,
expected_outputs_shape,
expected_cache_shape,
):
config = transformer_lib.TransformerConfig(
num_layers=num_layers,
num_embed=num_embed,
embed_dim=embed_dim,
hidden_dim=hidden_dim,
num_heads=num_heads,
head_dim=head_dim,
num_kv_heads=num_kv_heads,
max_cache_length=cache_size,
final_logit_softcap=None,
attention_types=[modules.AttentionType.GLOBAL] * num_layers,
use_post_attn_norm=False,
use_post_ffw_norm=False,
)
cache = config.init_cache(batch_size, dtype=jnp.float32)
attention_mask = jnp.ones((batch_size, 1, cache_size), dtype=jnp.bool)
with jax.numpy_rank_promotion('raise'):
> transformer = transformer_lib.Transformer(config=config)
gemma/transformer_test.py:139:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../flax/flax/linen/kw_only_dataclasses.py:235: in init_wrapper
dataclass_init(self, *args, **kwargs)
<string>:6: in __init__
???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Transformer(
# attributes
config = TransformerConfig(num_layers=3, num_embed=17, embed_dim=2, hidden_dim=11, n...al_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
)
def __post_init__(self):
if type(self) == Transformer: # pylint: disable=unidiomatic-typecheck]
msg = (
'The old Transformer class is deprecated, behave unexpectedly and'
" doesn't support multimodal."
' Instead, `gm.nn.GemmaXX` should be used.'
' See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/) '
)
> raise DeprecationWarning(msg)
E DeprecationWarning: The old Transformer class is deprecated, behave unexpectedly and doesn't support multimodal. Instead, `gm.nn.GemmaXX` should be used. See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/)
gemma/transformer.py:418: DeprecationWarning
______________________ TransformerTest.test_transformer1 _______________________
self = <gemma.transformer_test.TransformerTest testMethod=test_transformer1>
num_layers = 3, num_embed = 4, embed_dim = 2, num_heads = 2, num_kv_heads = 1
hidden_dim = 4, head_dim = 4, cache_size = 2, batch_size = 1
sequence_length = 1, expected_outputs_shape = (1, 1, 4)
expected_cache_shape = (1, 2, 1, 4)
@parameterized.parameters(
# Prime number to ease shape tracing
dict(
num_layers=3,
num_embed=17,
embed_dim=2,
num_heads=2,
num_kv_heads=2,
hidden_dim=11,
head_dim=8,
cache_size=29,
batch_size=7,
sequence_length=17,
expected_outputs_shape=(7, 17, 17),
expected_cache_shape=(7, 29, 2, 8),
),
dict(
num_layers=3,
num_embed=4,
embed_dim=2,
num_heads=2,
num_kv_heads=1,
hidden_dim=4,
head_dim=4,
cache_size=2,
batch_size=1,
sequence_length=1,
expected_outputs_shape=(1, 1, 4),
expected_cache_shape=(1, 2, 1, 4),
),
)
def test_transformer(
self,
num_layers,
num_embed,
embed_dim,
num_heads,
num_kv_heads,
hidden_dim,
head_dim,
cache_size,
batch_size,
sequence_length,
expected_outputs_shape,
expected_cache_shape,
):
config = transformer_lib.TransformerConfig(
num_layers=num_layers,
num_embed=num_embed,
embed_dim=embed_dim,
hidden_dim=hidden_dim,
num_heads=num_heads,
head_dim=head_dim,
num_kv_heads=num_kv_heads,
max_cache_length=cache_size,
final_logit_softcap=None,
attention_types=[modules.AttentionType.GLOBAL] * num_layers,
use_post_attn_norm=False,
use_post_ffw_norm=False,
)
cache = config.init_cache(batch_size, dtype=jnp.float32)
attention_mask = jnp.ones((batch_size, 1, cache_size), dtype=jnp.bool)
with jax.numpy_rank_promotion('raise'):
> transformer = transformer_lib.Transformer(config=config)
gemma/transformer_test.py:139:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../flax/flax/linen/kw_only_dataclasses.py:235: in init_wrapper
dataclass_init(self, *args, **kwargs)
<string>:6: in __init__
???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Transformer(
# attributes
config = TransformerConfig(num_layers=3, num_embed=4, embed_dim=2, hidden_dim=4, num...al_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, mm_extra_vocab_size=0, vision_encoder=None)
)
def __post_init__(self):
if type(self) == Transformer: # pylint: disable=unidiomatic-typecheck]
msg = (
'The old Transformer class is deprecated, behave unexpectedly and'
" doesn't support multimodal."
' Instead, `gm.nn.GemmaXX` should be used.'
' See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/) '
)
> raise DeprecationWarning(msg)
E DeprecationWarning: The old Transformer class is deprecated, behave unexpectedly and doesn't support multimodal. Instead, `gm.nn.GemmaXX` should be used. See the documentation at [https://gemma-llm.readthedocs.io/.](https://gemma-llm.readthedocs.io/)
gemma/transformer.py:418: DeprecationWarning
=============================== warnings summary ===============================
tests/test_gemma.py:16
/opt/gemma/tests/test_gemma.py:16: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display
from IPython.core.display import display, HTML
tests/test_gemma.py:32
/opt/gemma/tests/test_gemma.py:32: DeprecationWarning: jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.
backend = jax.lib.xla_bridge.get_backend()
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED gemma/peft/_interceptors_test.py::test_module - TypeError: Module.para...
FAILED gemma/peft/_interceptors_test.py::test_module_non_share_scope - TypeEr...
FAILED gemma/transformer_test.py::TransformerTest::test_forward_no_cache0 - D...
FAILED gemma/transformer_test.py::TransformerTest::test_logit_softcap0 - Depr...
FAILED gemma/transformer_test.py::TransformerTest::test_logit_softcap1 - Depr...
FAILED gemma/transformer_test.py::TransformerTest::test_transformer0 - Deprec...
FAILED gemma/transformer_test.py::TransformerTest::test_transformer1 - Deprec...
============= 7 failed, 67 passed, 2 warnings in 136.56s (0:02:16) =============
Metadata
Metadata
Assignees
Labels
No labels