diff --git a/src/zeroband/infer.py b/src/zeroband/infer.py index f579c548..18d1a238 100644 --- a/src/zeroband/infer.py +++ b/src/zeroband/infer.py @@ -3,8 +3,8 @@ import os import shutil import time -import uuid from pathlib import Path +import uuid # Import environment before any other imports # ruff: noqa: I001 @@ -26,6 +26,7 @@ from zeroband.inference.pipeline import all_reduce, patch_model_load, setup_comm, setup_hooks from zeroband.inference.rewards import compute_vllm_rewards from zeroband.inference.toploc import setup_toploc_cache +from zeroband.inference.toploc2 import Toploc2Sampler from zeroband.utils.monitor import setup_monitor from zeroband.inference.utils import ( filter_data_by_prompt_length, @@ -81,7 +82,10 @@ def inference(config: Config): disable_async_output_proc=True, # We have an off by 1 error in toploc without this flag when cuda graph padding is enabled. download_dir=config.download_dir, dtype="bfloat16" if config.dtype == "bf16" else torch.float32, + enable_chunked_prefill=False, # This is required for toploc2 because chunked prefill seems to allow len(seq_groups) != len(selected_token_indices) which is unexpected ) + if config.toploc2: + llm.llm_engine.model_executor.driver_worker.model_runner.sampler = Toploc2Sampler() tokenizer = llm.get_tokenizer() # Adjust sampling params based on config @@ -220,6 +224,7 @@ def inference(config: Config): # This would work even if the node restarts and resumes from the current step. generator = np.random.default_rng(node_address_int * current_step_batch_counter + real_step) indices = generator.integers(0, len(dataset), problems_per_batch) + sampling_params.seed = int(generator.integers(2**32)) else: # Use modulo to cycle through the dataset instead of terminating indices = [(dataset_offset + j) % len(dataset) for j in range(problems_per_batch)] @@ -314,6 +319,11 @@ def inference(config: Config): monitor.log({"rewards/batch_rewards": batch_rewards}) logger.info(f"Average reward of the batch: {batch_rewards}") + if sampling_params.seed is not None: + sampling_seeds = [sampling_params.seed + i for i in range(sampling_params.n)] * problems_per_batch + else: + sampling_seeds = [None] * batch_samples + # Get parquet table table = get_parquet_table( request_outputs, @@ -324,6 +334,7 @@ def inference(config: Config): target_lengths, problems, enable_logprobs=config.sampling.logprobs is not None, + seeds=sampling_seeds, ) # Save outputs to parquet file diff --git a/src/zeroband/inference/config.py b/src/zeroband/inference/config.py index 4802b9de..602a2fc7 100644 --- a/src/zeroband/inference/config.py +++ b/src/zeroband/inference/config.py @@ -18,6 +18,13 @@ class SamplingParamConfig(BaseConfig): seed: int | None = None logprobs: int | None = 0 # put to None to disable logprobs calculation + @model_validator(mode="after") + def convert_negative_logprobs_to_none(self): + """Convert negative logprobs values to None to disable logprobs calculation.""" + if self.logprobs is not None and self.logprobs < 0: + self.logprobs = None + return self + class DifficultyFilteringConfig(BaseConfig): solve_rate_field: str = "solve_rate_qwen_r1_distill_7b" @@ -73,6 +80,7 @@ class Config(BaseConfig): ckpt_start_path: str | None = None toploc: bool = False + toploc2: bool = True rewards: RewardsConfig = RewardsConfig() difficulty_filtering: DifficultyFilteringConfig | None = None diff --git a/src/zeroband/inference/parquet.py b/src/zeroband/inference/parquet.py index 62a051c2..ea9995a0 100644 --- a/src/zeroband/inference/parquet.py +++ b/src/zeroband/inference/parquet.py @@ -43,6 +43,7 @@ def get_parquet_table( target_lengths: list[int], problems: Dataset, enable_logprobs: bool, + seeds: list[int], ) -> pa.Table: # Iterator over proofs proof_iter = iter(proofs) @@ -57,7 +58,7 @@ def get_parquet_table( problems, ): assert request_output.request_id == request_rewards.request_id - for output, reward in zip(request_output.outputs, request_rewards.rewards): + for output, reward, seed in zip(request_output.outputs, request_rewards.rewards, seeds): assert output.index == reward.completion_id # Extract logprobs if enabled and available @@ -82,6 +83,7 @@ def get_parquet_table( "step": step, "target_lengths": target_length, "task_type": request_rewards.task_type, + "seed": seed, } ) diff --git a/src/zeroband/inference/toploc2.py b/src/zeroband/inference/toploc2.py new file mode 100644 index 00000000..de88d7ec --- /dev/null +++ b/src/zeroband/inference/toploc2.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A layer that samples the next tokens from the model's outputs.""" + +from typing import Optional + +import torch +from vllm.model_executor.layers.sampler import ( + Sampler, + SampleResultArgsType, + SamplerOutput, + _apply_min_p, + _apply_min_tokens_penalty, + _apply_top_k_top_p, + _build_sampler_output, + get_logprobs, +) +from vllm.model_executor.layers.utils import apply_penalties +from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors + +# We have to use smaller sizes in the exponential_ function to prevent different kernels +# from being used by different GPUs. +GUMBEL_BATCH_SIZE = 2**16 + + +def generate_neg_gumbel_noise(n: int | tuple[int, int], generator: torch.Generator, device: torch.device): + if isinstance(n, int): + ret = torch.empty(n, device=device) + for i in range(0, n, GUMBEL_BATCH_SIZE): + end = min(i + GUMBEL_BATCH_SIZE, n) + ret[i:end].exponential_(generator=generator).log_() + else: + ret = torch.empty(n[0], n[1], device=device) + for i in range(0, n[0]): + for j in range(0, n[1], GUMBEL_BATCH_SIZE): + end_j = min(j + GUMBEL_BATCH_SIZE, n[1]) + ret[i, j:end_j].exponential_(generator=generator).log_() + return ret + + +class Toploc2Sampler(Sampler): + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + """ + Args: + logits: (num_tokens, vocab_size). + sampling_metadata: Metadata for sampling. + """ + assert logits is not None + _, vocab_size = logits.shape + + # Prepare sampling tensors with pinned memory to avoid blocking. + if not sampling_metadata.reuse_sampling_tensors: + self._init_sampling_tensors(logits, sampling_metadata) + elif self._do_penalties: + # In this case, the sampling tensors logic depends on + # "output_tokens" of a sequence. As a result, we cannot + # reuse sampling tensors, since "output_tokens" changes + # between decode runs. + self._init_sampling_tensors(logits, sampling_metadata) + + assert self._sampling_tensors is not None + sampling_tensors = self._sampling_tensors + do_penalties = self._do_penalties + do_top_p_top_k = self._do_top_p_top_k + do_min_p = self._do_min_p + + logits = _apply_min_tokens_penalty(logits, sampling_metadata) + + # Apply presence and frequency penalties. + if do_penalties: + logits = apply_penalties( + logits, + sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties, + ) + + # Use float32 to apply temperature scaling. + # Use in-place division to avoid creating a new tensor. + logits = logits.to(torch.float) + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) + + if do_top_p_top_k: + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks) + + if do_min_p: + logits = _apply_min_p(logits, sampling_tensors.min_ps) + + # # We use float32 for probabilities and log probabilities. + # # Compute the probabilities. + # probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # # Compute the log probabilities. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + chosen_noises = [] + + def _sample(logits, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors): + assert len(sampling_metadata.seq_groups) == logits.shape[0] + neg_gumbel_noise = torch.stack( + [generate_neg_gumbel_noise(logits.shape[-1], sg.generator, logits.device) for sg in sampling_metadata.seq_groups] + ) + assert neg_gumbel_noise.shape == logits.shape + _race_result = logits - neg_gumbel_noise + token_ids = torch.argmax(_race_result, dim=-1) + chosen_noises.append(torch.gather(neg_gumbel_noise, 1, token_ids.unsqueeze(1))) + return [([token_ids[i].item()], [0]) for i in range(len(sampling_metadata.seq_groups))] + + # Sample the next tokens. + maybe_deferred_sample_results = _sample( + logits, + sampling_metadata, + sampling_tensors, + ) + # Get the logprobs query results. + prompt_logprobs = None + sample_logprobs = None + if not sampling_metadata.skip_sampler_cpu_output: + # Pythonize logprobs now (GPU -> CPU); do not defer. + assert not isinstance(maybe_deferred_sample_results, SampleResultArgsType) + prompt_logprobs, sample_logprobs = get_logprobs(logprobs, sampling_metadata, maybe_deferred_sample_results) + + return _build_sampler_output( + maybe_deferred_sample_results, + sampling_metadata, + prompt_logprobs=prompt_logprobs, + sample_logprobs=sample_logprobs, + on_device_tensors=None, + skip_sampler_cpu_output=False, + ) diff --git a/src/zeroband/utils/parquet.py b/src/zeroband/utils/parquet.py index c4dc547e..0aa0a70b 100644 --- a/src/zeroband/utils/parquet.py +++ b/src/zeroband/utils/parquet.py @@ -17,5 +17,6 @@ ("step", pa.int32()), ("target_lengths", pa.int32()), ("task_type", pa.string()), + ("seed", pa.int64()), # Optional - can be null ] ) diff --git a/tests/conftest.py b/tests/conftest.py index 25753f97..a646dfcc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -107,6 +107,7 @@ def create_dummy_parquet_table(batch_size: int, seq_len: int) -> Table: "problem_id": pa.array(["0"] * batch_size, type=pa.string()), "input_logprobs": pa.array([[0.1] * seq_len for _ in range(batch_size)], type=pa.list_(pa.float32())), "output_logprobs": pa.array([[0.1] * seq_len for _ in range(batch_size)], type=pa.list_(pa.float32())), + "seed": pa.array([42] * batch_size, type=pa.int64()), } # Create table directly from dictionary