Skip to content

Commit 14f2141

Browse files
ConchylicultorThe gemma Authors
authored andcommitted
Add transformer tests
PiperOrigin-RevId: 753152343
1 parent 45d1d91 commit 14f2141

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

gemma/gm/nn/_transformer_test.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2024 DeepMind Technologies Limited.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any
16+
17+
from gemma import gm
18+
import jax
19+
import jax.numpy as jnp
20+
import pytest
21+
22+
BATCH_SIZE = 4
23+
SEQ_LEN = 16
24+
NUM_IMAGES = 1
25+
26+
27+
def _get_output(model: gm.nn.Transformer, **kwargs) -> tuple[gm.nn.Output, Any]:
28+
29+
def init_fn(**kwargs):
30+
out, params = model.init_with_output(jax.random.key(0), **kwargs)
31+
return out, params['params']
32+
33+
return jax.eval_shape(init_fn, **kwargs)
34+
35+
36+
@pytest.mark.parametrize(
37+
'model_cls',
38+
[
39+
gm.nn.Gemma3_1B,
40+
gm.nn.Gemma3_4B,
41+
gm.nn.Gemma3_12B,
42+
gm.nn.Gemma3_27B,
43+
],
44+
)
45+
def test_transformer(model_cls: type[gm.nn.Transformer]):
46+
model = model_cls() # pylint: disable=missing-kwoa # pytype: disable=missing-parameter
47+
tokens = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32)
48+
out, _ = _get_output(model, tokens=tokens)
49+
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed)
50+
51+
52+
def test_images():
53+
54+
model = gm.nn.Gemma3_4B() # pylint: disable=missing-kwoa # pytype: disable=missing-parameter
55+
56+
tokens = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32)
57+
images = jnp.ones((BATCH_SIZE, NUM_IMAGES, 64, 64, 3), dtype=jnp.uint8)
58+
out, _ = _get_output(model, tokens=tokens, images=images)
59+
60+
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed)
61+
62+
63+
def test_text_only():
64+
65+
model = gm.nn.Gemma3_4B(text_only=True)
66+
67+
tokens = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32)
68+
images = jnp.ones((BATCH_SIZE, NUM_IMAGES, 64, 64, 3), dtype=jnp.uint8)
69+
70+
with pytest.raises(ValueError, match='does not have vision encoder'):
71+
_get_output(model, tokens=tokens, images=images)
72+
73+
out, params = _get_output(model, tokens=tokens)
74+
assert 'vision_encoder' not in params # Vision params not loaded
75+
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed)
76+
77+
78+
def test_last_only():
79+
model = gm.nn.Gemma3_4B(return_last_only=True)
80+
tokens = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32)
81+
out, params = _get_output(model, tokens=tokens)
82+
assert 'vision_encoder' in params # Vision by default
83+
assert out.logits.shape == (BATCH_SIZE, model.config.num_embed)

0 commit comments

Comments
 (0)