Skip to content

Gemma 3 Compilation Issues During GenerationΒ #39427

@mitchelldehaven

Description

@mitchelldehaven

System Info

When using the google/gemma-3-1b-it, I run into frequent recompilation issues, eventually resulting in the following error message:

torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider increasing torch._dynamo.config.cache_size_limit to an appropriate value

There is a simple workaround of sorting the inputs by length (longest first, trying to minimize graph recompilation due to increasing input sizes). However, that seems like a hacky workaround. Is this expected behavior or can the compilation be disabled.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to reproduce:

from transformers import Gemma3ForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import datasets
import json
import torch
from tqdm import tqdm


def tools_format_python(tool_dicts):
    functions_formatted = []
    tab_s = "    "
    for tool_dict in tool_dicts:
        function_name = tool_dict["name"]
        function_args = [f"{k}: {v['type']}" for k, v in tool_dict["parameters"].items()]
        function_formatted = (
            f"def {function_name}({', '.join(function_args)})\n" +
            f"{tab_s}{tool_dict['description']}\n\n" +
            f"{tab_s}Args:\n" +
            "\n".join(f"{tab_s*2}{k}: {v['description']}" for k, v in tool_dict["parameters"].items())
        )
        functions_formatted.append(function_formatted)
    tools_prompt = (
        "```python\n" +
        "\n\n".join(functions_formatted) + 
        "\n```\n\n"
    )
    return tools_prompt

llm_model_name = "google/gemma-3-1b-it"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(llm_model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
dataset = datasets.load_dataset("Salesforce/xlam-function-calling-60k")["train"]
system_message = "You are a helpful function calling AI assistant. Below are the Python functions accessible to you:\n"
for sample in tqdm(dataset):
    sample_tools = json.loads(sample["tools"])
    tools_prompt = tools_format_python(sample_tools)
    gemma_prompt_format = (
        "<bos><start_of_turn>user\n" +
        system_message + 
        tools_prompt + 
        "User: " + sample["query"] + "<end_of_turn>\n<start_of_turn>model\n```tool_call\n"
    )
    tokenized_prompt = tokenizer(gemma_prompt_format, return_tensors="pt", add_special_tokens=False)["input_ids"].to(device)
    with torch.no_grad(), torch.amp.autocast(device, dtype=torch.bfloat16):
        outputs = model.generate(
            tokenized_prompt,
            max_new_tokens=256,
            use_cache=True,
        )

Expected behavior

That dynamic input shapes should not result in the program eventually failing.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions