|
| 1 | +# Copyright 2024 DeepMind Technologies Limited. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from typing import Any |
| 16 | + |
| 17 | +from gemma import gm |
| 18 | +import jax |
| 19 | +import jax.numpy as jnp |
| 20 | +import pytest |
| 21 | + |
| 22 | +BATCH_SIZE = 4 |
| 23 | +SEQ_LEN = 16 |
| 24 | +NUM_IMAGES = 1 |
| 25 | + |
| 26 | + |
| 27 | +def _get_output(model: gm.nn.Transformer, **kwargs) -> tuple[gm.nn.Output, Any]: |
| 28 | + |
| 29 | + def init_fn(**kwargs): |
| 30 | + out, params = model.init_with_output(jax.random.key(0), **kwargs) |
| 31 | + return out, params['params'] |
| 32 | + |
| 33 | + return jax.eval_shape(init_fn, **kwargs) |
| 34 | + |
| 35 | + |
| 36 | +@pytest.mark.parametrize( |
| 37 | + 'model_cls', |
| 38 | + [ |
| 39 | + gm.nn.Gemma3_1B, |
| 40 | + gm.nn.Gemma3_4B, |
| 41 | + gm.nn.Gemma3_12B, |
| 42 | + gm.nn.Gemma3_27B, |
| 43 | + ], |
| 44 | +) |
| 45 | +def test_transformer(model_cls: type[gm.nn.Transformer]): |
| 46 | + model = model_cls() # pylint: disable=missing-kwoa # pytype: disable=missing-parameter |
| 47 | + tokens = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32) |
| 48 | + out, _ = _get_output(model, tokens=tokens) |
| 49 | + assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed) |
| 50 | + |
| 51 | + |
| 52 | +def test_images(): |
| 53 | + |
| 54 | + model = gm.nn.Gemma3_4B() # pylint: disable=missing-kwoa # pytype: disable=missing-parameter |
| 55 | + |
| 56 | + tokens = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32) |
| 57 | + images = jnp.ones((BATCH_SIZE, NUM_IMAGES, 64, 64, 3), dtype=jnp.uint8) |
| 58 | + out, _ = _get_output(model, tokens=tokens, images=images) |
| 59 | + |
| 60 | + assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed) |
| 61 | + |
| 62 | + |
| 63 | +def test_text_only(): |
| 64 | + |
| 65 | + model = gm.nn.Gemma3_4B(text_only=True) |
| 66 | + |
| 67 | + tokens = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32) |
| 68 | + images = jnp.ones((BATCH_SIZE, NUM_IMAGES, 64, 64, 3), dtype=jnp.uint8) |
| 69 | + |
| 70 | + with pytest.raises(ValueError, match='does not have vision encoder'): |
| 71 | + _get_output(model, tokens=tokens, images=images) |
| 72 | + |
| 73 | + out, params = _get_output(model, tokens=tokens) |
| 74 | + assert 'vision_encoder' not in params # Vision params not loaded |
| 75 | + assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed) |
| 76 | + |
| 77 | + |
| 78 | +def test_last_only(): |
| 79 | + model = gm.nn.Gemma3_4B(return_last_only=True) |
| 80 | + tokens = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32) |
| 81 | + out, params = _get_output(model, tokens=tokens) |
| 82 | + assert 'vision_encoder' in params # Vision by default |
| 83 | + assert out.logits.shape == (BATCH_SIZE, model.config.num_embed) |
0 commit comments