Skip to content

Conversation

@markknoffler
Copy link

Fix pad_length tuple handling in _tokenize_prompts

Fixes #467

Summary

This PR fixes a type signature mismatch bug where Sampler.pad_length defaults to a tuple (256, 512, 1024) but _tokenize_prompts() only accepted int | None which caused a TypeError when tuples were passed. The fix updates the type signature to accept tuples and adds proper bucket selection logic to handle tuple values correctly.

Problem

The Sampler class at line 129 defines pad_length: None | int | tuple[int, ...] = (256, 512, 1024) with a tuple as the default value but the _tokenize_prompts() method at line 410 only accepted pad_length: int | None = None in its type signature. The implementation at line 418 used max_prompt_len = pad_length or max(len(t) for t in tokens) which broke when pad_length was a tuple because tuples are truthy in Python so the or operator returned the tuple itself instead of calculating max length.

When a tuple was passed to _tokenize_prompts() the code would crash with TypeError: '>' not supported between instances of 'int' and 'tuple' because the tuple would get passed to _functional.pad() as max_length and when it tried to compare seq_length > max_length Python couldn't compare an int to a tuple.

Solution

The fix involves two changes to gemma/gm/text/_sampler.py:

1. Type Signature Update

Changed the pad_length parameter type from int | None to int | tuple[int, ...] | None to match the type signature of Sampler.pad_length. This allows the method to accept tuple values which is consistent with how _prefill.prefill() already handles tuples and makes the API uniform across the codebase.

2. Bucket Selection Logic

Replaced the buggy pad_length or max(...) logic with proper handling for all three cases:

  • None: Uses the actual maximum length of the tokenized prompts
  • int: Uses that integer value directly as the padding length
  • tuple: Iterates through the bucket sizes and picks the smallest bucket that fits the actual max length

The tuple handling works by checking each bucket size in order and selecting the first one that's greater than or equal to the actual prompt length. This matches the behavior of _prefill.prefill() which uses _pad_to_bucket() for the same purpose and if no bucket fits the prompt length exceeds all bucket sizes then it falls back to using the actual max length.

Why Bucket Selection is Necessary

The tuple buckets (256, 512, 1024) are designed to optimize memory usage by choosing the smallest bucket that fits rather than always padding to the maximum size which reduces JAX recompilation overhead while minimizing memory waste. Without the bucket selection logic the tuple would be passed directly to the padding function which expects an integer causing the TypeError crash.

The bucket selection ensures that when pad_length is (256, 512, 1024) and the prompt is 50 tokens long it selects 256 as the padding size instead of trying to use the entire tuple. This makes the API consistent across the codebase since _prefill.prefill() already handles tuples this way.

Testing

The fix was tested by calling _tokenize_prompts() directly with the tuple pad_length value which previously caused a crash. After the fix it correctly handles the tuple and selects the appropriate bucket size without errors.

Impact

This fix resolves the type mismatch bug and makes _tokenize_prompts() consistent with _prefill.prefill() in how they handle pad_length parameters. The bug didn't manifest in normal usage because _get_inputs() doesn't pass pad_length to _tokenize_prompts() but this fix prevents crashes if someone refactors the code to pass it or calls the method directly with tuple values.

@google-cla
Copy link

google-cla bot commented Dec 4, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@markknoffler markknoffler force-pushed the fix-pad-length-tuple-bug branch from 7cff778 to 0b54ad5 Compare December 4, 2025 05:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Type mismatch: _tokenize_prompts() doesn't accept tuple pad_length but Sampler.pad_length defaults to tuple

1 participant