-
Notifications
You must be signed in to change notification settings - Fork 29.7k
Open
Labels
Description
System Info
When using the google/gemma-3-1b-it
, I run into frequent recompilation issues, eventually resulting in the following error message:
torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider increasing torch._dynamo.config.cache_size_limit to an appropriate value
There is a simple workaround of sorting the inputs by length (longest first, trying to minimize graph recompilation due to increasing input sizes). However, that seems like a hacky workaround. Is this expected behavior or can the compilation be disabled.
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Steps to reproduce:
from transformers import Gemma3ForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import datasets
import json
import torch
from tqdm import tqdm
def tools_format_python(tool_dicts):
functions_formatted = []
tab_s = " "
for tool_dict in tool_dicts:
function_name = tool_dict["name"]
function_args = [f"{k}: {v['type']}" for k, v in tool_dict["parameters"].items()]
function_formatted = (
f"def {function_name}({', '.join(function_args)})\n" +
f"{tab_s}{tool_dict['description']}\n\n" +
f"{tab_s}Args:\n" +
"\n".join(f"{tab_s*2}{k}: {v['description']}" for k, v in tool_dict["parameters"].items())
)
functions_formatted.append(function_formatted)
tools_prompt = (
"```python\n" +
"\n\n".join(functions_formatted) +
"\n```\n\n"
)
return tools_prompt
llm_model_name = "google/gemma-3-1b-it"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(llm_model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
dataset = datasets.load_dataset("Salesforce/xlam-function-calling-60k")["train"]
system_message = "You are a helpful function calling AI assistant. Below are the Python functions accessible to you:\n"
for sample in tqdm(dataset):
sample_tools = json.loads(sample["tools"])
tools_prompt = tools_format_python(sample_tools)
gemma_prompt_format = (
"<bos><start_of_turn>user\n" +
system_message +
tools_prompt +
"User: " + sample["query"] + "<end_of_turn>\n<start_of_turn>model\n```tool_call\n"
)
tokenized_prompt = tokenizer(gemma_prompt_format, return_tensors="pt", add_special_tokens=False)["input_ids"].to(device)
with torch.no_grad(), torch.amp.autocast(device, dtype=torch.bfloat16):
outputs = model.generate(
tokenized_prompt,
max_new_tokens=256,
use_cache=True,
)
Expected behavior
That dynamic input shapes should not result in the program eventually failing.
florian-hoenicke