Gemma is a family of open-weights Large Language Model (LLM) by Google DeepMind, based on Gemini research and technology.
This repository contains the implementation of the
gemma
PyPI package. A
JAX library to use and fine-tune Gemma.
For examples and use cases, see our documentation. Please report issues and feedback in our GitHub.
-
Install JAX for CPU, GPU or TPU. Follow the instructions on the JAX website.
-
Run
pip install gemma
Here is a minimal example to have a multi-turn, multi-modal conversation with Gemma:
from gemma import gm
# Model and parameters
model = gm.nn.Gemma3_4B()
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)
# Example of multi-turn conversation
sampler = gm.text.ChatSampler(
model=model,
params=params,
multi_turn=True,
)
prompt = """Which of the two images do you prefer?
Image 1: <start_of_image>
Image 2: <start_of_image>
Write your answer as a poem."""
out0 = sampler.chat(prompt, images=[image1, image2])
out1 = sampler.chat('What about the other image ?')
Our documentation contains various Colabs and tutorials, including:
Additionally, our examples/ folder contain additional scripts to fine-tune and sample with Gemma.
- To use this library: Gemma documentation
- Technical reports for metrics and model capabilities:
- Other Gemma implementations and doc on the Gemma ecosystem
To download the model weights. See our documentation.
Gemma can run on a CPU, GPU and TPU. For GPU, we recommend 8GB+ RAM on GPU for The 2B checkpoint and 24GB+ RAM on GPU are used for the 7B checkpoint.
This is not an official Google product.