Open
Description
Implement nucleus sampling (top-p sampling) as a new sampling method in the Gemma text generation toolkit. This addresses the a gap in gemma/gm/text/__init__.py:34
and provides the missing sampling strategy.
Background
Nucleus sampling was introduced in "The Curious Case of Neural Text Degeneration" (Holtzman et al., 2020) and is a dynamic sampling technique that has gained popularity for high-quality generation for modern LLMs .
Current State
The Gemma library currently supports:
- ✅ Greedy sampling (
Greedy
) - ✅ Random sampling (
RandomSampling
) - ✅ Top-k sampling (
TopkSampling
) - ❌ Nucleus sampling
Problem
Each existing method has limitations:
Method | Issue |
---|---|
Greedy | Repetitive, deterministic output |
Random | May sample very unlikely tokens, leading to incoherent text |
Top-k | Fixed candidate set size doesn't adapt to context uncertainty |
Proposed solution
Add NucleusSampling
class that:
- Dynamically selects candidates based on cumulative probability mass;
- Adapts to context - uses fewer tokens when model is confident, more when uncertain.
Technical details
Algorithm
flowchart TD
A[Convert logits → probabilities temperature scaling] --> B[Sort tokens by probability descending]
B --> C[Find nucleus: smallest set where cumulative prob ≤ p]
C --> D[Filter out tokens outside nucleus]
D --> E[Renormalize remaining probabilities]
E --> F[Sample from the filtered distribution]
API design
@dataclasses.dataclass(frozen=True, kw_only=True)
class NucleusSampling(SamplingMethod):
temperature: float = 1.0 # Temperature scaling
p: float = 0.9 # Nucleus threshold (0.0-1.0)
def get_next_tokens(self, logits, rng) -> tokens
Usage
import gemma.gm as gm
# Conservative (factual text)
sampler = gm.text.NucleusSampling(p=0.7, temperature=0.8)
# Balanced (general purpose)
sampler = gm.text.NucleusSampling(p=0.9, temperature=1.0)
# Creative (diverse output)
sampler = gm.text.NucleusSampling(p=0.95, temperature=1.2)
References
Metadata
Metadata
Assignees
Labels
No labels