Skip to content

[Feat] toploc2 #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/zeroband/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import os
import shutil
import time
import uuid
from pathlib import Path
import uuid

# Import environment before any other imports
# ruff: noqa: I001
Expand All @@ -26,6 +26,7 @@
from zeroband.inference.pipeline import all_reduce, patch_model_load, setup_comm, setup_hooks
from zeroband.inference.rewards import compute_vllm_rewards
from zeroband.inference.toploc import setup_toploc_cache
from zeroband.inference.toploc2 import Toploc2Sampler
from zeroband.utils.monitor import setup_monitor
from zeroband.inference.utils import (
filter_data_by_prompt_length,
Expand Down Expand Up @@ -81,7 +82,10 @@ def inference(config: Config):
disable_async_output_proc=True, # We have an off by 1 error in toploc without this flag when cuda graph padding is enabled.
download_dir=config.download_dir,
dtype="bfloat16" if config.dtype == "bf16" else torch.float32,
enable_chunked_prefill=False, # This is required for toploc2 because chunked prefill seems to allow len(seq_groups) != len(selected_token_indices) which is unexpected
)
if config.toploc2:
llm.llm_engine.model_executor.driver_worker.model_runner.sampler = Toploc2Sampler()
tokenizer = llm.get_tokenizer()

# Adjust sampling params based on config
Expand Down Expand Up @@ -220,6 +224,7 @@ def inference(config: Config):
# This would work even if the node restarts and resumes from the current step.
generator = np.random.default_rng(node_address_int * current_step_batch_counter + real_step)
indices = generator.integers(0, len(dataset), problems_per_batch)
sampling_params.seed = int(generator.integers(2**32))
else:
# Use modulo to cycle through the dataset instead of terminating
indices = [(dataset_offset + j) % len(dataset) for j in range(problems_per_batch)]
Expand Down Expand Up @@ -314,6 +319,11 @@ def inference(config: Config):
monitor.log({"rewards/batch_rewards": batch_rewards})
logger.info(f"Average reward of the batch: {batch_rewards}")

if sampling_params.seed is not None:
sampling_seeds = [sampling_params.seed + i for i in range(sampling_params.n)] * problems_per_batch
else:
sampling_seeds = [None] * batch_samples

# Get parquet table
table = get_parquet_table(
request_outputs,
Expand All @@ -324,6 +334,7 @@ def inference(config: Config):
target_lengths,
problems,
enable_logprobs=config.sampling.logprobs is not None,
seeds=sampling_seeds,
)

# Save outputs to parquet file
Expand Down
8 changes: 8 additions & 0 deletions src/zeroband/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ class SamplingParamConfig(BaseConfig):
seed: int | None = None
logprobs: int | None = 0 # put to None to disable logprobs calculation

@model_validator(mode="after")
def convert_negative_logprobs_to_none(self):
"""Convert negative logprobs values to None to disable logprobs calculation."""
if self.logprobs is not None and self.logprobs < 0:
self.logprobs = None
return self
Comment on lines +21 to +26
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels more intuitive to err when passing negative values?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is necessary to disable logprobs. I couldnt find a way to pass none and since the default is 0, it didnt seem like there was a way to make it None other than this



class DifficultyFilteringConfig(BaseConfig):
solve_rate_field: str = "solve_rate_qwen_r1_distill_7b"
Expand Down Expand Up @@ -73,6 +80,7 @@ class Config(BaseConfig):
ckpt_start_path: str | None = None

toploc: bool = False
toploc2: bool = True

rewards: RewardsConfig = RewardsConfig()
difficulty_filtering: DifficultyFilteringConfig | None = None
Expand Down
4 changes: 3 additions & 1 deletion src/zeroband/inference/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_parquet_table(
target_lengths: list[int],
problems: Dataset,
enable_logprobs: bool,
seeds: list[int],
) -> pa.Table:
# Iterator over proofs
proof_iter = iter(proofs)
Expand All @@ -57,7 +58,7 @@ def get_parquet_table(
problems,
):
assert request_output.request_id == request_rewards.request_id
for output, reward in zip(request_output.outputs, request_rewards.rewards):
for output, reward, seed in zip(request_output.outputs, request_rewards.rewards, seeds):
assert output.index == reward.completion_id

# Extract logprobs if enabled and available
Expand All @@ -82,6 +83,7 @@ def get_parquet_table(
"step": step,
"target_lengths": target_length,
"task_type": request_rewards.task_type,
"seed": seed,
}
)

Expand Down
134 changes: 134 additions & 0 deletions src/zeroband/inference/toploc2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""

from typing import Optional

import torch
from vllm.model_executor.layers.sampler import (
Sampler,
SampleResultArgsType,
SamplerOutput,
_apply_min_p,
_apply_min_tokens_penalty,
_apply_top_k_top_p,
_build_sampler_output,
get_logprobs,
)
from vllm.model_executor.layers.utils import apply_penalties
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors

# We have to use smaller sizes in the exponential_ function to prevent different kernels
# from being used by different GPUs.
GUMBEL_BATCH_SIZE = 2**16


def generate_neg_gumbel_noise(n: int | tuple[int, int], generator: torch.Generator, device: torch.device):
if isinstance(n, int):
ret = torch.empty(n, device=device)
for i in range(0, n, GUMBEL_BATCH_SIZE):
end = min(i + GUMBEL_BATCH_SIZE, n)
ret[i:end].exponential_(generator=generator).log_()
else:
ret = torch.empty(n[0], n[1], device=device)
for i in range(0, n[0]):
for j in range(0, n[1], GUMBEL_BATCH_SIZE):
end_j = min(j + GUMBEL_BATCH_SIZE, n[1])
ret[i, j:end_j].exponential_(generator=generator).log_()
return ret


class Toploc2Sampler(Sampler):
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert logits is not None
_, vocab_size = logits.shape

# Prepare sampling tensors with pinned memory to avoid blocking.
if not sampling_metadata.reuse_sampling_tensors:
self._init_sampling_tensors(logits, sampling_metadata)
elif self._do_penalties:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self._init_sampling_tensors(logits, sampling_metadata)

assert self._sampling_tensors is not None
sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
do_min_p = self._do_min_p

logits = _apply_min_tokens_penalty(logits, sampling_metadata)

# Apply presence and frequency penalties.
if do_penalties:
logits = apply_penalties(
logits,
sampling_tensors.prompt_tokens,
sampling_tensors.output_tokens,
sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties,
)

# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks)

if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)

# # We use float32 for probabilities and log probabilities.
# # Compute the probabilities.
# probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# # Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

chosen_noises = []

def _sample(logits, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors):
assert len(sampling_metadata.seq_groups) == logits.shape[0]
neg_gumbel_noise = torch.stack(
[generate_neg_gumbel_noise(logits.shape[-1], sg.generator, logits.device) for sg in sampling_metadata.seq_groups]
)
assert neg_gumbel_noise.shape == logits.shape
_race_result = logits - neg_gumbel_noise
token_ids = torch.argmax(_race_result, dim=-1)
chosen_noises.append(torch.gather(neg_gumbel_noise, 1, token_ids.unsqueeze(1)))
return [([token_ids[i].item()], [0]) for i in range(len(sampling_metadata.seq_groups))]

# Sample the next tokens.
maybe_deferred_sample_results = _sample(
logits,
sampling_metadata,
sampling_tensors,
)
# Get the logprobs query results.
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert not isinstance(maybe_deferred_sample_results, SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(logprobs, sampling_metadata, maybe_deferred_sample_results)

return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs=prompt_logprobs,
sample_logprobs=sample_logprobs,
on_device_tensors=None,
skip_sampler_cpu_output=False,
)
1 change: 1 addition & 0 deletions src/zeroband/utils/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
("step", pa.int32()),
("target_lengths", pa.int32()),
("task_type", pa.string()),
("seed", pa.int64()), # Optional - can be null
]
)
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def create_dummy_parquet_table(batch_size: int, seq_len: int) -> Table:
"problem_id": pa.array(["0"] * batch_size, type=pa.string()),
"input_logprobs": pa.array([[0.1] * seq_len for _ in range(batch_size)], type=pa.list_(pa.float32())),
"output_logprobs": pa.array([[0.1] * seq_len for _ in range(batch_size)], type=pa.list_(pa.float32())),
"seed": pa.array([42] * batch_size, type=pa.int64()),
}

# Create table directly from dictionary
Expand Down
Loading