Fix pad_length tuple handling in _tokenize_prompts #468
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fix pad_length tuple handling in _tokenize_prompts
Fixes #467
Summary
This PR fixes a type signature mismatch bug where
Sampler.pad_lengthdefaults to a tuple(256, 512, 1024)but_tokenize_prompts()only acceptedint | Nonewhich caused aTypeErrorwhen tuples were passed. The fix updates the type signature to accept tuples and adds proper bucket selection logic to handle tuple values correctly.Problem
The
Samplerclass at line 129 definespad_length: None | int | tuple[int, ...] = (256, 512, 1024)with a tuple as the default value but the_tokenize_prompts()method at line 410 only acceptedpad_length: int | None = Nonein its type signature. The implementation at line 418 usedmax_prompt_len = pad_length or max(len(t) for t in tokens)which broke whenpad_lengthwas a tuple because tuples are truthy in Python so theoroperator returned the tuple itself instead of calculating max length.When a tuple was passed to
_tokenize_prompts()the code would crash withTypeError: '>' not supported between instances of 'int' and 'tuple'because the tuple would get passed to_functional.pad()asmax_lengthand when it tried to compareseq_length > max_lengthPython couldn't compare an int to a tuple.Solution
The fix involves two changes to
gemma/gm/text/_sampler.py:1. Type Signature Update
Changed the
pad_lengthparameter type fromint | Nonetoint | tuple[int, ...] | Noneto match the type signature ofSampler.pad_length. This allows the method to accept tuple values which is consistent with how_prefill.prefill()already handles tuples and makes the API uniform across the codebase.2. Bucket Selection Logic
Replaced the buggy
pad_length or max(...)logic with proper handling for all three cases:None: Uses the actual maximum length of the tokenized promptsint: Uses that integer value directly as the padding lengthtuple: Iterates through the bucket sizes and picks the smallest bucket that fits the actual max lengthThe tuple handling works by checking each bucket size in order and selecting the first one that's greater than or equal to the actual prompt length. This matches the behavior of
_prefill.prefill()which uses_pad_to_bucket()for the same purpose and if no bucket fits the prompt length exceeds all bucket sizes then it falls back to using the actual max length.Why Bucket Selection is Necessary
The tuple buckets
(256, 512, 1024)are designed to optimize memory usage by choosing the smallest bucket that fits rather than always padding to the maximum size which reduces JAX recompilation overhead while minimizing memory waste. Without the bucket selection logic the tuple would be passed directly to the padding function which expects an integer causing the TypeError crash.The bucket selection ensures that when
pad_lengthis(256, 512, 1024)and the prompt is 50 tokens long it selects 256 as the padding size instead of trying to use the entire tuple. This makes the API consistent across the codebase since_prefill.prefill()already handles tuples this way.Testing
The fix was tested by calling
_tokenize_prompts()directly with the tuplepad_lengthvalue which previously caused a crash. After the fix it correctly handles the tuple and selects the appropriate bucket size without errors.Impact
This fix resolves the type mismatch bug and makes
_tokenize_prompts()consistent with_prefill.prefill()in how they handlepad_lengthparameters. The bug didn't manifest in normal usage because_get_inputs()doesn't passpad_lengthto_tokenize_prompts()but this fix prevents crashes if someone refactors the code to pass it or calls the method directly with tuple values.