Open
Description
# Common imports
import os
import jax.numpy as jnp
import tensorflow_datasets as tfds
# Gemma imports
from gemma import gm
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"
ds = tfds.data_source("oxford_flowers102", split="train")
image1 = ds[0]["image"]
image2 = ds[1]["image"]
model = gm.nn.Gemma3_4B()
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)
sampler = gm.text.ChatSampler(
model=model,
params=params,
multi_turn=True,
)
out = sampler.chat(
"What can you say about this image: <start_of_image>",
images=image1,
)
out1 = sampler.chat(
"What about this other image?: <start_of_image>",
images=image2,
)
Metadata
Metadata
Assignees
Labels
No labels