Skip to content

T5Gemma2 Support#3604

Closed
contrebande-labs wants to merge 11 commits intohuggingface:mainfrom
contrebande-labs:t5gemma2
Closed

T5Gemma2 Support#3604
contrebande-labs wants to merge 11 commits intohuggingface:mainfrom
contrebande-labs:t5gemma2

Conversation

@contrebande-labs
Copy link
Contributor

Hi! As per #3602, here is a WIP just to check if I'm on the right track. If you agree with the general idea, I will write some tests with the t5gemma2 weights. I don't expect you guys to work on the holidays so take your time and get back to me when you can.

-- Vincent

@tomaarsen
Copy link
Member

Hello!

This is definitely in the right direction, matching what I would do as well. The only downside is that the T5GemmaConfig doesn't immediately expose the hidden_size, but instead uses a text_config which hosts the hidden_size. I've been working on adding this support in #3554:

https://github.com/tomaarsen/sentence-transformers/blob/3c55e8b14a1a5bbfb96da1be6938b92a6a4f5adf/sentence_transformers/base/models/Transformer.py#L874-L900

Compared to main:

def get_word_embedding_dimension(self) -> int:
return self.auto_model.config.hidden_size

I've done some tests, and it looks like this PR works once the get_word_embedding_dimension-changes from #3554 are merged.


I made some small changes, but it won't let me push that into this PR sadly, as I don't have the required permissions on the PR host repository. You can see the commit here: 7b429c8

Guard the T5Gemma2 import for transformers <5 compat, reformat

  • Tom Aarsen

@contrebande-labs
Copy link
Contributor Author

I added your changes. I didn't enable commiter push on this PR because I wanted to keep track of things. If you have other changes, let me know and I will make these changes.

I would now like to add examples and tests before I take this PR out of draft mode. I would like to implement a full quantization-aware FR-EN biilingual matryoshka training and subsequent ONNX export and OpenVINO ONNX Runtime inference. What do you think?

@contrebande-labs
Copy link
Contributor Author

I forgot get_word_embedding_dimension() now it's done. I left out the TIMM stuff, however.

@contrebande-labs
Copy link
Contributor Author

Hi Tom!

With the latest commit, loading the model now works:

st_pytorch_model = SentenceTransformer("google/t5gemma-2-270m-270m", backend="torch", device="xpu", trust_remote_code=True, local_files_only=False, model_kwargs={"dtype": torch.float32}).eval()

The code is a mix of your other PRs with what's currently on the main branch (e.g. model -- > auto_model). But the loading works. Did you intend to merge your other PRs first? If so, I can rebase once done. Otherwise, let me know how you suggest I make tests and examples.

@contrebande-labs
Copy link
Contributor Author

Hi Tom. Or I can just turn over this PR to you so you can pull it and I can make another PR for examples and tests.

@contrebande-labs contrebande-labs marked this pull request as ready for review December 30, 2025 16:20
@contrebande-labs contrebande-labs changed the title t5gemma2 sentence transformer T5Gemma2 Sentence Transformer support Dec 30, 2025
@contrebande-labs contrebande-labs changed the title T5Gemma2 Sentence Transformer support T5Gemma2 Support Dec 30, 2025
@contrebande-labs
Copy link
Contributor Author

Hi Tom!

What are your plans for this PR ? Thanks!

Vincent

@tomaarsen
Copy link
Member

Hello!

Apologies for the delays, I've been helping transformers get their v5 ready with some features that would benefit sentence-transformers. Now I'm back, looking to prepare a v5.3 release (and I'd like to get this in there). I'll keep you posted if there's more reviews or something 🤗

  • Tom Aarsen

@contrebande-labs
Copy link
Contributor Author

Hi Tom, makes sense about txers v5. I rebased the PR branch with your latest changes. I will see if there are adjustments needed later today. Glad you're back!

Copy link
Member

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the multiple rounds of reviews. I'm experimenting with the changes now, and I see that there's some limitations on the side of transformers currently. I'll make a PR there, but the short version is that if you save a SentenceTransformer("google/t5gemma-2-270m-270m"), you won't be able to load it anymore, as the t5gemma2_encoder is a non-recognized architecture.

Comment on lines +29 to +33
if parse_version(transformers_version) >= parse_version("5.0.0dev0"):
from transformers import T5Gemma2Config
else:
class T5Gemma2Config:
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if parse_version(transformers_version) >= parse_version("5.0.0dev0"):
from transformers import T5Gemma2Config
else:
class T5Gemma2Config:
pass
try:
from transformers import T5Gemma2Config
except ImportError:
class T5Gemma2Config:
pass

I realise now that it's likely better to use "Easier to ask for forgiveness than permission".

Comment on lines +320 to +334
# Try hidden_sizes list (e.g., ResNet, some vision models)
if hasattr(text_config, "hidden_sizes"):
if isinstance(text_config.hidden_sizes, list):
return text_config.hidden_sizes[-1] # Use final layer dimension
return text_config.hidden_sizes
elif hasattr(text_config, "hidden_size"):
return text_config.hidden_size

# Unable to determine dimension
raise ValueError(
f"Could not determine embedding dimension from model config. "
f"Config type: {type(text_config).__name__}. "
f"Available attributes: {[attr for attr in dir(text_config) if 'hidden' in attr.lower() or 'size' in attr.lower() or 'dim' in attr.lower()]}. "
f"Please report this issue with your model name: {self.auto_model.config.model_type if hasattr(self.auto_model.config, 'model_type') else 'unknown'}"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Try hidden_sizes list (e.g., ResNet, some vision models)
if hasattr(text_config, "hidden_sizes"):
if isinstance(text_config.hidden_sizes, list):
return text_config.hidden_sizes[-1] # Use final layer dimension
return text_config.hidden_sizes
elif hasattr(text_config, "hidden_size"):
return text_config.hidden_size
# Unable to determine dimension
raise ValueError(
f"Could not determine embedding dimension from model config. "
f"Config type: {type(text_config).__name__}. "
f"Available attributes: {[attr for attr in dir(text_config) if 'hidden' in attr.lower() or 'size' in attr.lower() or 'dim' in attr.lower()]}. "
f"Please report this issue with your model name: {self.auto_model.config.model_type if hasattr(self.auto_model.config, 'model_type') else 'unknown'}"
)
return text_config.hidden_size

I think I'd prefer to keep this small until I actually complete the multi-modality PR.

@contrebande-labs
Copy link
Contributor Author

contrebande-labs commented Jan 28, 2026

Ok. I'll keep my branch as is because I can export it with ONNX thus avoiding the issue you are talking about. I will rebase when you pull the multimodal patch and the problem is fixed with txers v5.

Also, when you have a PR in txers, can you link it here, please?

@tomaarsen
Copy link
Member

tomaarsen commented Jan 28, 2026

Sure, here's the PR: huggingface/transformers#43559
I trained a model here: https://huggingface.co/tomaarsen/t5gemma2-270m-gooaq-cmnrl
I used the same script from https://huggingface.co/tomaarsen/mpnet-base-gooaq-cmnrl-1024bs-GradCache, but the T5Gemma2-270m variant scores a good bit better: 0.8948 NDDCG@10 vs 0.8652 NDCG@10 on in-domain evaluation on a GooAQ dev set. Granted, the T5Gemma2 model is also bigger (270M vs 110M).

Click to see the script used
import random
import logging
from datasets import load_dataset, Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
    "google/t5gemma-2-270m-270m",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="T5Gemma2 270M encoder trained on GooAQ pairs using CachedMultipleNegativesRankingLoss",
    ),
)

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]

eval_dataset: Dataset = dataset_dict["test"]

# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=32)

# 5. (Optional) Specify training arguments
run_name = "t5gemma2-270m-gooaq-cmnrl"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=1024,
    per_device_eval_batch_size=1024,
    learning_rate=2e-5 * 4,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # CachedMultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=0.1,
    save_strategy="steps",
    save_steps=0.1,
    save_total_limit=2,
    logging_steps=0.05,
    logging_first_step=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
# The full corpus, but only the evaluation queries
corpus = dict(zip(dataset["id"], dataset["answer"]))
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
corpus = (
    {qid: dataset[qid]["answer"] for qid in queries}
    # {qid: dataset[qid]["answer"] for qid in queries} |
    # {qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)}
)
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = InformationRetrievalEvaluator(
    corpus=corpus,
    queries=queries,
    relevant_docs=relevant_docs,
    show_progress_bar=True,
    name="gooaq-dev",
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset.remove_columns("id"),
    eval_dataset=eval_dataset.remove_columns("id"),
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)

# 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)
  • Tom Aarsen

@contrebande-labs
Copy link
Contributor Author

I see a great future in IR for this architecture too. Can't wait to be officially supported by the whole stack. I've subscribed to your PR. I might make a dedicated repo myself if it takes too long.

@tomaarsen
Copy link
Member

There's some updates on the transformers PR: huggingface/transformers#43559 (comment)
I don't think we'll need the t5gemma_encoder.

  • Tom Aarsen

@contrebande-labs
Copy link
Contributor Author

Great news. But even if this gets fixed, we need a new txers release, correct?

@tomaarsen
Copy link
Member

tomaarsen commented Jan 29, 2026

Yes, it seems that there's an issue with the configs not being "linked" correctly, so I get a crash when trying to train with the model, even if I load it with T5Gemma2Encoder.

  • Tom Aarsen

@tomaarsen
Copy link
Member

Implemented via #3644!
Thanks for bringing my attention to this, it's very nice to get this working.

  • Tom Aarsen

@tomaarsen tomaarsen closed this Feb 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants