Skip to content

Add nucleus sampling (top-p) method #296

Open
@stefanoamorelli

Description

@stefanoamorelli

Implement nucleus sampling (top-p sampling) as a new sampling method in the Gemma text generation toolkit. This addresses the a gap in gemma/gm/text/__init__.py:34 and provides the missing sampling strategy.

Background

Nucleus sampling was introduced in "The Curious Case of Neural Text Degeneration" (Holtzman et al., 2020) and is a dynamic sampling technique that has gained popularity for high-quality generation for modern LLMs .

Current State

The Gemma library currently supports:

  • ✅ Greedy sampling (Greedy)
  • ✅ Random sampling (RandomSampling)
  • ✅ Top-k sampling (TopkSampling)
  • Nucleus sampling

Problem

Each existing method has limitations:

Method Issue
Greedy Repetitive, deterministic output
Random May sample very unlikely tokens, leading to incoherent text
Top-k Fixed candidate set size doesn't adapt to context uncertainty

Proposed solution

Add NucleusSampling class that:

  1. Dynamically selects candidates based on cumulative probability mass;
  2. Adapts to context - uses fewer tokens when model is confident, more when uncertain.

Technical details

Algorithm

flowchart TD
    A[Convert logits → probabilities temperature scaling] --> B[Sort tokens by probability descending]
    B --> C[Find nucleus: smallest set where cumulative prob ≤ p]
    C --> D[Filter out tokens outside nucleus]
    D --> E[Renormalize remaining probabilities]
    E --> F[Sample from the filtered distribution]
Loading

API design

@dataclasses.dataclass(frozen=True, kw_only=True)
class NucleusSampling(SamplingMethod):
    temperature: float = 1.0  # Temperature scaling
    p: float = 0.9           # Nucleus threshold (0.0-1.0)
    
    def get_next_tokens(self, logits, rng) -> tokens

Usage

import gemma.gm as gm

# Conservative (factual text)
sampler = gm.text.NucleusSampling(p=0.7, temperature=0.8)

# Balanced (general purpose) 
sampler = gm.text.NucleusSampling(p=0.9, temperature=1.0)

# Creative (diverse output)
sampler = gm.text.NucleusSampling(p=0.95, temperature=1.2)

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions