From c63fc49905dd1f55963bd7cfed91366e5d59e27d Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Thu, 9 Jan 2025 17:00:34 +0000 Subject: [PATCH 01/18] TPU rebase from Rob's PR - in process --- vllm/attention/selector.py | 4 + vllm/platforms/interface.py | 1 + vllm/v1/attention/backends/pallas.py | 345 ++++++++ vllm/v1/worker/__tpu_model_runner.py | 981 +++++++++++++++++++++ vllm/v1/worker/__tpu_worker.py | 198 +++++ vllm/v1/worker/tpu_model_runner_new.py | 1081 ++++++++++++++++++++++++ 6 files changed, 2610 insertions(+) create mode 100644 vllm/v1/attention/backends/pallas.py create mode 100644 vllm/v1/worker/__tpu_model_runner.py create mode 100644 vllm/v1/worker/__tpu_worker.py create mode 100644 vllm/v1/worker/tpu_model_runner_new.py diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index d263839705690..ca5dc94c94270 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -163,6 +163,10 @@ def _cached_get_attn_backend( logger.info("Using Pallas backend.") from vllm.attention.backends.pallas import PallasAttentionBackend return PallasAttentionBackend + elif backend == _Backend.PALLAS_VLLM_V1: + logger.info("Using Pallas backend.") + from vllm.v1.attention.backends.pallas import PallasAttentionBackendV1 + return PallasAttentionBackendV1 elif backend == _Backend.NO_ATTENTION: from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionBackend) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index ddccaa2ce0148..37079a80411d2 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -32,6 +32,7 @@ class _Backend(enum.Enum): FLASHINFER = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() + PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() NO_ATTENTION = enum.auto() diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py new file mode 100644 index 0000000000000..070c238acd63d --- /dev/null +++ b/vllm/v1/attention/backends/pallas.py @@ -0,0 +1,345 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla.experimental.custom_kernel # Required to register custom ops. + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "PALLAS_VLLM_V1" + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: + return PallasAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PallasMetadata"]: + return PallasMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_kv_heads, num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise RuntimeError("swap_blocks is not used for the TPU backend.") + + @torch.compile(backend="openxla") + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dists: Tuple[torch.Tensor, torch.Tensor], + ) -> None: + src_indices, dst_indices = src_to_dists + for k_cache, v_cache in kv_caches: + torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) + k_cache[:, dst_indices] = k_cache[:, src_indices] + torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) + v_cache[:, dst_indices] = v_cache[:, src_indices] + + +@dataclass +class PallasMetadata(AttentionMetadata): + + # Currently, input sequences can only contain all prefills + # or all decoding. + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None + effective_query_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["PallasMetadata"]: + if self.num_prefills == 0: + return None + + assert self.num_decode_tokens == 0 + return self + + @property + def decode_metadata(self) -> Optional["PallasMetadata"]: + if self.num_decode_tokens == 0: + return None + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.block_tables is not None + assert self.context_lens is not None + return self + + +class PallasAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if head_size % 128 != 0: + raise NotImplementedError("Head size must be a multiple of 128.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if kv_cache_dtype != "auto": + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + if logits_soft_cap is not None: + raise NotImplementedError( + "Attention logits soft-capping is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + + self.megacore_mode = None + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + attn_metadata: PallasMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Pallas attention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + assert k_scale == 1.0 and v_scale == 1.0 + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + if kv_cache[0].numel() > 0: + slot_mapping = attn_metadata.slot_mapping + key_cache, value_cache = kv_cache + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.num_prefills > 0: + if attn_metadata.block_tables is None: + # Prefill without paged KV cache. + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, + dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention kernel requires the input shape to be + # [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Prefill with paged KV cache. + # TODO(woosuk): Tune the below knobs. + num_kv_pages_per_compute_block = 16 + num_queries_per_compute_block = 16 + assert seq_len % num_queries_per_compute_block == 0 + output = torch.ops.xla.multi_queries_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.effective_query_lens, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=True, + ) + else: + # Decoding run. + assert kv_cache[0].numel() > 0 + query = query.squeeze(dim=1) + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + self.megacore_mode, + ) + else: + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + ) + output[chunk_start:chunk_end] = chunk_output + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + # NOTE(woosuk): A temporary workaround to avoid the error: + # "xla::paged_attention() Expected a value of type 'str' for + # argument 'megacore_mode' but instead found type 'NoneType'." + if megacore_mode is not None: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + ) + else: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + ) + return output diff --git a/vllm/v1/worker/__tpu_model_runner.py b/vllm/v1/worker/__tpu_model_runner.py new file mode 100644 index 0000000000000..7963fe4973b55 --- /dev/null +++ b/vllm/v1/worker/__tpu_model_runner.py @@ -0,0 +1,981 @@ +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +import torch_xla.core.xla_model as xm + +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MultiModalDataDict +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_pin_memory_available +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, + PallasAttentionMetadata) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME: Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 + + +@dataclass +class PrefillInputData: + + request_ids: List + prompt_lens: List + token_ids: List + position_ids: List + attn_metadata: List + + def zipped(self): + return zip(self.request_ids, self.prompt_lens, self.token_ids, + self.position_ids, self.attn_metadata) + + +@dataclass +class DecodeInputData: + + num_decodes: int + token_ids: Optional[torch.Tensor] = None + position_ids: Optional[torch.Tensor] = None + attn_metadata: PallasAttentionMetadata = None + + +class TPUModelRunner: + + def __init__( + self, + vllm_config: VllmConfig, + ): + # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] + + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + + # Model-related. + self.num_attn_layers = model_config.get_num_attention_layers( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + + # List[k_cache, v_cache] + self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.scheduler_config.max_num_seqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + ) + + self.prefill_positions = torch.tensor( + range(self.max_model_len), + device="cpu", + ).to(torch.int32).reshape(1, -1) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove stopped requests from the cached states. + # Keep the states of the pre-empted requests. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + + # Remove the requests from the persistent batch. + stopped_req_ids = set().union( + scheduler_output.preempted_req_ids, + scheduler_output.finished_req_ids, + ) + removed_req_indices: List[int] = [] + for req_id in stopped_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Update the states of the running requests. + for req_data in scheduler_output.scheduled_running_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + req_index = self.input_batch.req_id_to_index[req_id] + + # Update the num_computed_tokens. + req_state.num_computed_tokens = req_data.num_computed_tokens + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + + # Update the block table. + num_new_blocks = len(req_data.new_block_ids) + if num_new_blocks == 0: + continue + start_index = len(req_state.block_ids) + end_index = start_index + num_new_blocks + req_state.block_ids.extend(req_data.new_block_ids) + self.input_batch.block_table_cpu[ + req_index, start_index:end_index] = req_data.new_block_ids + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for req_data in scheduler_output.scheduled_new_reqs: + req_id = req_data.req_id + sampling_params = req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=req_data.prompt_token_ids, + prompt=req_data.prompt, + multi_modal_data=req_data.multi_modal_data, + sampling_params=sampling_params, + generator=generator, + block_ids=req_data.block_ids, + num_computed_tokens=req_data.num_computed_tokens, + output_token_ids=[], + ) + req_ids_to_add.append(req_id) + + # Update the cached states of the resumed requests. + for req_data in scheduler_output.scheduled_resumed_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = req_data.block_ids + req_state.num_computed_tokens = req_data.num_computed_tokens + req_ids_to_add.append(req_id) + + # THIS MOVES ALL THE DECODES TO THE FIRST N IN BATCH. + # Condense the batched states if there are empty indices. + removed_req_indices = sorted(removed_req_indices, reverse=True) + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + # ALL THE PREFILLS ARE THE LAST M IN THE BATCH. + # These are added at the end after the bacth is condensed. + self.input_batch.num_prefills = len(req_ids_to_add) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + self.input_batch.add_request(req_state, None) + + def _prepare_prefill_inputs( + self, + num_scheduled_tokens: List[int], + ) -> PrefillInputData: + # Each prefill run separately with shape [1, padded_prompt_len]. + # So we create lists that will be used in execute_model(). + + prefill_request_ids = [] + prefill_prompt_lens = [] + prefill_token_ids = [] + prefill_position_ids = [] + prefill_attn_metadata = [] + + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = self.input_batch.num_decodes + for idx in range(num_decodes, num_reqs): + prefill_request_ids.append(self.input_batch.req_ids[idx]) + + # STATIC SHAPE: prefills are padded to the next power of 2. + prompt_len = num_scheduled_tokens[idx] + padded_prompt_len = _get_padded_prefill_len(prompt_len) + prefill_prompt_lens.append(prompt_len) + assert padded_prompt_len <= self.max_model_len + + # TOKEN_IDS. + token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ + idx, :padded_prompt_len].reshape(1, -1)) + prefill_token_ids.append(token_ids.to(self.device)) + + # POSITIONS. + positions = self.prefill_positions[:, :padded_prompt_len] + prefill_position_ids.append(positions.to(self.device)) + + # SLOT_MAPPING. + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_numbers = self.input_batch.block_table_cpu_tensor[ + idx, positions // self.block_size].reshape(1, -1) + block_offsets = positions % self.block_size + slot_mapping = block_numbers * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[:, prompt_len:] = _PAD_SLOT_ID + slot_mapping = slot_mapping.long() + + # ATTN_METADATA. + prefill_attn_metadata.append( + PallasAttentionMetadata( + is_prompt=True, + slot_mapping=slot_mapping.to(self.device), + block_tables=None, + context_lens=None, + )) + + return PrefillInputData( + request_ids=prefill_request_ids, + prompt_lens=prefill_prompt_lens, + token_ids=prefill_token_ids, + position_ids=prefill_position_ids, + attn_metadata=prefill_attn_metadata, + ) + + def _prepare_decode_inputs(self, num_decodes: int) -> DecodeInputData: + # Decodes run as one single padded batch with shape [batch, 1] + # + # We need to set _PAD_SLOT_ID for the padding tokens in the + # slot_mapping, such that the attention KV cache insertion + # logic knows to ignore those indicies. Otherwise, the + # padding data can be dummy since we have a causal mask. + + if num_decodes == 0: + return DecodeInputData(num_decodes=0) + + # PAD FOR STATIC SHAPES. + padded_batch_size = _get_padded_batch_size(num_decodes) + + # POSITIONS. [batch, 1] + # We slice at the end, since we use the positions for gathering. + positions = torch.from_numpy( + self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) + index = positions.to(torch.int64) + positions = positions[:padded_batch_size] + + # TOKEN_IDS. [batch, 1] + token_ids = torch.gather( + input=torch.from_numpy(self.input_batch.token_ids_cpu), + dim=1, + index=index, + )[:padded_batch_size] + + # SLOT_MAPPING [batch, 1] + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_number = torch.gather( + input=self.input_batch.block_table_cpu_tensor, + dim=1, + index=(index // self.block_size)) + block_offsets = index % self.block_size + slot_mapping = block_number * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[num_decodes:] = _PAD_SLOT_ID + slot_mapping = slot_mapping[:padded_batch_size] + + # BLOCK_TABLE [batch, max_num_blocks_per_req] + block_table = self.input_batch.block_table_cpu_tensor[: + padded_batch_size] + + # CONTEXT_LENS [batch_size] + context_lens = (positions.reshape(-1) + 1) + + # CPU<>TPU sync happens here. + return DecodeInputData(num_decodes=num_decodes, + token_ids=token_ids.to(self.device), + position_ids=positions.to(self.device), + attn_metadata=PallasAttentionMetadata( + is_prompt=False, + slot_mapping=slot_mapping.to(self.device), + block_tables=block_table.to(self.device), + context_lens=context_lens.to(self.device), + )) + + def _prepare_inputs( + self, scheduler_output: "SchedulerOutput" + ) -> Tuple[PrefillInputData, Optional[DecodeInputData]]: + + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + + num_reqs = self.input_batch.num_reqs + num_decodes = self.input_batch.num_decodes + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens.append(num_tokens) + + # NOTE: assert that all the decodes are "decodes". + if idx < num_decodes: + assert num_tokens == 1 + + return ( + self._prepare_prefill_inputs(num_scheduled_tokens), + self._prepare_decode_inputs(num_decodes), + ) + + def _prepare_sampling( + self, + scheduler_output: "SchedulerOutput", + ) -> SamplingMetadata: + skip_copy = True + if (scheduler_output.finished_req_ids + or scheduler_output.preempted_req_ids): + skip_copy = False + if (scheduler_output.scheduled_new_reqs + or scheduler_output.scheduled_resumed_reqs): + skip_copy = False + # Create the sampling metadata. + sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) + return sampling_metadata + + @torch.no_grad() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + self._update_states(scheduler_output) + prefill_data, decode_data = self._prepare_inputs(scheduler_output) + num_reqs = self.input_batch.num_reqs + sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) + + ######################### DECODES ######################### + # Decodes run as one single batch with [padded_batch, 1] + if decode_data.num_decodes > 0: + + # FORWARD. + selected_token_ids = self.model(decode_data.token_ids, + decode_data.position_ids, + decode_data.attn_metadata, + self.kv_caches, + is_prompt=False) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] + sampled_token_ids_list = token_ids.tolist() + sampled_token_ids[:decode_data.num_decodes] = token_ids + + # UPDATE REQUEST STATE. + for i, req_id in enumerate( + self.input_batch.req_ids[:decode_data.num_decodes]): + req_state = self.requests[req_id] + + # TODO: ASSERT NO CHUNKED PREFILL. + assert scheduler_output.num_scheduled_tokens[req_id] == 1 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len == req_state.num_tokens + + token_id = sampled_token_ids_list[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + ######################### PREFILLS ######################### + # Prefills run separately with shape [1, padded_prefill_len], + # due to lack of variable length attention kernel so far. + for idx, (req_id, prompt_len, token_ids, position_ids, + attn_metadata) in enumerate(prefill_data.zipped()): + + # FORWARD. + selected_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + self.kv_caches, + is_prompt=True) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_id = selected_token_ids.cpu()[prompt_len - 1].item() + sampled_token_ids[decode_data.num_decodes + idx] = token_id + req_state = self.requests[req_id] + + # TODO: ASSERT NO PREFIX CACHING. + assert req_state.num_computed_tokens == 0 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + # TODO: ASSERT NO CHUNKED PREFILL. + assert seq_len == req_state.num_tokens + assert prompt_len == seq_len + + # UPDATE REQUEST STATE. + req_idx = self.input_batch.req_id_to_index[req_id] + self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids[:num_reqs], + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids_cpu=sampled_token_ids, + logprob_token_ids_cpu=None, + logprobs_cpu=None, + ) + + def load_model(self) -> None: + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + + # xm_tp_rank = xr.global_ordinal() + # with patch( + # "vllm.model_executor.layers.vocab_parallel_embedding." + # "get_tensor_model_parallel_rank", + # return_value=xm_tp_rank): + # model = get_model(vllm_config=self.vllm_config) + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.wait_device_ops() + self.model = ModelWrapper(model) + + def _dummy_run(self, batch_size: int, seq_len: int, + kv_caches: List[torch.Tensor], is_prompt: bool) -> None: + """Dummy warmup run for memory usage and graph compilation.""" + + input_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = None if is_prompt else torch.zeros( + (batch_size, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device, + ) + context_lens = None if is_prompt else torch.ones( + (batch_size, ), + dtype=torch.int32, + device=self.device, + ) + attn_metadata = PallasAttentionMetadata( + is_prompt=is_prompt, + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + ) + + # NOTE: There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if is_prompt: + torch._dynamo.mark_dynamic(input_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + + # Dummy run. + self.model(input_ids, + position_ids, + attn_metadata, + kv_caches, + is_prompt=is_prompt) + + def profile_run(self) -> None: + """Profile to measure peak memory during forward pass.""" + + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value `None`. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + dummy_kv_caches = [( + torch.tensor([], dtype=torch.float32, device=self.device), + torch.tensor([], dtype=torch.float32, device=self.device), + ) for _ in range(self.num_attn_layers)] + + # Round to multiple of 16. + seq_len = (self.max_num_tokens + 15) // 16 * 16 + + # Run empty forward. + self._dummy_run(batch_size=1, + seq_len=seq_len, + kv_caches=dummy_kv_caches, + is_prompt=True) + + def capture_model(self) -> None: + """Compile the model.""" + + logger.info("Compiling the model with different input shapes.") + + # Prefill shapes. + start = time.perf_counter() + for batch_size in [1]: + seq_len = 16 + while True: + self._dummy_run(batch_size, + seq_len, + self.kv_caches, + is_prompt=True) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + if seq_len >= self.model_config.max_model_len: + break + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + seq_len = seq_len * 2 + + end = time.perf_counter() + logger.info("Compilation for prefill done in %.2f s.", end - start) + + # Decode shapes. + start = time.time() + seq_len = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() + while True: + self._dummy_run(batch_size, + seq_len, + self.kv_caches, + is_prompt=False) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode done in %.2f s.", end - start) + + def initialize_kv_cache(self, num_blocks: int) -> None: + assert len(self.kv_caches) == 0 + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + for _ in range(self.num_attn_layers): + self.kv_caches.append(( + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device), + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device), + )) + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional["MultiModalDataDict"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + + self.token_ids_cpu = np.zeros((max_num_reqs, max_model_len), + dtype=np.int32) + self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32) + + # Attention-related. + self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32) + self.block_table_cpu_tensor = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_cpu = self.block_table_cpu_tensor.numpy() + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: Set[str] = set() + + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + self.num_prefills = 0 + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + num_blocks = len(request.block_ids) + self.block_table_cpu[req_index, :num_blocks] = request.block_ids + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_id) + + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_id) + + def remove_request(self, req_id: str) -> Optional[int]: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def condense(self, empty_req_indices: List[int]) -> None: + if self.num_reqs == 0: + # The batched states are empty. + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_cpu[empty_index] = self.token_ids_cpu[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[ + last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata( + self, + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def num_decodes(self) -> int: + return self.num_reqs - self.num_prefills + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) if self.num_logprobs else 0 + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 + + +class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): + + def __init__(self, model: nn.Module): + self.model = model + compiled_callable = torch.compile(self.forward, + backend="openxla", + fullgraph=True, + dynamic=False) + super().__init__(compiled_callable) + + def __call__(self, *args, is_prompt: bool, **kwargs): + if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: + # not fully compiled yet, or not using the custom dispatcher, + # let PyTorch handle it + return self.compiled_callable(*args, **kwargs) + # the 3 compiled codes are: + # 0: for profiling + # 1: for prompt + # 2: for decode + # dispatch to the compiled code directly, skip PyTorch + if is_prompt: + with self.dispatch_to_code(1): + return self.forward(*args, **kwargs) + else: + with self.dispatch_to_code(2): + return self.forward(*args, **kwargs) + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: PallasAttentionMetadata, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + attn_metadata: The Pallas attention metadata. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + + # Skip this in memory profiling at initialization. + if kv_caches[0][0].numel() > 0: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + attn_metadata, + ) + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, None) + + # Greedy sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + return argmax_token_ids.squeeze(dim=1) + + +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 + + +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() diff --git a/vllm/v1/worker/__tpu_worker.py b/vllm/v1/worker/__tpu_worker.py new file mode 100644 index 0000000000000..866c1dbf6ea98 --- /dev/null +++ b/vllm/v1/worker/__tpu_worker.py @@ -0,0 +1,198 @@ +"""A TPU worker class.""" + +import os +from typing import TYPE_CHECKING, Tuple + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.tpu_model_runner import TPUModelRunner + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + + +class TPUWorker: + + def __init__(self, vllm_config: VllmConfig, local_rank: int, rank: int, + distributed_init_method: str): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + def initialize(self): + os.environ["PJRT_DEVICE"] = "TPU" + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # NOTE: This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=self.parallel_config.world_size, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size) + + # Device initialization should happen after initializing the distributed + # runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = TPUModelRunner(self.vllm_config) + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + + self.model_runner.profile_run() + + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_tpu_memory = m["bytes_limit"] + peak_memory = m[ + "peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak Used: %sGB", peak_memory // 1024 // 1024 // 1024) + logger.debug("Total Memory: %sGB", + total_tpu_memory // 1024 // 1024 // 1024) + + cache_block_size = _get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + num_tpu_blocks = int( + (total_tpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 + return num_tpu_blocks, 0 + + def initialize_cache(self, num_tpu_blocks: int) -> None: + """Allocate TPU and CPU KV cache with the specified number of blocks.""" + + if num_tpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_tpu_blocks + max_model_len = self.model_config.max_model_len + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + self.model_runner.initialize_kv_cache(num_tpu_blocks) + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + xm.mark_step() + xm.wait_device_ops() + m = xm.get_memory_info(self.device) + peak_memory = m[ + "peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak GB Used Post KV Cache: %sGB", + peak_memory // 1024 // 1024 // 1024) + + def compile_or_warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + output = self.model_runner.execute_model(scheduler_output) + # TODO(woosuk): Send the output to the engine process. + return output + + +# TODO: this is a duplicate. +def _get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, +) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_attention_layers( + parallel_config) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + return dtype_size * total diff --git a/vllm/v1/worker/tpu_model_runner_new.py b/vllm/v1/worker/tpu_model_runner_new.py new file mode 100644 index 0000000000000..9d910be89e640 --- /dev/null +++ b/vllm/v1/worker/tpu_model_runner_new.py @@ -0,0 +1,1081 @@ +import gc +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Tuple, cast, Optional + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn + +from vllm.config import CompilationLevel, VllmConfig +from vllm.distributed.parallel_state import graph_capture +from vllm.forward_context import set_forward_context +from vllm.inputs import INPUT_REGISTRY +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.sampling_params import SamplingType +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, + LayerBlockType, cdiv, is_pin_memory_available) +from vllm.v1.attention.backends.pallas import PallasMetadata +from vllm.v1.engine.mm_input_mapper import MMInputMapperClient +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME: Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 + + +@dataclass +class PrefillInputData: + + request_ids: List + prompt_lens: List + token_ids: List + position_ids: List + attn_metadata: List + + def zipped(self): + return zip(self.request_ids, self.prompt_lens, self.token_ids, + self.position_ids, self.attn_metadata) + + +@dataclass +class DecodeInputData: + + num_decodes: int + token_ids: Optional[torch.Tensor] = None + position_ids: Optional[torch.Tensor] = None + attn_metadata: PallasMetadata = None + + +class TPUModelRunner: + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] + + self.is_multimodal_model = model_config.is_multimodal_model + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + self.max_num_reqs = scheduler_config.max_num_seqs + + # Model-related. + self.num_attn_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + self.num_query_heads = model_config.get_num_attention_heads( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + self.hidden_size = model_config.get_hidden_size() + + # Multi-modal data support + self.input_registry = INPUT_REGISTRY + self.mm_registry = MULTIMODAL_REGISTRY + + # NOTE: Initialized input mapper is only used for processing dummy + # multimodal data into multimodal kwargs for GPU memory profiling. + self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) + self.mm_input_mapper_profiling.use_cache = False + + self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 + self.encoder_cache_size = self.scheduler_config.encoder_cache_size + + # Lazy initialization + # self.model: nn.Module # Set after load_model + self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] + # req_id -> (input_id -> encoder_output) + self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=model_config.get_vocab_size(), + ) + + self.prefill_positions = torch.tensor(range(self.max_model_len), + device="cpu", + dtype=torch.int32).reshape( + 1, -1) + + self.new_req_ids = None + + # TODO: Remove this + # self.use_cuda_graph = (self.vllm_config.compilation_config.level + # == CompilationLevel.PIECEWISE + # and not self.model_config.enforce_eager) + # # TODO(woosuk): Provide an option to tune the max cudagraph batch size. + # # The convention is different. + # # self.cudagraph_batch_sizes sorts in ascending order. + # # The batch sizes in the config are in descending order. + # self.cudagraph_batch_sizes = list( + # reversed(self.vllm_config.compilation_config.capture_sizes)) + + # # Cache the device properties. + # self.device_properties = torch.cuda.get_device_properties(self.device) + # self.num_sms = self.device_properties.multi_processor_count + + # # Persistent buffers for CUDA graphs. + # self.input_ids = torch.zeros(self.max_num_tokens, + # dtype=torch.int32, + # device=self.device) + # self.positions = torch.zeros(self.max_num_tokens, + # dtype=torch.int64, + # device=self.device) + # self.inputs_embeds = torch.zeros( + # (self.max_num_tokens, self.hidden_size), + # dtype=self.dtype, + # device=self.device) + + # # OPTIMIZATION: Cache the tensors rather than creating them every step. + # self.arange_np = np.arange(max(self.max_num_reqs + 1, + # self.max_model_len), + # dtype=np.int32) + # # NOTE(woosuk): These tensors are "stateless", i.e., they are literally + # # a faster version of creating a new tensor every time. Thus, we should + # # not make any assumptions about the values in these tensors. + # self.input_ids_cpu = torch.zeros(self.max_num_tokens, + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # self.input_ids_np = self.input_ids_cpu.numpy() + # self.positions_cpu = torch.zeros(self.max_num_tokens, + # dtype=torch.int64, + # device="cpu", + # pin_memory=self.pin_memory) + # self.positions_np = self.positions_cpu.numpy() + # self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # self.slot_mapping_np = self.slot_mapping_cpu.numpy() + # self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # self.query_start_loc_np = self.query_start_loc_cpu.numpy() + # self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove stopped requests from the cached states. + # Keep the states of the pre-empted requests. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + + # Remove the requests from the persistent batch. + stopped_req_ids = set().union( + scheduler_output.preempted_req_ids, + scheduler_output.finished_req_ids, + ) + removed_req_indices: List[int] = [] + for req_id in stopped_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Update the states of the running requests. + for req_data in scheduler_output.scheduled_running_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + req_index = self.input_batch.req_id_to_index[req_id] + + # Update the num_computed_tokens. + req_state.num_computed_tokens = req_data.num_computed_tokens + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + + # Update the block table. + num_new_blocks = len(req_data.new_block_ids) + if num_new_blocks == 0: + continue + start_index = len(req_state.block_ids) + req_state.block_ids.extend(req_data.new_block_ids) + self.input_batch.block_table.append_row(req_index, start_index, + req_data.new_block_ids) + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt=new_req_data.prompt, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + ) + req_ids_to_add.append(req_id) + + # Update the cached states of the resumed requests. + for res_req_data in scheduler_output.scheduled_resumed_reqs: + req_id = res_req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = res_req_data.block_ids + req_state.num_computed_tokens = res_req_data.num_computed_tokens + req_ids_to_add.append(req_id) + + # For TPU, we keep all of the decode requests before the + # prefill requests in the batch sequence. + # 1. First condense, so all decodes move to start + # 2. Then add new prefills to the end of the batch + removed_req_indices = sorted(removed_req_indices, reverse=True) + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + self.input_batch.add_request(req_state, None) # Append last + self.new_req_ids = req_ids_to_add + + def _prepare_prefill_inputs( + self, + num_scheduled_tokens: List[int], + ) -> PrefillInputData: + # Each prefill run separately with shape [1, padded_prompt_len]. + # So we create lists that will be used in execute_model(). + + prefill_request_ids = [] + prefill_prompt_lens = [] + prefill_token_ids = [] + prefill_position_ids = [] + prefill_attn_metadata = [] + + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = num_reqs - self.new_req_ids + for idx in range(num_decodes, num_reqs): + prefill_request_ids.append(self.input_batch.req_ids[idx]) + + prompt_len = num_scheduled_tokens[idx] + prefill_prompt_lens.append(prompt_len) + + # STATIC SHAPE: prefills are padded to the next power of 2. + padded_prompt_len = _get_padded_prefill_len(prompt_len) + assert padded_prompt_len <= self.max_model_len + + # TOKEN_IDS. + token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ + idx, :padded_prompt_len].reshape(1, -1)) + prefill_token_ids.append(token_ids.to(self.device)) + + # POSITIONS. + positions = self.prefill_positions[:, :padded_prompt_len] + prefill_position_ids.append(positions.to(self.device)) + + # SLOT_MAPPING. + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor( + ) + block_numbers = block_table_cpu_tensor[idx, positions // + self.block_size].reshape( + 1, -1) + block_offsets = positions % self.block_size + slot_mapping = block_numbers * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[:, prompt_len:] = _PAD_SLOT_ID + slot_mapping = slot_mapping.long() + + prefill_attn_metadata.append( + PallasMetadata( + num_prefills=1, + num_prefill_tokens=padded_prompt_len, + num_decode_tokens=0, + slot_mapping=slot_mapping.to(self.device), + block_tables=None, + context_lens=None, + effective_query_lens=None, + )) + + return PrefillInputData( + request_ids=prefill_request_ids, + prompt_lens=prefill_prompt_lens, + token_ids=prefill_token_ids, + position_ids=prefill_position_ids, + attn_metadata=prefill_attn_metadata, + ) + + def _prepare_decode_inputs(self) -> DecodeInputData: + # Decodes run as one single padded batch with shape [batch, 1] + # + # We need to set _PAD_SLOT_ID for the padding tokens in the + # slot_mapping, such that the attention KV cache insertion + # logic knows to ignore those indicies. Otherwise, the + # padding data can be dummy since we have a causal mask. + + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = num_reqs - self.new_req_ids + + if num_decodes == 0: + return DecodeInputData(num_decodes=0) + + # PAD FOR STATIC SHAPES. + padded_batch_size = _get_padded_batch_size(num_decodes) + + # POSITIONS. [batch, 1] + # We slice at the end, since we use the positions for gathering. + positions = torch.from_numpy( + self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) + index = positions.to(torch.int64) + positions = positions[:padded_batch_size] + + # TOKEN_IDS. [batch, 1] + token_ids = torch.gather( + input=torch.from_numpy(self.input_batch.token_ids_cpu), + dim=1, + index=index, + )[:padded_batch_size] + + # SLOT_MAPPING [batch, 1] + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor() + block_number = torch.gather(input=block_table_cpu_tensor, + dim=1, + index=(index // self.block_size)) + block_offsets = index % self.block_size + slot_mapping = block_number * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[num_decodes:] = _PAD_SLOT_ID + slot_mapping = slot_mapping[:padded_batch_size] + + # BLOCK_TABLE [batch, max_num_blocks_per_req] + block_table = block_table_cpu_tensor[:padded_batch_size] + + # CONTEXT_LENS [batch_size] + context_lens = (positions.reshape(-1) + 1) + + # CPU<>TPU sync happens here. + return DecodeInputData(num_decodes=num_decodes, + token_ids=token_ids.to(self.device), + position_ids=positions.to(self.device), + attn_metadata=PallasMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=padded_batch_size, + slot_mapping=slot_mapping.to(self.device), + block_tables=block_table.to(self.device), + context_lens=context_lens.to(self.device), + effective_query_lens=None, + )) + + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + num_decodes = num_reqs - self.new_req_ids + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + # TODO: Verify this works with TPUs + self.input_batch.block_table.commit(num_reqs) + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + max_num_scheduled_tokens = 0 + for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + assert req_id is not None + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens.append(num_tokens) + max_num_scheduled_tokens = max(max_num_scheduled_tokens, + num_tokens) + + # NOTE: Assert that all the decodes are "decodes". + if idx < num_decodes: + assert num_tokens == 1 + assert max_num_scheduled_tokens > 0 + + return ( + self._prepare_prefill_inputs(num_scheduled_tokens), + self._prepare_decode_inputs(num_decodes), + ) + + # # OPTIMIZATION: Start copying the block table first. + # # This way, we can overlap the copy with the following CPU operations. + # self.input_batch.block_table.commit(num_reqs) + + # # Get the number of scheduled tokens for each request. + # # TODO: The Python loop can be slow. Optimize. + # num_scheduled_tokens = [] + # max_num_scheduled_tokens = 0 + # for req_id in self.input_batch.req_ids[:num_reqs]: + # assert req_id is not None + # num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # num_scheduled_tokens.append(num_tokens) + # max_num_scheduled_tokens = max(max_num_scheduled_tokens, + # num_tokens) + # num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) + # assert max_num_scheduled_tokens > 0 + + # # Get request indices. + # # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + # req_indices = np.repeat(self.arange_np[:num_reqs], + # num_scheduled_tokens) + + # # Get batched arange. + # # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # arange = np.concatenate( + # [self.arange_np[:n] for n in num_scheduled_tokens]) + + # # Get positions. + # positions_np = self.positions_np[:total_num_scheduled_tokens] + # np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + # arange, + # out=positions_np) + + # # Get token indices. + # # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # # where M is the max_model_len. + # token_indices = (positions_np + + # req_indices * self.input_batch.token_ids_cpu.shape[1]) + # # NOTE(woosuk): We use torch.index_select instead of np.take here + # # because torch.index_select is much faster than np.take for large + # # tensors. + # torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + # 0, + # torch.from_numpy(token_indices), + # out=self.input_ids_cpu[:total_num_scheduled_tokens]) + + # # Calculate the slot mapping. + # # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # # where K is the max_num_blocks_per_req and the block size is 2. + # # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # # because M (max_model_len) is not necessarily divisible by block_size. + # block_table_indices = (req_indices * self.max_num_blocks_per_req + + # positions_np // self.block_size) + # # NOTE(woosuk): We use torch.index_select instead of np.take here + # # because torch.index_select is much faster than np.take for large + # # tensors. + # block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + # block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + # block_offsets = positions_np % self.block_size + # np.add(block_numbers * self.block_size, + # block_offsets, + # out=self.slot_mapping_np[:total_num_scheduled_tokens]) + + # # Prepare the attention metadata. + # self.query_start_loc_np[0] = 0 + # np.cumsum(num_scheduled_tokens, + # out=self.query_start_loc_np[1:num_reqs + 1]) + + # seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + + # num_scheduled_tokens) + # max_seq_len = seq_lens.max() + # self.seq_start_loc_np[0] = 0 + # np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1]) + + # # Copy the tensors to the GPU. + # self.input_ids[:total_num_scheduled_tokens].copy_( + # self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + # self.positions[:total_num_scheduled_tokens].copy_( + # self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) + # query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( + # self.device, non_blocking=True) + # seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to( + # self.device, non_blocking=True) + # slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( + # self.device, non_blocking=True).long() + + # # Prepare for cascade attention if needed. + # common_prefix_len = (scheduler_output.num_common_prefix_blocks * + # self.block_size) + # if common_prefix_len == 0: + # # Common case. + # use_cascade = False + # else: + # # NOTE(woosuk): Cascade attention uses two attention kernels: one + # # for the common prefix and the other for the rest. For the first + # # kernel, we concatenate all the query tokens (possibly from + # # different requests) and treat them as if they are from the same + # # request. Then, we use bi-directional attention to process the + # # common prefix in the KV cache. Importantly, this means that the + # # first kernel does not do any masking. + + # # Consider the following example: + # # Request 1's input query: [D, E, X] + # # Request 1's kv cache: [A, B, C, D, E, X] + # # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) + # # Request 2's input query: [E, Y] + # # Request 2's kv cache: [A, B, C, D, E, Y] + # # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) + + # # If we use [A, B, C, D, E] as the common prefix, then the + # # first kernel will compute the bi-directional attention between + # # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. + # # However, this is wrong because D in Request 1 should not attend to + # # E in the common prefix (i.e., we need masking). + # # To avoid this, [A, B, C, D] should be the common prefix. + # # That is, the common prefix should be capped by the minimum + # # num_computed_tokens among the requests, and plus one to include + # # the first token of the query. + + # # In practice, we use [A, B, C] as the common prefix, instead of + # # [A, B, C, D] (i.e., the common prefix is capped by the minimum + # # num_computed_tokens, without plus one). + # # This is because of an implementation detail: We want to always + # # use two kernels for cascade attention. Let's imagine: + # # Request 3's input query: [D] + # # Request 3's kv cache: [A, B, C, D] + # # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) + # # If we use [A, B, C, D] as the common prefix for Request 1-3, + # # then Request 3 will be processed only by the first kernel, + # # and the second kernel will get an empty input. While this is not + # # a fundamental problem, our current implementation does not support + # # this case. + # common_prefix_len = min( + # common_prefix_len, + # self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + # # common_prefix_len should be a multiple of the block size. + # common_prefix_len = (common_prefix_len // self.block_size * + # self.block_size) + # use_cascade = FlashAttentionBackend.use_cascade_attention( + # common_prefix_len=common_prefix_len, + # query_lens=num_scheduled_tokens, + # num_query_heads=self.num_query_heads, + # num_kv_heads=self.num_kv_heads, + # use_alibi=False, # FIXME + # use_sliding_window=self.sliding_window is not None, + # num_sms=self.num_sms, + # ) + + # if use_cascade: + # # TODO: Optimize. + # cu_prefix_query_lens = torch.tensor( + # [0, total_num_scheduled_tokens], + # dtype=torch.int32, + # device=self.device) + # cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], + # dtype=torch.int32, + # device=self.device) + # cu_suffix_kv_lens = ( + # self.seq_start_loc_np[:num_reqs + 1] - + # self.arange_np[:num_reqs + 1] * common_prefix_len) + # cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to( + # self.device) + # else: + # cu_prefix_query_lens = None + # cu_prefix_kv_lens = None + # cu_suffix_kv_lens = None + + # attn_metadata = FlashAttentionMetadata( + # num_actual_tokens=total_num_scheduled_tokens, + # max_query_len=max_num_scheduled_tokens, + # query_start_loc=query_start_loc, + # max_seq_len=max_seq_len, + # seq_start_loc=seq_start_loc, + # block_table=( + # self.input_batch.block_table.get_device_tensor()[:num_reqs]), + # slot_mapping=slot_mapping, + # use_cascade=use_cascade, + # common_prefix_len=common_prefix_len, + # cu_prefix_query_lens=cu_prefix_query_lens, + # cu_prefix_kv_lens=cu_prefix_kv_lens, + # cu_suffix_kv_lens=cu_suffix_kv_lens, + # ) + # # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # # request in the batch. While we should not sample any token from this + # # partial request, we do so for simplicity. We will ignore the sampled + # # token from the partial request. + # # TODO: Support prompt logprobs. + # logits_indices = query_start_loc[1:] - 1 + # return attn_metadata, logits_indices + + def _prepare_sampling( + self, + scheduler_output: "SchedulerOutput", + ) -> SamplingMetadata: + skip_copy = True + if (scheduler_output.finished_req_ids + or scheduler_output.preempted_req_ids): + skip_copy = False + if (scheduler_output.scheduled_new_reqs + or scheduler_output.scheduled_resumed_reqs): + skip_copy = False + # Create the sampling metadata. + req_id_output_token_ids: Dict[str, List[int]] = \ + {req_id: req.output_token_ids \ + for req_id, req in self.requests.items()} + + sampling_metadata = self.input_batch.make_sampling_metadata( + req_id_output_token_ids, skip_copy) + return sampling_metadata + + def _execute_encoder(self, scheduler_output: "SchedulerOutput"): + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return + + # Batch the multi-modal inputs. + mm_inputs: List[MultiModalKwargs] = [] + req_input_ids: List[Tuple[str, int]] = [] + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + for input_id in encoder_input_ids: + mm_inputs.append(req_state.mm_inputs[input_id]) + req_input_ids.append((req_id, input_id)) + batched_mm_inputs = MultiModalKwargs.batch(mm_inputs) + batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, + device=self.device) + + # Run the encoder. + # `encoder_outputs` is either of the following: + # 1. A tensor of shape [num_images, feature_size, hidden_size] + # in case when feature_size is fixed across all images. + # 2. A list (length: num_images) of tensors, each of shape + # [feature_size, hidden_size] in case when the feature size is + # dynamic depending on input images. + encoder_outputs = self.model.get_multimodal_embeddings( + **batched_mm_inputs) + + # Cache the encoder outputs. + for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): + if req_id not in self.encoder_cache: + self.encoder_cache[req_id] = {} + self.encoder_cache[req_id][input_id] = output + + def _gather_encoder_outputs( + self, + scheduler_output: "SchedulerOutput", + ) -> List[torch.Tensor]: + encoder_outputs: List[torch.Tensor] = [] + num_reqs = self.input_batch.num_reqs + for req_id in self.input_batch.req_ids[:num_reqs]: + assert req_id is not None + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + req_state = self.requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + mm_positions = req_state.mm_positions + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info["offset"] + num_encoder_tokens = pos_info["length"] + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, + # num_computed_tokens + num_scheduled_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_scheduled_tokens: + # The encoder output is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens) + assert start_idx < end_idx + assert req_id in self.encoder_cache + assert i in self.encoder_cache[req_id] + encoder_output = self.encoder_cache[req_id][i] + encoder_outputs.append(encoder_output[start_idx:end_idx]) + return encoder_outputs + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + self._update_states(scheduler_output) + + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_encoder(scheduler_output) + encoder_outputs = self._gather_encoder_outputs(scheduler_output) + else: + encoder_outputs = [] + + # Prepare the decoder inputs. + attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if (self.use_cuda_graph + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_scheduled_tokens) + else: + # Eager mode. + num_input_tokens = num_scheduled_tokens + attn_metadata.num_input_tokens = num_input_tokens + + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[:num_scheduled_tokens] + if encoder_outputs: + inputs_embeds = self.model.get_input_embeddings( + input_ids, encoder_outputs) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + + # Run the decoder. + # Use persistent buffers for CUDA graphs. + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + positions=self.positions[:num_input_tokens], + kv_caches=self.kv_caches, + attn_metadata=None, + inputs_embeds=inputs_embeds, + ) + hidden_states = hidden_states[:num_scheduled_tokens] + hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(hidden_states, None) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self._prepare_sampling(scheduler_output) + sampler_output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + + sampled_token_ids = sampler_output.sampled_token_ids + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + num_reqs = self.input_batch.num_reqs + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + assert req_id is not None + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len <= req_state.num_tokens + if seq_len == req_state.num_tokens: + # Append the sampled token to the output token ids. + token_id = sampled_token_ids[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + self.input_batch.num_tokens[i] += 1 + req_state.output_token_ids.append(token_id) + else: + # Ignore the sampled token from the partial request. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + # This relies on cuda-specific torch-internal impl details + generator.set_offset(generator.get_offset() - 4) + + if sampler_output.logprob_token_ids is None: + logprob_token_ids = None + else: + logprob_token_ids = sampler_output.logprob_token_ids.cpu() + if sampler_output.logprobs is None: + logprobs = None + else: + logprobs = sampler_output.logprobs.cpu() + + # num_reqs entries should be non-None + assert all( + req_id is not None for req_id in + self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=sampled_token_ids, + logprob_token_ids_cpu=logprob_token_ids, + logprobs_cpu=logprobs, + ) + return model_runner_output + + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: # noqa: SIM117 + self.model = get_model(vllm_config=self.vllm_config) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + @torch.inference_mode() + def _dummy_run( + self, + model: nn.Module, + num_tokens: int, + kv_caches: List[torch.Tensor], + ) -> torch.Tensor: + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + with set_forward_context(None, self.vllm_config): + hidden_states = model( + input_ids=input_ids, + positions=self.positions[:num_tokens], + kv_caches=kv_caches, + attn_metadata=None, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def profile_run(self) -> None: + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value `None`. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + dummy_kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + for _ in range(self.num_attn_layers) + ] + + # Profile with multimodal encoder & encoder cache. + if self.is_multimodal_model: + + # Create dummy batch of multimodal inputs. + dummy_request_data = self.input_registry.dummy_data_for_profiling( + model_config=self.model_config, + seq_len=self.max_num_tokens, + mm_registry=self.mm_registry, + ) + dummy_mm_data = dummy_request_data.multi_modal_data + + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501 + self.model_config) + + dummy_data_modality, max_tokens_per_mm_item = max( + max_tokens_by_modality_dict.items(), key=lambda item: item[1]) + + # Check how many items of this modality can be supported by + # the encoder cache budget. + encoder_cache_budget = min(self.max_num_encoder_input_tokens, + self.encoder_cache_size) + max_num_mm_items_encoder_budget = encoder_cache_budget // \ + max_tokens_per_mm_item + + # TODO: Allow users to set encoder_cache_budget in case this + # happens. + assert max_num_mm_items_encoder_budget > 0, ( + f"Encoder cache budget={encoder_cache_budget} is too small to " + f"support the maximum possible size of multimodal embeddings" + f"={max_tokens_per_mm_item}.") + + # Check how many items of this modality can be supported by + # the decoder budget. + max_mm_items_per_req = max( + self.mm_registry.get_mm_limits_per_prompt( + self.model_config).values()) + + # NOTE: We do not consider max_num_batched_tokens on purpose + # because the multimodal embeddings can be generated in advance + # and chunked prefilled. + max_num_mm_items_decoder_budget = self.max_num_reqs * \ + max_mm_items_per_req + + max_num_mm_items = min(max_num_mm_items_encoder_budget, + max_num_mm_items_decoder_budget) + + # Dummy data definition in V0 may contain multiple multimodal items + # (e.g, multiple images) for a single request, therefore here we + # always replicate first item by max_num_mm_items times since in V1 + # they are scheduled to be processed separately. + + # Case when models have a merged processor, their dummy data is + # already batched `MultiModalKwargs`, therefore we take the first + # `MultiModalKwargsItem` from the desired modality to profile on. + if isinstance(dummy_mm_data, MultiModalKwargs): + dummy_mm_item = dummy_mm_data.get_item( + modality=dummy_data_modality, item_index=0) + dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) + + # Case when models have dummy data explicitly defined as + # `MultiModalDataDict`, so they need to be processed through input + # mapper. + # TODO (ywang96): deprecate this path once merged processor is + # supported on all models. + else: + mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs( + mm_data=dummy_mm_data, + mm_hashes=None, + mm_processor_kwargs=None, + precomputed_mm_inputs=None) + dummy_mm_kwargs = mm_kwargs_list[0] + + batched_dummy_mm_inputs = MultiModalKwargs.batch( + [dummy_mm_kwargs] * max_num_mm_items) + batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( + batched_dummy_mm_inputs, device=self.device) + + # Run multimodal encoder. + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + assert len(dummy_encoder_outputs) == max_num_mm_items, ( + "Expected dimension 0 of encoder outputs to match the number " + f"of multimodal data items: {max_num_mm_items}, got " + f"{len(dummy_encoder_outputs)=} instead. This is most likely " + "due to the 'get_multimodal_embeddings' method of the model " + "not implemented correctly.") + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + + # Trigger compilation for general shape. + hidden_states = self._dummy_run(self.model, self.max_num_tokens, + dummy_kv_caches) + logits = self.model.compute_logits(hidden_states, None) + logits = logits[:self.max_num_tokens] + # TODO(woosuk): Consider the memory usage of the sampler. + torch.cuda.synchronize() + del hidden_states, logits + self.encoder_cache.clear() + gc.collect() + + def capture_model(self) -> None: + if not self.use_cuda_graph: + logger.warning( + "Skipping CUDA graph capture. Please add " + "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE) + return + + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + # Trigger CUDA graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with graph_capture(device=self.device): + for num_tokens in reversed(self.cudagraph_batch_sizes): + for _ in range(self.vllm_config.compilation_config. + cudagraph_num_of_warmups): + self._dummy_run(self.model, num_tokens, self.kv_caches) + self._dummy_run(self.model, num_tokens, self.kv_caches) + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + # This usually takes 5~20 seconds. + logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, cuda_graph_size / (1 << 30)) + + def initialize_kv_cache(self, num_blocks: int) -> None: + assert len(self.kv_caches) == 0 + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + for _ in range(self.num_attn_layers): + self.kv_caches.append( + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device)) + + +# TODO: Duplicate with V0, refactor +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() + + +# TODO: Duplicate with V0, refactor +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 From ae3c487ed115172bc77d52e7d13d6ea65954513b Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 14:39:13 +0000 Subject: [PATCH 02/18] finished tpu model runner --- vllm/v1/worker/__tpu_model_runner.py | 981 ------------------------- vllm/v1/worker/tpu_model_runner_new.py | 816 +++++++++++++------- 2 files changed, 562 insertions(+), 1235 deletions(-) delete mode 100644 vllm/v1/worker/__tpu_model_runner.py diff --git a/vllm/v1/worker/__tpu_model_runner.py b/vllm/v1/worker/__tpu_model_runner.py deleted file mode 100644 index 7963fe4973b55..0000000000000 --- a/vllm/v1/worker/__tpu_model_runner.py +++ /dev/null @@ -1,981 +0,0 @@ -import time -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple - -import numpy as np -import torch -import torch.distributed -import torch.nn as nn -import torch_xla.core.xla_model as xm - -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model -from vllm.multimodal import MultiModalDataDict -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, - PallasAttentionMetadata) -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.sample.metadata import SamplingMetadata - -if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput - -logger = init_logger(__name__) - -# Here we utilize the behavior that out-of-bound index is ignored. -# FIXME: Find a more reliable way to prevent possible bugs. -_PAD_SLOT_ID = 1_000_000_000 - - -@dataclass -class PrefillInputData: - - request_ids: List - prompt_lens: List - token_ids: List - position_ids: List - attn_metadata: List - - def zipped(self): - return zip(self.request_ids, self.prompt_lens, self.token_ids, - self.position_ids, self.attn_metadata) - - -@dataclass -class DecodeInputData: - - num_decodes: int - token_ids: Optional[torch.Tensor] = None - position_ids: Optional[torch.Tensor] = None - attn_metadata: PallasAttentionMetadata = None - - -class TPUModelRunner: - - def __init__( - self, - vllm_config: VllmConfig, - ): - # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config) - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - model_config = self.model_config - cache_config = self.cache_config - scheduler_config = self.scheduler_config - parallel_config = self.parallel_config - self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() - self.dtype = self.model_config.dtype - if cache_config.cache_dtype == "auto": - self.kv_cache_dtype = self.dtype - else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] - - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_model_len = model_config.max_model_len - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = scheduler_config.max_num_batched_tokens - - # Model-related. - self.num_attn_layers = model_config.get_num_attention_layers( - parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() - - # List[k_cache, v_cache] - self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] - - # Request states. - self.requests: Dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.scheduler_config.max_num_seqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - ) - - self.prefill_positions = torch.tensor( - range(self.max_model_len), - device="cpu", - ).to(torch.int32).reshape(1, -1) - - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: - # Remove stopped requests from the cached states. - # Keep the states of the pre-empted requests. - for req_id in scheduler_output.finished_req_ids: - self.requests.pop(req_id, None) - - # Remove the requests from the persistent batch. - stopped_req_ids = set().union( - scheduler_output.preempted_req_ids, - scheduler_output.finished_req_ids, - ) - removed_req_indices: List[int] = [] - for req_id in stopped_req_ids: - req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) - - # Update the states of the running requests. - for req_data in scheduler_output.scheduled_running_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] - req_index = self.input_batch.req_id_to_index[req_id] - - # Update the num_computed_tokens. - req_state.num_computed_tokens = req_data.num_computed_tokens - self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - - # Update the block table. - num_new_blocks = len(req_data.new_block_ids) - if num_new_blocks == 0: - continue - start_index = len(req_state.block_ids) - end_index = start_index + num_new_blocks - req_state.block_ids.extend(req_data.new_block_ids) - self.input_batch.block_table_cpu[ - req_index, start_index:end_index] = req_data.new_block_ids - - req_ids_to_add: List[str] = [] - # Add new requests to the cached states. - for req_data in scheduler_output.scheduled_new_reqs: - req_id = req_data.req_id - sampling_params = req_data.sampling_params - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_params.seed) - else: - generator = None - - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=req_data.prompt_token_ids, - prompt=req_data.prompt, - multi_modal_data=req_data.multi_modal_data, - sampling_params=sampling_params, - generator=generator, - block_ids=req_data.block_ids, - num_computed_tokens=req_data.num_computed_tokens, - output_token_ids=[], - ) - req_ids_to_add.append(req_id) - - # Update the cached states of the resumed requests. - for req_data in scheduler_output.scheduled_resumed_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] - - req_state.block_ids = req_data.block_ids - req_state.num_computed_tokens = req_data.num_computed_tokens - req_ids_to_add.append(req_id) - - # THIS MOVES ALL THE DECODES TO THE FIRST N IN BATCH. - # Condense the batched states if there are empty indices. - removed_req_indices = sorted(removed_req_indices, reverse=True) - if removed_req_indices: - self.input_batch.condense(removed_req_indices) - - # ALL THE PREFILLS ARE THE LAST M IN THE BATCH. - # These are added at the end after the bacth is condensed. - self.input_batch.num_prefills = len(req_ids_to_add) - for req_id in req_ids_to_add: - req_state = self.requests[req_id] - self.input_batch.add_request(req_state, None) - - def _prepare_prefill_inputs( - self, - num_scheduled_tokens: List[int], - ) -> PrefillInputData: - # Each prefill run separately with shape [1, padded_prompt_len]. - # So we create lists that will be used in execute_model(). - - prefill_request_ids = [] - prefill_prompt_lens = [] - prefill_token_ids = [] - prefill_position_ids = [] - prefill_attn_metadata = [] - - # DECODES are the first num_decodes REQUESTS. - # PREFILLS are the next num_reqs - num_decodes REQUESTS. - num_reqs = self.input_batch.num_reqs - num_decodes = self.input_batch.num_decodes - for idx in range(num_decodes, num_reqs): - prefill_request_ids.append(self.input_batch.req_ids[idx]) - - # STATIC SHAPE: prefills are padded to the next power of 2. - prompt_len = num_scheduled_tokens[idx] - padded_prompt_len = _get_padded_prefill_len(prompt_len) - prefill_prompt_lens.append(prompt_len) - assert padded_prompt_len <= self.max_model_len - - # TOKEN_IDS. - token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ - idx, :padded_prompt_len].reshape(1, -1)) - prefill_token_ids.append(token_ids.to(self.device)) - - # POSITIONS. - positions = self.prefill_positions[:, :padded_prompt_len] - prefill_position_ids.append(positions.to(self.device)) - - # SLOT_MAPPING. - # The "slot" is the "physical index" of a token in the KV cache. - # Look up the block_idx in the block table (logical<>physical map) - # to compute this. - block_numbers = self.input_batch.block_table_cpu_tensor[ - idx, positions // self.block_size].reshape(1, -1) - block_offsets = positions % self.block_size - slot_mapping = block_numbers * self.block_size + block_offsets - # Set an out of range value for the padding tokens so that they - # are ignored when inserting into the KV cache. - slot_mapping[:, prompt_len:] = _PAD_SLOT_ID - slot_mapping = slot_mapping.long() - - # ATTN_METADATA. - prefill_attn_metadata.append( - PallasAttentionMetadata( - is_prompt=True, - slot_mapping=slot_mapping.to(self.device), - block_tables=None, - context_lens=None, - )) - - return PrefillInputData( - request_ids=prefill_request_ids, - prompt_lens=prefill_prompt_lens, - token_ids=prefill_token_ids, - position_ids=prefill_position_ids, - attn_metadata=prefill_attn_metadata, - ) - - def _prepare_decode_inputs(self, num_decodes: int) -> DecodeInputData: - # Decodes run as one single padded batch with shape [batch, 1] - # - # We need to set _PAD_SLOT_ID for the padding tokens in the - # slot_mapping, such that the attention KV cache insertion - # logic knows to ignore those indicies. Otherwise, the - # padding data can be dummy since we have a causal mask. - - if num_decodes == 0: - return DecodeInputData(num_decodes=0) - - # PAD FOR STATIC SHAPES. - padded_batch_size = _get_padded_batch_size(num_decodes) - - # POSITIONS. [batch, 1] - # We slice at the end, since we use the positions for gathering. - positions = torch.from_numpy( - self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) - index = positions.to(torch.int64) - positions = positions[:padded_batch_size] - - # TOKEN_IDS. [batch, 1] - token_ids = torch.gather( - input=torch.from_numpy(self.input_batch.token_ids_cpu), - dim=1, - index=index, - )[:padded_batch_size] - - # SLOT_MAPPING [batch, 1] - # The "slot" is the "physical index" of a token in the KV cache. - # Look up the block_idx in the block table (logical<>physical map) - # to compute this. - block_number = torch.gather( - input=self.input_batch.block_table_cpu_tensor, - dim=1, - index=(index // self.block_size)) - block_offsets = index % self.block_size - slot_mapping = block_number * self.block_size + block_offsets - # Set an out of range value for the padding tokens so that they - # are ignored when inserting into the KV cache. - slot_mapping[num_decodes:] = _PAD_SLOT_ID - slot_mapping = slot_mapping[:padded_batch_size] - - # BLOCK_TABLE [batch, max_num_blocks_per_req] - block_table = self.input_batch.block_table_cpu_tensor[: - padded_batch_size] - - # CONTEXT_LENS [batch_size] - context_lens = (positions.reshape(-1) + 1) - - # CPU<>TPU sync happens here. - return DecodeInputData(num_decodes=num_decodes, - token_ids=token_ids.to(self.device), - position_ids=positions.to(self.device), - attn_metadata=PallasAttentionMetadata( - is_prompt=False, - slot_mapping=slot_mapping.to(self.device), - block_tables=block_table.to(self.device), - context_lens=context_lens.to(self.device), - )) - - def _prepare_inputs( - self, scheduler_output: "SchedulerOutput" - ) -> Tuple[PrefillInputData, Optional[DecodeInputData]]: - - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 - - num_reqs = self.input_batch.num_reqs - num_decodes = self.input_batch.num_decodes - - # Get the number of scheduled tokens for each request. - # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens = [] - for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens.append(num_tokens) - - # NOTE: assert that all the decodes are "decodes". - if idx < num_decodes: - assert num_tokens == 1 - - return ( - self._prepare_prefill_inputs(num_scheduled_tokens), - self._prepare_decode_inputs(num_decodes), - ) - - def _prepare_sampling( - self, - scheduler_output: "SchedulerOutput", - ) -> SamplingMetadata: - skip_copy = True - if (scheduler_output.finished_req_ids - or scheduler_output.preempted_req_ids): - skip_copy = False - if (scheduler_output.scheduled_new_reqs - or scheduler_output.scheduled_resumed_reqs): - skip_copy = False - # Create the sampling metadata. - sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) - return sampling_metadata - - @torch.no_grad() - def execute_model( - self, - scheduler_output: "SchedulerOutput", - ) -> ModelRunnerOutput: - self._update_states(scheduler_output) - prefill_data, decode_data = self._prepare_inputs(scheduler_output) - num_reqs = self.input_batch.num_reqs - sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) - - ######################### DECODES ######################### - # Decodes run as one single batch with [padded_batch, 1] - if decode_data.num_decodes > 0: - - # FORWARD. - selected_token_ids = self.model(decode_data.token_ids, - decode_data.position_ids, - decode_data.attn_metadata, - self.kv_caches, - is_prompt=False) - - # NOTE: TPU<>CPU sync happens here. - # We need to call .cpu() first to avoid recompilation. - token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] - sampled_token_ids_list = token_ids.tolist() - sampled_token_ids[:decode_data.num_decodes] = token_ids - - # UPDATE REQUEST STATE. - for i, req_id in enumerate( - self.input_batch.req_ids[:decode_data.num_decodes]): - req_state = self.requests[req_id] - - # TODO: ASSERT NO CHUNKED PREFILL. - assert scheduler_output.num_scheduled_tokens[req_id] == 1 - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - assert seq_len == req_state.num_tokens - - token_id = sampled_token_ids_list[i] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - req_state.output_token_ids.append(token_id) - - ######################### PREFILLS ######################### - # Prefills run separately with shape [1, padded_prefill_len], - # due to lack of variable length attention kernel so far. - for idx, (req_id, prompt_len, token_ids, position_ids, - attn_metadata) in enumerate(prefill_data.zipped()): - - # FORWARD. - selected_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - self.kv_caches, - is_prompt=True) - - # NOTE: TPU<>CPU sync happens here. - # We need to call .cpu() first to avoid recompilation. - token_id = selected_token_ids.cpu()[prompt_len - 1].item() - sampled_token_ids[decode_data.num_decodes + idx] = token_id - req_state = self.requests[req_id] - - # TODO: ASSERT NO PREFIX CACHING. - assert req_state.num_computed_tokens == 0 - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - - # TODO: ASSERT NO CHUNKED PREFILL. - assert seq_len == req_state.num_tokens - assert prompt_len == seq_len - - # UPDATE REQUEST STATE. - req_idx = self.input_batch.req_id_to_index[req_id] - self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id - req_state.output_token_ids.append(token_id) - - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids[:num_reqs], - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids_cpu=sampled_token_ids, - logprob_token_ids_cpu=None, - logprobs_cpu=None, - ) - - def load_model(self) -> None: - - # NOTE(woosuk): While the executor assigns the TP ranks to the worker - # process, the ranks can be different from the ranks internally assigned - # by the xm runtime. Therefore, there is a mismatch in the rank - # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. - # This is not a problem in linear layers because all-reduce is - # rank-agnostic. However, it matters for all-gather as the ranks - # determine the order of concatenating the output tensors. - # As a workaround, we use the xm's rank assignment only when loading - # the embedding weights. - - # xm_tp_rank = xr.global_ordinal() - # with patch( - # "vllm.model_executor.layers.vocab_parallel_embedding." - # "get_tensor_model_parallel_rank", - # return_value=xm_tp_rank): - # model = get_model(vllm_config=self.vllm_config) - model = get_model(vllm_config=self.vllm_config) - model = model.eval() - xm.wait_device_ops() - self.model = ModelWrapper(model) - - def _dummy_run(self, batch_size: int, seq_len: int, - kv_caches: List[torch.Tensor], is_prompt: bool) -> None: - """Dummy warmup run for memory usage and graph compilation.""" - - input_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((batch_size, seq_len), - dtype=torch.int64, - device=self.device) - block_tables = None if is_prompt else torch.zeros( - (batch_size, self.max_num_blocks_per_req), - dtype=torch.int32, - device=self.device, - ) - context_lens = None if is_prompt else torch.ones( - (batch_size, ), - dtype=torch.int32, - device=self.device, - ) - attn_metadata = PallasAttentionMetadata( - is_prompt=is_prompt, - slot_mapping=slot_mapping, - block_tables=block_tables, - context_lens=context_lens, - ) - - # NOTE: There are two stages of compilation: torch.compile and - # XLA compilation. Using `mark_dynamic` can reduce the torch.compile - # overhead by reusing the FX graph for different shapes. - # However, the XLA graph will still require static shapes and needs to - # be re-compiled for every different shapes. This overhead is inevitable - # in the first run, but can be skipped afterwards as we cache the XLA - # graphs in the disk (VLLM_XLA_CACHE_PATH). - if is_prompt: - torch._dynamo.mark_dynamic(input_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) - else: - torch._dynamo.mark_dynamic(input_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - - # Dummy run. - self.model(input_ids, - position_ids, - attn_metadata, - kv_caches, - is_prompt=is_prompt) - - def profile_run(self) -> None: - """Profile to measure peak memory during forward pass.""" - - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value `None`. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - dummy_kv_caches = [( - torch.tensor([], dtype=torch.float32, device=self.device), - torch.tensor([], dtype=torch.float32, device=self.device), - ) for _ in range(self.num_attn_layers)] - - # Round to multiple of 16. - seq_len = (self.max_num_tokens + 15) // 16 * 16 - - # Run empty forward. - self._dummy_run(batch_size=1, - seq_len=seq_len, - kv_caches=dummy_kv_caches, - is_prompt=True) - - def capture_model(self) -> None: - """Compile the model.""" - - logger.info("Compiling the model with different input shapes.") - - # Prefill shapes. - start = time.perf_counter() - for batch_size in [1]: - seq_len = 16 - while True: - self._dummy_run(batch_size, - seq_len, - self.kv_caches, - is_prompt=True) - xm.wait_device_ops() - logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - if seq_len >= self.model_config.max_model_len: - break - num_tokens = batch_size * seq_len - if num_tokens >= self.scheduler_config.max_num_batched_tokens: - break - seq_len = seq_len * 2 - - end = time.perf_counter() - logger.info("Compilation for prefill done in %.2f s.", end - start) - - # Decode shapes. - start = time.time() - seq_len = 1 - batch_size = 8 # Must be in sync with _get_padded_batch_size() - while True: - self._dummy_run(batch_size, - seq_len, - self.kv_caches, - is_prompt=False) - xm.wait_device_ops() - logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - - if batch_size >= self.scheduler_config.max_num_seqs: - break - batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 - - end = time.time() - logger.info("Compilation for decode done in %.2f s.", end - start) - - def initialize_kv_cache(self, num_blocks: int) -> None: - assert len(self.kv_caches) == 0 - kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) - for _ in range(self.num_attn_layers): - self.kv_caches.append(( - torch.zeros(kv_cache_shape, - dtype=self.kv_cache_dtype, - device=self.device), - torch.zeros(kv_cache_shape, - dtype=self.kv_cache_dtype, - device=self.device), - )) - - -@dataclass -class CachedRequestState: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - multi_modal_data: Optional["MultiModalDataDict"] - sampling_params: SamplingParams - generator: Optional[torch.Generator] - - block_ids: List[int] - num_computed_tokens: int - output_token_ids: List[int] - - @property - def num_tokens(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - -class InputBatch: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_blocks_per_req: int, - device: torch.device, - pin_memory: bool, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req - self.device = device - self.pin_memory = pin_memory - - self.req_ids: List[Optional[str]] = [None] * max_num_reqs - self.req_id_to_index: Dict[str, int] = {} - - self.token_ids_cpu = np.zeros((max_num_reqs, max_model_len), - dtype=np.int32) - self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32) - - # Attention-related. - self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32) - self.block_table_cpu_tensor = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.block_table_cpu = self.block_table_cpu_tensor.numpy() - - # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.temperature_cpu = self.temperature_cpu_tensor.numpy() - self.greedy_reqs: Set[str] = set() - self.random_reqs: Set[str] = set() - - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.top_p_cpu = self.top_p_cpu_tensor.numpy() - self.top_p_reqs: Set[str] = set() - - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.top_k_cpu = self.top_k_cpu_tensor.numpy() - self.top_k_reqs: Set[str] = set() - - # req_index -> generator - self.generators: Dict[int, torch.Generator] = {} - - self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() - - self.num_prefills = 0 - - def add_request( - self, - request: "CachedRequestState", - req_index: Optional[int] = None, - ) -> None: - if req_index is None: - req_index = self.num_reqs - assert req_index < self.max_num_reqs - - req_id = request.req_id - self.req_ids[req_index] = req_id - self.req_id_to_index[req_id] = req_index - - # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - - self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - num_blocks = len(request.block_ids) - self.block_table_cpu[req_index, :num_blocks] = request.block_ids - - sampling_params = request.sampling_params - self.temperature_cpu[req_index] = sampling_params.temperature - if sampling_params.sampling_type == SamplingType.GREEDY: - self.greedy_reqs.add(req_id) - else: - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - self.top_k_cpu[req_index] = sampling_params.top_k - if sampling_params.top_k > 0: - self.top_k_reqs.add(req_id) - - self.generators[req_index] = request.generator - - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) - - def remove_request(self, req_id: str) -> Optional[int]: - req_index = self.req_id_to_index.pop(req_id, None) - if req_index is None: - return None - self.req_ids[req_index] = None - - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.generators.pop(req_index, None) - self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) - return req_index - - def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.generators.clear() - self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() - - def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: - # The batched states are empty. - return - - # NOTE(woosuk): This function assumes that the empty_req_indices - # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 - while empty_req_indices: - # Find the largest non-empty index. - while last_req_index in empty_req_indices: - last_req_index -= 1 - - # Find the smallest empty index. - empty_index = empty_req_indices.pop() - if empty_index >= last_req_index: - break - - # Swap the states. - req_id = self.req_ids[last_req_index] - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None - self.req_id_to_index[req_id] = empty_index - - # TODO(woosuk): Optimize the copy of token_ids_cpu and - # block_table_cpu. - self.token_ids_cpu[empty_index] = self.token_ids_cpu[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table_cpu[empty_index] = self.block_table_cpu[ - last_req_index] - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] - self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] - self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - generator = self.generators.pop(last_req_index, None) - if generator is not None: - self.generators[empty_index] = generator - - # Decrement last_req_index since it is now empty. - last_req_index -= 1 - - def make_sampling_metadata( - self, - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - ) - - @property - def num_reqs(self) -> int: - return len(self.req_id_to_index) - - @property - def num_decodes(self) -> int: - return self.num_reqs - self.num_prefills - - @property - def all_greedy(self) -> bool: - return len(self.random_reqs) == 0 - - @property - def all_random(self) -> bool: - return len(self.greedy_reqs) == 0 - - @property - def no_top_p(self) -> bool: - return len(self.top_p_reqs) == 0 - - @property - def no_top_k(self) -> bool: - return len(self.top_k_reqs) == 0 - - @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 - - @property - def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0 - - -class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): - - def __init__(self, model: nn.Module): - self.model = model - compiled_callable = torch.compile(self.forward, - backend="openxla", - fullgraph=True, - dynamic=False) - super().__init__(compiled_callable) - - def __call__(self, *args, is_prompt: bool, **kwargs): - if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: - # not fully compiled yet, or not using the custom dispatcher, - # let PyTorch handle it - return self.compiled_callable(*args, **kwargs) - # the 3 compiled codes are: - # 0: for profiling - # 1: for prompt - # 2: for decode - # dispatch to the compiled code directly, skip PyTorch - if is_prompt: - with self.dispatch_to_code(1): - return self.forward(*args, **kwargs) - else: - with self.dispatch_to_code(2): - return self.forward(*args, **kwargs) - - def forward( - self, - token_ids: torch.Tensor, - position_ids: torch.Tensor, - attn_metadata: PallasAttentionMetadata, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> torch.Tensor: - """Executes the forward pass of the model and samples the next token. - - Args: - token_ids: The input token IDs of shape [batch_size, seq_len]. - position_ids: The input position IDs of shape [batch_size, seq_len]. - attn_metadata: The Pallas attention metadata. - kv_caches: The key and value caches. They can be None during the - memory profiling at initialization. - """ - - # Skip this in memory profiling at initialization. - if kv_caches[0][0].numel() > 0: - # index_copy_(slot_mapping) only works when the inserted dimension - # is 0. However, the KV cache in the Pallas backend has the shape - # [num_kv_heads, num_blocks, block_size, head_size]. To make it - # work, we need to flatten the first three dimensions and modify - # the slot_mapping accordingly. - num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape - slot_mapping = attn_metadata.slot_mapping - slot_mapping = slot_mapping.flatten() - head_indicies = torch.arange(0, - num_kv_heads, - device=slot_mapping.device, - dtype=slot_mapping.dtype) - head_indicies *= block_size * num_blocks - slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( - -1, num_kv_heads) - slot_mapping = slot_mapping + head_indicies.view(1, -1) - slot_mapping = slot_mapping.flatten() - attn_metadata.slot_mapping = slot_mapping - - hidden_states = self.model( - token_ids, - position_ids, - kv_caches, - attn_metadata, - ) - hidden_states = hidden_states.flatten(0, 1) - logits = self.model.compute_logits(hidden_states, None) - - # Greedy sampling. - argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - return argmax_token_ids.squeeze(dim=1) - - -def _get_padded_batch_size(batch_size: int) -> int: - # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. - # To meet this requirement in the simplest way, we set the minimal batch - # size to 8. - if batch_size <= 8: - return 8 - else: - return ((batch_size + 15) // 16) * 16 - - -def _get_padded_prefill_len(x: int) -> int: - # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence - # length to be a multiple of 16. We pad the prompt length to the nearest - # multiple of 16. This is also good for performance. - if x <= 16: - return 16 - return 1 << (x - 1).bit_length() diff --git a/vllm/v1/worker/tpu_model_runner_new.py b/vllm/v1/worker/tpu_model_runner_new.py index 9d910be89e640..2637319a492c9 100644 --- a/vllm/v1/worker/tpu_model_runner_new.py +++ b/vllm/v1/worker/tpu_model_runner_new.py @@ -1,13 +1,20 @@ import gc import time +import enum from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Tuple, cast, Optional +from unittest.mock import patch import numpy as np import torch import torch.distributed import torch.nn as nn +# TPU XLA related +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +from vllm.attention import AttentionMetadata from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context @@ -18,7 +25,7 @@ from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, is_pin_memory_available) -from vllm.v1.attention.backends.pallas import PallasMetadata +from vllm.v1.attention.backends.pallas import PallasMetadata, PallasAttentionBackend from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -30,8 +37,22 @@ logger = init_logger(__name__) # Here we utilize the behavior that out-of-bound index is ignored. -# FIXME: Find a more reliable way to prevent possible bugs. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. _PAD_SLOT_ID = 1_000_000_000 +# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. +_ENABLE_TOP_P = False +# FIXME(woosuk): A temporary hack to support `n > 1`. +# This can significantly affect the performance if too large. +_MAX_NUM_SAMPLES = 128 + + +class ExecutionMode(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + PREFIX_PREFILL = enum.auto() + + def is_prefill(self) -> bool: + return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) @dataclass @@ -74,6 +95,7 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + self.device_config = vllm_config.device_config model_config = self.model_config cache_config = self.cache_config @@ -753,6 +775,84 @@ def _gather_encoder_outputs( encoder_outputs.append(encoder_output[start_idx:end_idx]) return encoder_outputs + def execute_model_xxx(): + prefill_data, decode_data = self._prepare_inputs(scheduler_output) + num_reqs = self.input_batch.num_reqs + sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) + + ######################### DECODES ######################### + # Decodes run as one single batch with [padded_batch, 1] + if decode_data.num_decodes > 0: + + # FORWARD. + selected_token_ids = self.model(decode_data.token_ids, + decode_data.position_ids, + decode_data.attn_metadata, + self.kv_caches, + is_prompt=False) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] + sampled_token_ids_list = token_ids.tolist() + sampled_token_ids[:decode_data.num_decodes] = token_ids + + # UPDATE REQUEST STATE. + for i, req_id in enumerate( + self.input_batch.req_ids[:decode_data.num_decodes]): + req_state = self.requests[req_id] + + # TODO: ASSERT NO CHUNKED PREFILL. + assert scheduler_output.num_scheduled_tokens[req_id] == 1 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len == req_state.num_tokens + + token_id = sampled_token_ids_list[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + ######################### PREFILLS ######################### + # Prefills run separately with shape [1, padded_prefill_len], + # due to lack of variable length attention kernel so far. + for idx, (req_id, prompt_len, token_ids, position_ids, + attn_metadata) in enumerate(prefill_data.zipped()): + + # FORWARD. + selected_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + self.kv_caches, + is_prompt=True) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_id = selected_token_ids.cpu()[prompt_len - 1].item() + sampled_token_ids[decode_data.num_decodes + idx] = token_id + req_state = self.requests[req_id] + + # TODO: ASSERT NO PREFIX CACHING. + assert req_state.num_computed_tokens == 0 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + # TODO: ASSERT NO CHUNKED PREFILL. + assert seq_len == req_state.num_tokens + assert prompt_len == seq_len + + # UPDATE REQUEST STATE. + req_idx = self.input_batch.req_id_to_index[req_id] + self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids[:num_reqs], + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids_cpu=sampled_token_ids, + logprob_token_ids_cpu=None, + logprobs_cpu=None, + ) + @torch.inference_mode() def execute_model( self, @@ -760,102 +860,152 @@ def execute_model( ) -> ModelRunnerOutput: self._update_states(scheduler_output) - if self.is_multimodal_model: - # Run the multimodal encoder if any. - self._execute_encoder(scheduler_output) - encoder_outputs = self._gather_encoder_outputs(scheduler_output) - else: - encoder_outputs = [] + # TODO: Ressurect this code + # if self.is_multimodal_model: + # # Run the multimodal encoder if any. + # self._execute_encoder(scheduler_output) + # encoder_outputs = self._gather_encoder_outputs(scheduler_output) + # else: + # encoder_outputs = [] # Prepare the decoder inputs. - attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if (self.use_cuda_graph - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_scheduled_tokens) - else: - # Eager mode. - num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens - - if self.is_multimodal_model: - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - input_ids = self.input_ids[:num_scheduled_tokens] - if encoder_outputs: - inputs_embeds = self.model.get_input_embeddings( - input_ids, encoder_outputs) - else: - inputs_embeds = self.model.get_input_embeddings(input_ids) - # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) - inputs_embeds = self.inputs_embeds[:num_input_tokens] - input_ids = None - else: - # For text-only models, we use token ids as input. - # While it is possible to use embeddings as input just like the - # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the CUDA graph. - input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = None - - # Run the decoder. - # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( - input_ids=input_ids, - positions=self.positions[:num_input_tokens], - kv_caches=self.kv_caches, - attn_metadata=None, - inputs_embeds=inputs_embeds, - ) - hidden_states = hidden_states[:num_scheduled_tokens] - hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(hidden_states, None) - - # Sample the next token and get logprobs if needed. - sampling_metadata = self._prepare_sampling(scheduler_output) - sampler_output = self.model.sample( - logits=logits, - sampling_metadata=sampling_metadata, - ) + prefill_data, decode_data = self._prepare_inputs(scheduler_output) - sampled_token_ids = sampler_output.sampled_token_ids - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - assert req_id is not None + sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) + + # attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_input_tokens = num_scheduled_tokens + # attn_metadata.num_input_tokens = num_input_tokens + + # TODO: Resurrect this code + # if self.is_multimodal_model: + # # NOTE(woosuk): To unify token ids and soft tokens (vision + # # embeddings), we always use embeddings (rather than token ids) + # # as input to the multimodal model, even when the input is text. + # input_ids = self.input_ids[:num_scheduled_tokens] + # if encoder_outputs: + # inputs_embeds = self.model.get_input_embeddings( + # input_ids, encoder_outputs) + # else: + # inputs_embeds = self.model.get_input_embeddings(input_ids) + # # TODO(woosuk): Avoid the copy. Optimize. + # self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) + # inputs_embeds = self.inputs_embeds[:num_input_tokens] + # input_ids = None + # else: + # # For text-only models, we use token ids as input. + # # While it is possible to use embeddings as input just like the + # # multimodal models, it is not desirable for performance since + # # then the embedding layer is not included in the CUDA graph. + # input_ids = self.input_ids[:num_input_tokens] + # inputs_embeds = None + + ######################### DECODES ######################### + # Decodes run as one single batch with [padded_batch, 1] + if decode_data.num_decodes > 0: + # FORWARD. + selected_token_ids = self.model(decode_data.token_ids, + decode_data.position_ids, + decode_data.attn_metadata, + self.kv_caches, + is_prompt=False) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] + sampled_token_ids_list = token_ids.tolist() + sampled_token_ids[:decode_data.num_decodes] = token_ids + + # UPDATE REQUEST STATE. + for i, req_id in enumerate( + self.input_batch.req_ids[:decode_data.num_decodes]): + req_state = self.requests[req_id] + + # TODO: ASSERT NO CHUNKED PREFILL. + assert scheduler_output.num_scheduled_tokens[req_id] == 1 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len == req_state.num_tokens + + # TODO: Verify if req_id_to_index mapping is needed here! + token_id = sampled_token_ids_list[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + ######################### PREFILLS ######################### + # Prefills run separately with shape [1, padded_prefill_len], + # due to lack of variable length attention kernel so far. + for idx, (req_id, prompt_len, token_ids, position_ids, + attn_metadata) in enumerate(prefill_data.zipped()): + # FORWARD. + selected_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + self.kv_caches, + is_prompt=True) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_id = selected_token_ids.cpu()[prompt_len - 1].item() + sampled_token_ids[decode_data.num_decodes + idx] = token_id req_state = self.requests[req_id] + + # TODO: ASSERT NO PREFIX CACHING. + assert req_state.num_computed_tokens == 0 seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) - assert seq_len <= req_state.num_tokens - if seq_len == req_state.num_tokens: - # Append the sampled token to the output token ids. - token_id = sampled_token_ids[i] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - self.input_batch.num_tokens[i] += 1 - req_state.output_token_ids.append(token_id) - else: - # Ignore the sampled token from the partial request. - # Rewind the generator state as if the token was not sampled. - generator = self.input_batch.generators.get(i) - if generator is not None: - # This relies on cuda-specific torch-internal impl details - generator.set_offset(generator.get_offset() - 4) - - if sampler_output.logprob_token_ids is None: - logprob_token_ids = None - else: - logprob_token_ids = sampler_output.logprob_token_ids.cpu() - if sampler_output.logprobs is None: - logprobs = None - else: - logprobs = sampler_output.logprobs.cpu() + + # TODO: ASSERT NO CHUNKED PREFILL. + assert seq_len == req_state.num_tokens + assert prompt_len == seq_len + + # UPDATE REQUEST STATE. + req_idx = self.input_batch.req_id_to_index[req_id] + self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + # TODO: Remove + # # Sample the next token and get logprobs if needed. + # sampling_metadata = self._prepare_sampling(scheduler_output) + # sampler_output = self.model.sample( + # logits=logits, + # sampling_metadata=sampling_metadata, + # ) + + # sampled_token_ids = sampler_output.sampled_token_ids + # # TODO(woosuk): The following loop can be slow since it iterates over + # # the requests one by one. Optimize. + # num_reqs = self.input_batch.num_reqs + # for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + # assert req_id is not None + # req_state = self.requests[req_id] + # seq_len = (req_state.num_computed_tokens + + # scheduler_output.num_scheduled_tokens[req_id]) + # assert seq_len <= req_state.num_tokens + # if seq_len == req_state.num_tokens: + # # Append the sampled token to the output token ids. + # token_id = sampled_token_ids[i] + # self.input_batch.token_ids_cpu[i, seq_len] = token_id + # self.input_batch.num_tokens[i] += 1 + # req_state.output_token_ids.append(token_id) + # else: + # # Ignore the sampled token from the partial request. + # # Rewind the generator state as if the token was not sampled. + # generator = self.input_batch.generators.get(i) + # if generator is not None: + # # This relies on cuda-specific torch-internal impl details + # generator.set_offset(generator.get_offset() - 4) + + # if sampler_output.logprob_token_ids is None: + # logprob_token_ids = None + # else: + # logprob_token_ids = sampler_output.logprob_token_ids.cpu() + # if sampler_output.logprobs is None: + # logprobs = None + # else: + # logprobs = sampler_output.logprobs.cpu() # num_reqs entries should be non-None assert all( @@ -866,45 +1016,161 @@ def execute_model( model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=sampled_token_ids, - logprob_token_ids_cpu=logprob_token_ids, - logprobs_cpu=logprobs, + sampled_token_ids_cpu=sampled_token_ids, + logprob_token_ids_cpu=None, + logprobs_cpu=None, ) + return model_runner_output def load_model(self) -> None: - logger.info("Starting to load model %s...", self.model_config.model) - with DeviceMemoryProfiler() as m: # noqa: SIM117 - self.model = get_model(vllm_config=self.vllm_config) - - self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GB", - self.model_memory_usage / float(2**30)) + self.device = self.device_config.device + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + xm_tp_rank = xr.global_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.wait_device_ops() + model = ModelWrapper(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) @torch.inference_mode() def _dummy_run( self, - model: nn.Module, - num_tokens: int, - kv_caches: List[torch.Tensor], - ) -> torch.Tensor: - if self.is_multimodal_model: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] + batch_size: int, + seq_len: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + exec_mode: ExecutionMode, + ) -> None: + exec_mode = ExecutionMode(exec_mode) + if exec_mode.is_prefill(): + seq_len = (seq_len + 15) // 16 * 16 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + if exec_mode == ExecutionMode.PREFILL: + attn_metadata = PallasMetadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=None, + context_lens=None, + effective_query_lens=None, + ) + + else: + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + + block_tables = torch.zeros( + (batch_size, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device) + + effective_query_lens = torch.ones_like(context_lens) + + attn_metadata = PallasMetadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=effective_query_lens, + ) else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None - with set_forward_context(None, self.vllm_config): - hidden_states = model( - input_ids=input_ids, - positions=self.positions[:num_tokens], - kv_caches=kv_caches, - attn_metadata=None, - inputs_embeds=inputs_embeds, + assert seq_len == 1 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (batch_size, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device) + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + attn_metadata = PallasMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size * seq_len, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=block_tables, + context_lens=context_lens, ) - return hidden_states + + t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 + + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if exec_mode.is_prefill(): + # Prefll + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + # Decode + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(input_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(t, 0) + torch._dynamo.mark_dynamic(p, 0) + + # Dummy run. + self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, + num_samples, kv_caches) def profile_run(self) -> None: + """Profile to measure peak memory during forward pass.""" + # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value `None`. # the `dtype` argument does not matter, and we use `float32` as @@ -912,152 +1178,184 @@ def profile_run(self) -> None: # it is important to create tensors inside the loop, rather than # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. - dummy_kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(self.num_attn_layers) - ] - - # Profile with multimodal encoder & encoder cache. - if self.is_multimodal_model: - - # Create dummy batch of multimodal inputs. - dummy_request_data = self.input_registry.dummy_data_for_profiling( - model_config=self.model_config, - seq_len=self.max_num_tokens, - mm_registry=self.mm_registry, - ) - dummy_mm_data = dummy_request_data.multi_modal_data - - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501 - self.model_config) - - dummy_data_modality, max_tokens_per_mm_item = max( - max_tokens_by_modality_dict.items(), key=lambda item: item[1]) - - # Check how many items of this modality can be supported by - # the encoder cache budget. - encoder_cache_budget = min(self.max_num_encoder_input_tokens, - self.encoder_cache_size) - max_num_mm_items_encoder_budget = encoder_cache_budget // \ - max_tokens_per_mm_item - - # TODO: Allow users to set encoder_cache_budget in case this - # happens. - assert max_num_mm_items_encoder_budget > 0, ( - f"Encoder cache budget={encoder_cache_budget} is too small to " - f"support the maximum possible size of multimodal embeddings" - f"={max_tokens_per_mm_item}.") - - # Check how many items of this modality can be supported by - # the decoder budget. - max_mm_items_per_req = max( - self.mm_registry.get_mm_limits_per_prompt( - self.model_config).values()) - - # NOTE: We do not consider max_num_batched_tokens on purpose - # because the multimodal embeddings can be generated in advance - # and chunked prefilled. - max_num_mm_items_decoder_budget = self.max_num_reqs * \ - max_mm_items_per_req - - max_num_mm_items = min(max_num_mm_items_encoder_budget, - max_num_mm_items_decoder_budget) - - # Dummy data definition in V0 may contain multiple multimodal items - # (e.g, multiple images) for a single request, therefore here we - # always replicate first item by max_num_mm_items times since in V1 - # they are scheduled to be processed separately. - - # Case when models have a merged processor, their dummy data is - # already batched `MultiModalKwargs`, therefore we take the first - # `MultiModalKwargsItem` from the desired modality to profile on. - if isinstance(dummy_mm_data, MultiModalKwargs): - dummy_mm_item = dummy_mm_data.get_item( - modality=dummy_data_modality, item_index=0) - dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) - - # Case when models have dummy data explicitly defined as - # `MultiModalDataDict`, so they need to be processed through input - # mapper. - # TODO (ywang96): deprecate this path once merged processor is - # supported on all models. - else: - mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs( - mm_data=dummy_mm_data, - mm_hashes=None, - mm_processor_kwargs=None, - precomputed_mm_inputs=None) - dummy_mm_kwargs = mm_kwargs_list[0] - - batched_dummy_mm_inputs = MultiModalKwargs.batch( - [dummy_mm_kwargs] * max_num_mm_items) - batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( - batched_dummy_mm_inputs, device=self.device) - - # Run multimodal encoder. - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) - assert len(dummy_encoder_outputs) == max_num_mm_items, ( - "Expected dimension 0 of encoder outputs to match the number " - f"of multimodal data items: {max_num_mm_items}, got " - f"{len(dummy_encoder_outputs)=} instead. This is most likely " - "due to the 'get_multimodal_embeddings' method of the model " - "not implemented correctly.") - - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - - # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.model, self.max_num_tokens, - dummy_kv_caches) - logits = self.model.compute_logits(hidden_states, None) - logits = logits[:self.max_num_tokens] - # TODO(woosuk): Consider the memory usage of the sampler. - torch.cuda.synchronize() - del hidden_states, logits - self.encoder_cache.clear() - gc.collect() + dummy_kv_caches = [( + torch.tensor([], dtype=torch.float32, device=self.device), + torch.tensor([], dtype=torch.float32, device=self.device), + ) for _ in range(self.num_attn_layers)] + + # Run empty forward. + self._dummy_run( + batch_size=1, + seq_len=self.max_num_tokens, # Will be rounded to 16 multiple + kv_caches=dummy_kv_caches, + exec_mode=ExecutionMode.PREFILL) def capture_model(self) -> None: - if not self.use_cuda_graph: - logger.warning( - "Skipping CUDA graph capture. Please add " - "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE) - return + """Compile the model.""" + + logger.info("Compiling the model with different input shapes.") + + # Capture prefill shapes + start = time.perf_counter() + for batch_size in [1]: + seq_len = 16 + while True: + self._dummy_run(batch_size, + seq_len, + self.kv_caches, + exec_mode=ExecutionMode.PREFILL) + xm.wait_device_ops() + logger.info(" -- batch_size: %d, seq_len: %d", batch_size, + seq_len) + + if seq_len >= self.model_config.max_model_len: + break - start_time = time.perf_counter() - start_free_gpu_memory = torch.cuda.mem_get_info()[0] - - # Trigger CUDA graph capture for specific shapes. - # Capture the large shapes first so that the smaller shapes - # can reuse the memory pool allocated for the large shapes. - with graph_capture(device=self.device): - for num_tokens in reversed(self.cudagraph_batch_sizes): - for _ in range(self.vllm_config.compilation_config. - cudagraph_num_of_warmups): - self._dummy_run(self.model, num_tokens, self.kv_caches) - self._dummy_run(self.model, num_tokens, self.kv_caches) - - end_time = time.perf_counter() - end_free_gpu_memory = torch.cuda.mem_get_info()[0] - elapsed_time = end_time - start_time - cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory - # This usually takes 5~20 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / (1 << 30)) + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + + # Move to next seq_len + seq_len = seq_len * 2 + + end = time.perf_counter() + logger.info("Compilation for prefill shapes is done in %.2f [secs].", + end - start) + + # Capture decode shapes. + start = time.time() + seq_len = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() + while True: + self._dummy_run(batch_size, + seq_len, + self.kv_caches, + exec_mode=ExecutionMode.DECODE) + xm.wait_device_ops() + logger.info(" -- batch_size: %d, seq_len: %d", batch_size, + seq_len) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode shapes is done in %.2f [secs].", + end - start) def initialize_kv_cache(self, num_blocks: int) -> None: assert len(self.kv_caches) == 0 - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) for _ in range(self.num_attn_layers): - self.kv_caches.append( + self.kv_caches.append(( + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device), torch.zeros(kv_cache_shape, dtype=self.kv_cache_dtype, - device=self.device)) + device=self.device), + )) + + +# TODO: This is duplicate from V0, refactor +class ModelWrapper(nn.Module): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + input_lens: torch.Tensor, + t: torch.Tensor, + p: torch.Tensor, + num_samples: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + attn_metadata: The Pallas attention metadata. + input_lens: The actual input lengths of shape [batch_size]. + t: The sampling temperature of shape [batch_size]. + p: The top-p probability of shape [batch_size]. + num_samples: Number of samples to draw from each logits vector. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + batch_size, seq_len = token_ids.shape + # Calculate the positions to sample from. + start_indicies = torch.arange( + batch_size, dtype=torch.int32, device=input_lens.device) * seq_len + logits_indices = start_indicies + input_lens - 1 + + # FIXME(woosuk): This is a temporary hack to avoid using the existing + # sampler and sampling metadata. + sampling_metadata = SamplingMetadata( + seq_groups=[], + selected_token_indices=logits_indices, + categorized_sample_indices={}, + num_prompts=attn_metadata.num_prefills, + ) + + # Skip this in memory profiling at initialization. + if kv_caches[0][0].numel() > 0: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + attn_metadata, + ) + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Argmax sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + argmax_token_ids = argmax_token_ids.repeat(1, num_samples) + + # Zero temperature means greedy decoding. Avoid division by zero. + nonzero_t = torch.where(t != 0, t, 1.0) + logits = logits / nonzero_t.unsqueeze(dim=1) + if _ENABLE_TOP_P: + logits = _apply_top_p(logits, p.unsqueeze(dim=1)) + + # Random sampling. + probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + sampled_token_ids = torch.multinomial(probs, + num_samples, + replacement=True) + if num_samples == 1: + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + sampled_token_ids = sampled_token_ids.squeeze(dim=-1) + next_token_ids = torch.where(t != 0, sampled_token_ids, + argmax_token_ids) + return next_token_ids # TODO: Duplicate with V0, refactor @@ -1079,3 +1377,13 @@ def _get_padded_batch_size(batch_size: int) -> int: return 8 else: return ((batch_size + 15) // 16) * 16 + + +# TODO: Duplicate with V0, refactor +def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + logits_sorted = torch.sort(logits, dim=-1, descending=True).values + sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1) + cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True) + cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) + logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) + return logits From 35d139d3332abd987fa6058c5af58e33909226e8 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 15:04:45 +0000 Subject: [PATCH 03/18] add tpu worker --- vllm/v1/worker/__tpu_worker.py | 198 --------------------------------- 1 file changed, 198 deletions(-) delete mode 100644 vllm/v1/worker/__tpu_worker.py diff --git a/vllm/v1/worker/__tpu_worker.py b/vllm/v1/worker/__tpu_worker.py deleted file mode 100644 index 866c1dbf6ea98..0000000000000 --- a/vllm/v1/worker/__tpu_worker.py +++ /dev/null @@ -1,198 +0,0 @@ -"""A TPU worker class.""" - -import os -from typing import TYPE_CHECKING, Tuple - -import torch -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr - -import vllm.envs as envs -from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.logger import init_logger -from vllm.model_executor import set_random_seed -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.worker.tpu_model_runner import TPUModelRunner - -if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput - -logger = init_logger(__name__) - - -class TPUWorker: - - def __init__(self, vllm_config: VllmConfig, local_rank: int, rank: int, - distributed_init_method: str): - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - - def initialize(self): - os.environ["PJRT_DEVICE"] = "TPU" - torch.set_grad_enabled(False) - torch.set_default_dtype(self.model_config.dtype) - - # NOTE: This is just to initialize the TP group and broadcast - # the input objects on CPU. The all-reduce and all-gather ops on TPU - # are invoked by `xm.all_reduce` and `xm.all_gather` which use their - # own context. - init_distributed_environment( - world_size=self.parallel_config.world_size, - rank=self.rank, - local_rank=self.local_rank, - distributed_init_method=self.distributed_init_method, - backend="gloo", - ) - ensure_model_parallel_initialized( - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size) - - # Device initialization should happen after initializing the distributed - # runtime. - self.device = xm.xla_device() - self.device_config.device = self.device - - # Init ModelRunner here, so that we have access to self.device. - self.model_runner = TPUModelRunner(self.vllm_config) - - # Set random seed. - set_random_seed(self.model_config.seed) - xm.set_rng_state(self.model_config.seed, self.device) - - # Increase the cache size limit, which is the maximum number of - # dynamo graphs that can be compiled. - # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and - # 30-40 graphs for decode. 128 is an arbitrary safe number. - torch._dynamo.config.cache_size_limit = 128 - # Use persistent cache to avoid XLA recompilation. - # NOTE(woosuk): Set per-rank cache path since different ranks - # can have slightly different XLA graphs. - world_size = self.parallel_config.world_size - rank = xr.global_ordinal() - per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, - f"tp{world_size}_rank{rank}") - xr.initialize_cache(per_rank_path, readonly=False) - - def load_model(self): - self.model_runner.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - - self.model_runner.profile_run() - - # Synchronize before measuring the memory usage. - xm.wait_device_ops() - - # Get the maximum amount of memory used by the model weights and - # intermediate activations. - m = xm.get_memory_info(self.device) - total_tpu_memory = m["bytes_limit"] - peak_memory = m[ - "peak_bytes_used"] # Weights + intermediate activations. - logger.debug("Peak Used: %sGB", peak_memory // 1024 // 1024 // 1024) - logger.debug("Total Memory: %sGB", - total_tpu_memory // 1024 // 1024 // 1024) - - cache_block_size = _get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - num_tpu_blocks = int( - (total_tpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) - num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 - return num_tpu_blocks, 0 - - def initialize_cache(self, num_tpu_blocks: int) -> None: - """Allocate TPU and CPU KV cache with the specified number of blocks.""" - - if num_tpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - - max_seq_len = self.cache_config.block_size * num_tpu_blocks - max_model_len = self.model_config.max_model_len - if max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") - - self.model_runner.initialize_kv_cache(num_tpu_blocks) - - # Get the maximum amount of memory used by the model weights and - # intermediate activations. - xm.mark_step() - xm.wait_device_ops() - m = xm.get_memory_info(self.device) - peak_memory = m[ - "peak_bytes_used"] # Weights + intermediate activations. - logger.debug("Peak GB Used Post KV Cache: %sGB", - peak_memory // 1024 // 1024 // 1024) - - def compile_or_warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model() - - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) - - def execute_model( - self, - scheduler_output: "SchedulerOutput", - ) -> ModelRunnerOutput: - output = self.model_runner.execute_model(scheduler_output) - # TODO(woosuk): Send the output to the engine process. - return output - - -# TODO: this is a duplicate. -def _get_cache_block_size( - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig, -) -> int: - head_size = model_config.get_head_size() - num_heads = model_config.get_num_kv_heads(parallel_config) - num_attention_layers = model_config.get_num_attention_layers( - parallel_config) - - key_cache_block = cache_config.block_size * num_heads * head_size - value_cache_block = key_cache_block - total = num_attention_layers * (key_cache_block + value_cache_block) - if cache_config.cache_dtype == "auto": - dtype = model_config.dtype - else: - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - dtype_size = get_dtype_size(dtype) - return dtype_size * total From 2656fb298ae8f482e5e07d32bb0fccd4e469747b Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 15:07:22 +0000 Subject: [PATCH 04/18] add files --- vllm/v1/worker/tpu_worker_new.py | 244 +++++++++++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 vllm/v1/worker/tpu_worker_new.py diff --git a/vllm/v1/worker/tpu_worker_new.py b/vllm/v1/worker/tpu_worker_new.py new file mode 100644 index 0000000000000..c696ae5f5349d --- /dev/null +++ b/vllm/v1/worker/tpu_worker_new.py @@ -0,0 +1,244 @@ +"""A GPU worker class.""" +import gc +import os +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.distributed +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size +from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.tpu_model_runner_new import TPUModelRunner + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + + +class TPUWorker: + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + ): + + # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def initialize(self): + os.environ["PJRT_DEVICE"] = "TPU" + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=self.parallel_config.world_size, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size) + + # Device initialization should happen after initializing the distributed + # runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = TPUModelRunner(self.vllm_config, self.device) + + def load_model(self) -> None: + self.model_runner.load_model() + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + + self.model_runner.profile_run() + + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_tpu_memory = m["bytes_limit"] + peak_memory = m[ + "peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak Used: %sGB", peak_memory // 1024 // 1024 // 1024) + logger.debug("Total Memory: %sGB", + total_tpu_memory // 1024 // 1024 // 1024) + + cache_block_size = _get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + num_tpu_blocks = int( + (total_tpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 + return num_tpu_blocks, 0 + + def initialize_cache(self, num_tpu_blocks: int) -> None: + """Allocate TPU and CPU KV cache with the specified number of blocks.""" + if num_tpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_tpu_blocks + max_model_len = self.model_config.max_model_len + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + self.model_runner.initialize_kv_cache(num_tpu_blocks) + + # For debug: Get the maximum amount of memory used by the model weights and + # intermediate activations. + # TODO: Remove this? + xm.mark_step() + xm.wait_device_ops() + m = xm.get_memory_info(self.device) + peak_memory = m[ + "peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak GB Used Post KV Cache: %sGB", + peak_memory // 1024 // 1024 // 1024) + + def compile_or_warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + output = self.model_runner.execute_model(scheduler_output) + return output if self.rank == 0 else None + + def profile(self, is_start: bool = True): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + self.profiler.start() + else: + self.profiler.stop() + + def check_health(self) -> None: + # worker will always be healthy as long as it's running. + return + + +# TODO: Duplicate, refactor +def _get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, +) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + return dtype_size * total From 6a7633a92ce91e0685d0704e7f3d8344fa39c450 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 15:12:53 +0000 Subject: [PATCH 05/18] add executor --- vllm/v1/executor/uniproc_tpu_executor.py | 81 ++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 vllm/v1/executor/uniproc_tpu_executor.py diff --git a/vllm/v1/executor/uniproc_tpu_executor.py b/vllm/v1/executor/uniproc_tpu_executor.py new file mode 100644 index 0000000000000..957fe20355cee --- /dev/null +++ b/vllm/v1/executor/uniproc_tpu_executor.py @@ -0,0 +1,81 @@ +import os +from typing import Optional, Tuple + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.executor.abstract import Executor +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.tpu_worker_new import TPUWorker + +logger = init_logger(__name__) + + +class UniprocTPUExecutor(Executor): + + def __init__(self, vllm_config: VllmConfig) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.worker: TPUWorker = self._create_worker() + self.worker.initialize() + self.worker.load_model() + + def _create_worker( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> TPUWorker: + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + return TPUWorker( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + ) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.worker.determine_num_available_blocks() + + def initialize(self, num_tpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. + logger.info("# TPU blocks: %d", num_tpu_blocks) + self.worker.initialize_cache(num_tpu_blocks) + self.worker.compile_or_warm_up_model() + + def execute_model( + self, + scheduler_output, + ) -> ModelRunnerOutput: + output = self.worker.execute_model(scheduler_output) + return output + + def profile(self, is_start: bool = True): + self.worker.profile(is_start) + + def shutdown(self): + pass + + def check_health(self) -> None: + # UniprocTPUExecutor will always be healthy as long as + # it's running. + return From 56621b4cd49c7c2155624cabae04c6dca08b6646 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 15:35:20 +0000 Subject: [PATCH 06/18] store tmp --- vllm/v1/executor/__tpu_executor_v1.py | 80 ++ vllm/v1/worker/__tpu_model_runner_v1.py | 981 ++++++++++++++++++++++++ vllm/v1/worker/__tpu_worker_v1.py | 198 +++++ vllm/worker/__tpu_model_runner_v0.py | 835 ++++++++++++++++++++ 4 files changed, 2094 insertions(+) create mode 100644 vllm/v1/executor/__tpu_executor_v1.py create mode 100644 vllm/v1/worker/__tpu_model_runner_v1.py create mode 100644 vllm/v1/worker/__tpu_worker_v1.py create mode 100644 vllm/worker/__tpu_model_runner_v0.py diff --git a/vllm/v1/executor/__tpu_executor_v1.py b/vllm/v1/executor/__tpu_executor_v1.py new file mode 100644 index 0000000000000..5e6e63086946d --- /dev/null +++ b/vllm/v1/executor/__tpu_executor_v1.py @@ -0,0 +1,80 @@ +from typing import Optional, Tuple + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.tpu_worker import TPUWorker + +logger = init_logger(__name__) + +# import torch_xla.debug.profiler as xp + + +class TPUExecutor: + + def __init__(self, vllm_config: VllmConfig) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.worker = self._create_worker() + self.worker.initialize() + self.worker.load_model() + + # self.server = xp.start_server(9012) + + def _create_worker( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> TPUWorker: + """Return worker init args for a given rank.""" + + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + return TPUWorker( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + ) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.worker.determine_num_available_blocks() + + def initialize_cache(self, num_tpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. + logger.info("# TPU blocks: %d", num_tpu_blocks) + self.worker.initialize_cache(num_tpu_blocks) + self.worker.compile_or_warm_up_model() + + def execute_model( + self, + scheduler_output, + ) -> ModelRunnerOutput: + # xp.trace_detached('localhost:9012', "./profiles") + output = self.worker.execute_model(scheduler_output) + return output + + def check_health(self) -> None: + # TPUExecutor will always be healthy as long as + # it's running. + return diff --git a/vllm/v1/worker/__tpu_model_runner_v1.py b/vllm/v1/worker/__tpu_model_runner_v1.py new file mode 100644 index 0000000000000..7963fe4973b55 --- /dev/null +++ b/vllm/v1/worker/__tpu_model_runner_v1.py @@ -0,0 +1,981 @@ +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +import torch_xla.core.xla_model as xm + +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MultiModalDataDict +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_pin_memory_available +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, + PallasAttentionMetadata) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME: Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 + + +@dataclass +class PrefillInputData: + + request_ids: List + prompt_lens: List + token_ids: List + position_ids: List + attn_metadata: List + + def zipped(self): + return zip(self.request_ids, self.prompt_lens, self.token_ids, + self.position_ids, self.attn_metadata) + + +@dataclass +class DecodeInputData: + + num_decodes: int + token_ids: Optional[torch.Tensor] = None + position_ids: Optional[torch.Tensor] = None + attn_metadata: PallasAttentionMetadata = None + + +class TPUModelRunner: + + def __init__( + self, + vllm_config: VllmConfig, + ): + # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] + + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + + # Model-related. + self.num_attn_layers = model_config.get_num_attention_layers( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + + # List[k_cache, v_cache] + self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.scheduler_config.max_num_seqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + ) + + self.prefill_positions = torch.tensor( + range(self.max_model_len), + device="cpu", + ).to(torch.int32).reshape(1, -1) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove stopped requests from the cached states. + # Keep the states of the pre-empted requests. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + + # Remove the requests from the persistent batch. + stopped_req_ids = set().union( + scheduler_output.preempted_req_ids, + scheduler_output.finished_req_ids, + ) + removed_req_indices: List[int] = [] + for req_id in stopped_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Update the states of the running requests. + for req_data in scheduler_output.scheduled_running_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + req_index = self.input_batch.req_id_to_index[req_id] + + # Update the num_computed_tokens. + req_state.num_computed_tokens = req_data.num_computed_tokens + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + + # Update the block table. + num_new_blocks = len(req_data.new_block_ids) + if num_new_blocks == 0: + continue + start_index = len(req_state.block_ids) + end_index = start_index + num_new_blocks + req_state.block_ids.extend(req_data.new_block_ids) + self.input_batch.block_table_cpu[ + req_index, start_index:end_index] = req_data.new_block_ids + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for req_data in scheduler_output.scheduled_new_reqs: + req_id = req_data.req_id + sampling_params = req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=req_data.prompt_token_ids, + prompt=req_data.prompt, + multi_modal_data=req_data.multi_modal_data, + sampling_params=sampling_params, + generator=generator, + block_ids=req_data.block_ids, + num_computed_tokens=req_data.num_computed_tokens, + output_token_ids=[], + ) + req_ids_to_add.append(req_id) + + # Update the cached states of the resumed requests. + for req_data in scheduler_output.scheduled_resumed_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = req_data.block_ids + req_state.num_computed_tokens = req_data.num_computed_tokens + req_ids_to_add.append(req_id) + + # THIS MOVES ALL THE DECODES TO THE FIRST N IN BATCH. + # Condense the batched states if there are empty indices. + removed_req_indices = sorted(removed_req_indices, reverse=True) + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + # ALL THE PREFILLS ARE THE LAST M IN THE BATCH. + # These are added at the end after the bacth is condensed. + self.input_batch.num_prefills = len(req_ids_to_add) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + self.input_batch.add_request(req_state, None) + + def _prepare_prefill_inputs( + self, + num_scheduled_tokens: List[int], + ) -> PrefillInputData: + # Each prefill run separately with shape [1, padded_prompt_len]. + # So we create lists that will be used in execute_model(). + + prefill_request_ids = [] + prefill_prompt_lens = [] + prefill_token_ids = [] + prefill_position_ids = [] + prefill_attn_metadata = [] + + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = self.input_batch.num_decodes + for idx in range(num_decodes, num_reqs): + prefill_request_ids.append(self.input_batch.req_ids[idx]) + + # STATIC SHAPE: prefills are padded to the next power of 2. + prompt_len = num_scheduled_tokens[idx] + padded_prompt_len = _get_padded_prefill_len(prompt_len) + prefill_prompt_lens.append(prompt_len) + assert padded_prompt_len <= self.max_model_len + + # TOKEN_IDS. + token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ + idx, :padded_prompt_len].reshape(1, -1)) + prefill_token_ids.append(token_ids.to(self.device)) + + # POSITIONS. + positions = self.prefill_positions[:, :padded_prompt_len] + prefill_position_ids.append(positions.to(self.device)) + + # SLOT_MAPPING. + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_numbers = self.input_batch.block_table_cpu_tensor[ + idx, positions // self.block_size].reshape(1, -1) + block_offsets = positions % self.block_size + slot_mapping = block_numbers * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[:, prompt_len:] = _PAD_SLOT_ID + slot_mapping = slot_mapping.long() + + # ATTN_METADATA. + prefill_attn_metadata.append( + PallasAttentionMetadata( + is_prompt=True, + slot_mapping=slot_mapping.to(self.device), + block_tables=None, + context_lens=None, + )) + + return PrefillInputData( + request_ids=prefill_request_ids, + prompt_lens=prefill_prompt_lens, + token_ids=prefill_token_ids, + position_ids=prefill_position_ids, + attn_metadata=prefill_attn_metadata, + ) + + def _prepare_decode_inputs(self, num_decodes: int) -> DecodeInputData: + # Decodes run as one single padded batch with shape [batch, 1] + # + # We need to set _PAD_SLOT_ID for the padding tokens in the + # slot_mapping, such that the attention KV cache insertion + # logic knows to ignore those indicies. Otherwise, the + # padding data can be dummy since we have a causal mask. + + if num_decodes == 0: + return DecodeInputData(num_decodes=0) + + # PAD FOR STATIC SHAPES. + padded_batch_size = _get_padded_batch_size(num_decodes) + + # POSITIONS. [batch, 1] + # We slice at the end, since we use the positions for gathering. + positions = torch.from_numpy( + self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) + index = positions.to(torch.int64) + positions = positions[:padded_batch_size] + + # TOKEN_IDS. [batch, 1] + token_ids = torch.gather( + input=torch.from_numpy(self.input_batch.token_ids_cpu), + dim=1, + index=index, + )[:padded_batch_size] + + # SLOT_MAPPING [batch, 1] + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_number = torch.gather( + input=self.input_batch.block_table_cpu_tensor, + dim=1, + index=(index // self.block_size)) + block_offsets = index % self.block_size + slot_mapping = block_number * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[num_decodes:] = _PAD_SLOT_ID + slot_mapping = slot_mapping[:padded_batch_size] + + # BLOCK_TABLE [batch, max_num_blocks_per_req] + block_table = self.input_batch.block_table_cpu_tensor[: + padded_batch_size] + + # CONTEXT_LENS [batch_size] + context_lens = (positions.reshape(-1) + 1) + + # CPU<>TPU sync happens here. + return DecodeInputData(num_decodes=num_decodes, + token_ids=token_ids.to(self.device), + position_ids=positions.to(self.device), + attn_metadata=PallasAttentionMetadata( + is_prompt=False, + slot_mapping=slot_mapping.to(self.device), + block_tables=block_table.to(self.device), + context_lens=context_lens.to(self.device), + )) + + def _prepare_inputs( + self, scheduler_output: "SchedulerOutput" + ) -> Tuple[PrefillInputData, Optional[DecodeInputData]]: + + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + + num_reqs = self.input_batch.num_reqs + num_decodes = self.input_batch.num_decodes + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens.append(num_tokens) + + # NOTE: assert that all the decodes are "decodes". + if idx < num_decodes: + assert num_tokens == 1 + + return ( + self._prepare_prefill_inputs(num_scheduled_tokens), + self._prepare_decode_inputs(num_decodes), + ) + + def _prepare_sampling( + self, + scheduler_output: "SchedulerOutput", + ) -> SamplingMetadata: + skip_copy = True + if (scheduler_output.finished_req_ids + or scheduler_output.preempted_req_ids): + skip_copy = False + if (scheduler_output.scheduled_new_reqs + or scheduler_output.scheduled_resumed_reqs): + skip_copy = False + # Create the sampling metadata. + sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) + return sampling_metadata + + @torch.no_grad() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + self._update_states(scheduler_output) + prefill_data, decode_data = self._prepare_inputs(scheduler_output) + num_reqs = self.input_batch.num_reqs + sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) + + ######################### DECODES ######################### + # Decodes run as one single batch with [padded_batch, 1] + if decode_data.num_decodes > 0: + + # FORWARD. + selected_token_ids = self.model(decode_data.token_ids, + decode_data.position_ids, + decode_data.attn_metadata, + self.kv_caches, + is_prompt=False) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] + sampled_token_ids_list = token_ids.tolist() + sampled_token_ids[:decode_data.num_decodes] = token_ids + + # UPDATE REQUEST STATE. + for i, req_id in enumerate( + self.input_batch.req_ids[:decode_data.num_decodes]): + req_state = self.requests[req_id] + + # TODO: ASSERT NO CHUNKED PREFILL. + assert scheduler_output.num_scheduled_tokens[req_id] == 1 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len == req_state.num_tokens + + token_id = sampled_token_ids_list[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + ######################### PREFILLS ######################### + # Prefills run separately with shape [1, padded_prefill_len], + # due to lack of variable length attention kernel so far. + for idx, (req_id, prompt_len, token_ids, position_ids, + attn_metadata) in enumerate(prefill_data.zipped()): + + # FORWARD. + selected_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + self.kv_caches, + is_prompt=True) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_id = selected_token_ids.cpu()[prompt_len - 1].item() + sampled_token_ids[decode_data.num_decodes + idx] = token_id + req_state = self.requests[req_id] + + # TODO: ASSERT NO PREFIX CACHING. + assert req_state.num_computed_tokens == 0 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + # TODO: ASSERT NO CHUNKED PREFILL. + assert seq_len == req_state.num_tokens + assert prompt_len == seq_len + + # UPDATE REQUEST STATE. + req_idx = self.input_batch.req_id_to_index[req_id] + self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids[:num_reqs], + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids_cpu=sampled_token_ids, + logprob_token_ids_cpu=None, + logprobs_cpu=None, + ) + + def load_model(self) -> None: + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + + # xm_tp_rank = xr.global_ordinal() + # with patch( + # "vllm.model_executor.layers.vocab_parallel_embedding." + # "get_tensor_model_parallel_rank", + # return_value=xm_tp_rank): + # model = get_model(vllm_config=self.vllm_config) + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.wait_device_ops() + self.model = ModelWrapper(model) + + def _dummy_run(self, batch_size: int, seq_len: int, + kv_caches: List[torch.Tensor], is_prompt: bool) -> None: + """Dummy warmup run for memory usage and graph compilation.""" + + input_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = None if is_prompt else torch.zeros( + (batch_size, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device, + ) + context_lens = None if is_prompt else torch.ones( + (batch_size, ), + dtype=torch.int32, + device=self.device, + ) + attn_metadata = PallasAttentionMetadata( + is_prompt=is_prompt, + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + ) + + # NOTE: There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if is_prompt: + torch._dynamo.mark_dynamic(input_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + + # Dummy run. + self.model(input_ids, + position_ids, + attn_metadata, + kv_caches, + is_prompt=is_prompt) + + def profile_run(self) -> None: + """Profile to measure peak memory during forward pass.""" + + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value `None`. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + dummy_kv_caches = [( + torch.tensor([], dtype=torch.float32, device=self.device), + torch.tensor([], dtype=torch.float32, device=self.device), + ) for _ in range(self.num_attn_layers)] + + # Round to multiple of 16. + seq_len = (self.max_num_tokens + 15) // 16 * 16 + + # Run empty forward. + self._dummy_run(batch_size=1, + seq_len=seq_len, + kv_caches=dummy_kv_caches, + is_prompt=True) + + def capture_model(self) -> None: + """Compile the model.""" + + logger.info("Compiling the model with different input shapes.") + + # Prefill shapes. + start = time.perf_counter() + for batch_size in [1]: + seq_len = 16 + while True: + self._dummy_run(batch_size, + seq_len, + self.kv_caches, + is_prompt=True) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + if seq_len >= self.model_config.max_model_len: + break + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + seq_len = seq_len * 2 + + end = time.perf_counter() + logger.info("Compilation for prefill done in %.2f s.", end - start) + + # Decode shapes. + start = time.time() + seq_len = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() + while True: + self._dummy_run(batch_size, + seq_len, + self.kv_caches, + is_prompt=False) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode done in %.2f s.", end - start) + + def initialize_kv_cache(self, num_blocks: int) -> None: + assert len(self.kv_caches) == 0 + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + for _ in range(self.num_attn_layers): + self.kv_caches.append(( + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device), + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device), + )) + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional["MultiModalDataDict"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + + self.token_ids_cpu = np.zeros((max_num_reqs, max_model_len), + dtype=np.int32) + self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32) + + # Attention-related. + self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32) + self.block_table_cpu_tensor = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_cpu = self.block_table_cpu_tensor.numpy() + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: Set[str] = set() + + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + self.num_prefills = 0 + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + num_blocks = len(request.block_ids) + self.block_table_cpu[req_index, :num_blocks] = request.block_ids + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_id) + + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_id) + + def remove_request(self, req_id: str) -> Optional[int]: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def condense(self, empty_req_indices: List[int]) -> None: + if self.num_reqs == 0: + # The batched states are empty. + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_cpu[empty_index] = self.token_ids_cpu[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[ + last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata( + self, + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def num_decodes(self) -> int: + return self.num_reqs - self.num_prefills + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) if self.num_logprobs else 0 + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 + + +class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): + + def __init__(self, model: nn.Module): + self.model = model + compiled_callable = torch.compile(self.forward, + backend="openxla", + fullgraph=True, + dynamic=False) + super().__init__(compiled_callable) + + def __call__(self, *args, is_prompt: bool, **kwargs): + if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: + # not fully compiled yet, or not using the custom dispatcher, + # let PyTorch handle it + return self.compiled_callable(*args, **kwargs) + # the 3 compiled codes are: + # 0: for profiling + # 1: for prompt + # 2: for decode + # dispatch to the compiled code directly, skip PyTorch + if is_prompt: + with self.dispatch_to_code(1): + return self.forward(*args, **kwargs) + else: + with self.dispatch_to_code(2): + return self.forward(*args, **kwargs) + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: PallasAttentionMetadata, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + attn_metadata: The Pallas attention metadata. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + + # Skip this in memory profiling at initialization. + if kv_caches[0][0].numel() > 0: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + attn_metadata, + ) + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, None) + + # Greedy sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + return argmax_token_ids.squeeze(dim=1) + + +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 + + +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() diff --git a/vllm/v1/worker/__tpu_worker_v1.py b/vllm/v1/worker/__tpu_worker_v1.py new file mode 100644 index 0000000000000..866c1dbf6ea98 --- /dev/null +++ b/vllm/v1/worker/__tpu_worker_v1.py @@ -0,0 +1,198 @@ +"""A TPU worker class.""" + +import os +from typing import TYPE_CHECKING, Tuple + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.tpu_model_runner import TPUModelRunner + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + + +class TPUWorker: + + def __init__(self, vllm_config: VllmConfig, local_rank: int, rank: int, + distributed_init_method: str): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + def initialize(self): + os.environ["PJRT_DEVICE"] = "TPU" + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # NOTE: This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=self.parallel_config.world_size, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size) + + # Device initialization should happen after initializing the distributed + # runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = TPUModelRunner(self.vllm_config) + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + + self.model_runner.profile_run() + + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_tpu_memory = m["bytes_limit"] + peak_memory = m[ + "peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak Used: %sGB", peak_memory // 1024 // 1024 // 1024) + logger.debug("Total Memory: %sGB", + total_tpu_memory // 1024 // 1024 // 1024) + + cache_block_size = _get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + num_tpu_blocks = int( + (total_tpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 + return num_tpu_blocks, 0 + + def initialize_cache(self, num_tpu_blocks: int) -> None: + """Allocate TPU and CPU KV cache with the specified number of blocks.""" + + if num_tpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_tpu_blocks + max_model_len = self.model_config.max_model_len + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + self.model_runner.initialize_kv_cache(num_tpu_blocks) + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + xm.mark_step() + xm.wait_device_ops() + m = xm.get_memory_info(self.device) + peak_memory = m[ + "peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak GB Used Post KV Cache: %sGB", + peak_memory // 1024 // 1024 // 1024) + + def compile_or_warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + output = self.model_runner.execute_model(scheduler_output) + # TODO(woosuk): Send the output to the engine process. + return output + + +# TODO: this is a duplicate. +def _get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, +) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_attention_layers( + parallel_config) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + return dtype_size * total diff --git a/vllm/worker/__tpu_model_runner_v0.py b/vllm/worker/__tpu_model_runner_v0.py new file mode 100644 index 0000000000000..a721186137328 --- /dev/null +++ b/vllm/worker/__tpu_model_runner_v0.py @@ -0,0 +1,835 @@ +import time +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, Union) +from unittest.mock import patch + +import numpy as np +import torch +import torch.nn as nn +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 +# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. +_ENABLE_TOP_P = False +# FIXME(woosuk): A temporary hack to support `n > 1`. +# This can significantly affect the performance if too large. +_MAX_NUM_SAMPLES = 128 + + +@dataclass(frozen=True) +class ModelInputForTPU(ModelRunnerInputBase): + token_ids: torch.Tensor + position_ids: torch.Tensor + attn_metadata: AttentionMetadata + input_lens: torch.Tensor + t: torch.Tensor + p: torch.Tensor + num_samples: int + n: List[int] + seq_groups: List[List[int]] + is_first_multi_step: bool = True + is_last_step: bool = True + virtual_engine: int = 0 + async_callback: Optional[Callable] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = { + "token_ids": self.token_ids, + "position_ids": self.position_ids, + "input_lens": self.input_lens, + "t": self.t, + "p": self.p, + "num_samples": self.num_samples, + "n": self.n, + "seq_groups": self.seq_groups, + "is_first_multi_step": self.is_first_multi_step, + "is_last_step": self.is_last_step, + "virtual_engine": self.virtual_engine, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["ModelInputForTPU"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForTPU": + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): + + def __init__( + self, + vllm_config: VllmConfig, + is_driver_worker: bool = False, + ): + ModelRunnerBase.__init__(self, vllm_config=vllm_config) + self.is_driver_worker = is_driver_worker + + self.block_size = self.cache_config.block_size + self.max_num_blocks_per_seq = (self.model_config.max_model_len // + self.block_size) + self.block_tables = np.zeros( + (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq), + dtype=np.int32) + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + self.model_config.is_attention_free, + False, + ) + self.cached_step_outputs: List[torch.Tensor] = [] + + smem_size = 512 * 1024 + block_table_size = 4 * self.block_tables.size + if block_table_size >= smem_size: + logger.warning( + "The max_model_len (%d) is too large. This may degrade the " + "performance due to the insufficient smem size. Consider " + "setting --max-model-len to a smaller value.", + self.model_config.max_model_len) + + def load_model(self) -> None: + self.device = self.device_config.device + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + xm_tp_rank = xr.global_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.wait_device_ops() + self.model = ModelWrapper(model) + + def _dummy_run( + self, + batch_size: int, + seq_len: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + is_prompt: bool, + ) -> None: + if is_prompt: + seq_len = (seq_len + 15) // 16 * 16 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=None, + context_lens=None, + ) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + else: + assert seq_len == 1 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (batch_size, self.max_num_blocks_per_seq), + dtype=torch.int32, + device=self.device) + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size * seq_len, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=block_tables, + context_lens=context_lens, + ) + t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 + + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if is_prompt: + # Prefll + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + # Decode + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(input_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(t, 0) + torch._dynamo.mark_dynamic(p, 0) + # Dummy run. + self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + num_samples, + kv_caches, + is_prompt=is_prompt) + + def warmup_model( + self, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> None: + # Prefill + logger.info("Compiling the model with different input shapes...") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while True: + self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if seq_len >= self.model_config.max_model_len: + break + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + seq_len = seq_len * 2 + + end = time.time() + logger.info("Compilation for prefill done in %.2f s.", end - start) + + # Decode + start = time.time() + seq_len = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() + while True: + self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode done in %.2f s.", end - start) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + prompt_lens: List[int] = [] + slot_mapping: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + # Could include output tokens when a request is preempted. + prompt_tokens = seq_data.get_token_ids() + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + + input_tokens.extend(prompt_tokens) + input_positions.extend(list(range(prompt_len))) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + for i in range(prompt_len): + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + # Add paddings to EACH prompt to the smallest power of 2 that is + # greater than or equal to the prompt length. + # We pad the seq_len to reduce the compilation overhead. + # We execute each prompt individually (i.e., with batch_size 1) + # because the FlashAttention kernel does not support ragged inputs. + # TODO(woosuk): Use SplashAttention to support ragged inputs. + padded_prompt_len = _get_padded_prefill_len(prompt_len) + num_paddings = padded_prompt_len - prompt_len + input_tokens += [0] * num_paddings + input_positions += [0] * num_paddings + slot_mapping += [_PAD_SLOT_ID] * num_paddings + + assert len(prompt_lens) > 0 + num_prefills = len(prompt_lens) + input_tokens = torch.tensor(input_tokens, + dtype=torch.int32, + device="cpu") + input_positions = torch.tensor(input_positions, + dtype=torch.int32, + device="cpu") + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.int64, + device="cpu") + prompt_lens = torch.tensor(prompt_lens, + dtype=torch.int32, + device="cpu") + attn_metadata = self.attn_backend.make_metadata( + num_prefills=num_prefills, + num_prefill_tokens=0, # NOTE: This is not used. + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=None, + context_lens=None, + ) + return input_tokens, input_positions, attn_metadata, prompt_lens + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + context_lens: List[int] = [] + + batch_idx = 0 + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + context_lens.append(seq_len) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + self.block_tables[batch_idx, :len(block_table)] = block_table + batch_idx += 1 + + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + + batch_size = _get_padded_batch_size(batch_idx) + num_paddings = batch_size - batch_idx + input_tokens = input_tokens + [[0]] * num_paddings + input_positions = input_positions + [[0]] * num_paddings + slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings + context_lens = context_lens + [0] * num_paddings + + input_tokens = torch.tensor(input_tokens, + dtype=torch.int32, + device="cpu") + input_positions = torch.tensor(input_positions, + dtype=torch.int32, + device="cpu") + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.int64, + device="cpu") + context_lens = torch.tensor(context_lens, + dtype=torch.int32, + device="cpu") + block_tables = torch.tensor(self.block_tables[:batch_size], + dtype=torch.int32, + device="cpu") + input_lens = torch.tensor([1] * batch_size, + dtype=torch.int32, + device="cpu") + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=block_tables, + context_lens=context_lens, + ) + return input_tokens, input_positions, attn_metadata, input_lens + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + padded_batch_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + assert len(seq_group_metadata_list) > 0 + t = [] + p = [] + n = [] + for seq_group_metadata in seq_group_metadata_list: + sampling_params = seq_group_metadata.sampling_params + t.append(sampling_params.temperature) + if sampling_params.top_p != 1 and not _ENABLE_TOP_P: + raise NotImplementedError( + "Top-p sampling is currently disabled for the TPU backend " + "due to performance issues.") + p.append(sampling_params.top_p) + if sampling_params.top_k != -1: + raise NotImplementedError( + "Top-k sampling is currently disabled for the TPU backend " + "due to performance issues.") + if sampling_params.n > _MAX_NUM_SAMPLES: + raise NotImplementedError( + f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " + "backend.") + n.append(sampling_params.n) + if sampling_params.logprobs is not None: + raise NotImplementedError( + "logprobs is not currently supported by the TPU backend.") + if sampling_params.prompt_logprobs is not None: + raise NotImplementedError( + "prompt_logprobs is not currently supported by the TPU " + "backend.") + + # Repeat the sampling params if the seq group has multiple seqs. + num_seqs = len(seq_group_metadata.seq_data) + t += [t[-1]] * (num_seqs - 1) + p += [p[-1]] * (num_seqs - 1) + n += [n[-1]] * (num_seqs - 1) + + num_paddings = padded_batch_size - len(t) + t += [1.0] * num_paddings + p += [1.0] * num_paddings + + t = torch.tensor(t, dtype=torch.float32, device="cpu") + p = torch.tensor(p, dtype=torch.float32, device="cpu") + return t, p, n + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelInputForTPU: + del finished_requests_ids # Unused. + assert virtual_engine == 0 + assert len(seq_group_metadata_list) > 0 + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + if is_prompt: + inputs = self._prepare_prompt(seq_group_metadata_list) + else: + inputs = self._prepare_decode(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, input_lens = inputs + padded_batch_size = input_tokens.shape[0] + t, p, n = self._prepare_sample(seq_group_metadata_list, + padded_batch_size) + num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 + + seq_groups = [ + list(metadata.seq_data.keys()) + for metadata in seq_group_metadata_list + ] + return ModelInputForTPU(input_tokens, input_positions, attn_metadata, + input_lens, t, p, num_samples, n, seq_groups) + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: + model_input = ModelInputForTPU.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=self.attn_backend) + return model_input + + @torch.no_grad() + def execute_model( + self, + model_input: ModelInputForTPU, + kv_caches: Optional[List[Any]], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> List[SamplerOutput]: + assert intermediate_tensors is None + if not model_input.is_first_multi_step: + if not model_input.is_last_step: + return [] + + use_async_out_proc = model_input.async_callback is not None + sampler_outputs = [] + num_outputs = len(self.cached_step_outputs) + for i in range(num_outputs): + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + sampler_outputs.append(sampler_output) + + if i < num_outputs - 1 and use_async_out_proc: + assert model_input.async_callback is not None + ctx = model_input.async_callback.keywords[ # type: ignore + "ctx"] + ctx.append_output( + outputs=[sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False, + is_first_step_output=i == 0) + model_input.async_callback() + if use_async_out_proc: + return [sampler_outputs[-1]] + else: + return sampler_outputs + + is_prompt = model_input.attn_metadata.num_prefills > 0 + if is_prompt: + assert num_steps == 1 + # NOTE(woosuk): Since the FlashAttention kernel does not support + # ragged inputs, we split the prompts into different batches and + # process them separately. This is a temporary hack that should be + # optimized by using SplashAttention. + orig_slot_mapping = model_input.attn_metadata.slot_mapping + batch_size = model_input.input_lens.shape[0] + start_idx = 0 + next_token_ids = [] + for i in range(batch_size): + # Get the actual prefill_len. + prefill_len = model_input.input_lens[i:i + 1].item() + prefill_len = _get_padded_prefill_len(prefill_len) + end_idx = start_idx + prefill_len + + token_ids = model_input.token_ids[None, start_idx:end_idx].to( + self.device) + position_ids = model_input.position_ids[None, + start_idx:end_idx].to( + self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.num_prefills = 1 + attn_metadata.slot_mapping = orig_slot_mapping[ + None, start_idx:end_idx].to(self.device) + input_lens = model_input.input_lens[i:i + 1].to(self.device) + t = model_input.t[i:i + 1].to(self.device) + p = model_input.p[i:i + 1].to(self.device) + output_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + model_input.num_samples, + kv_caches, + is_prompt=True) + next_token_ids.append(output_token_ids[0]) + start_idx = end_idx + + if model_input.async_callback is not None: + model_input.async_callback() + # Retrieve the outputs to CPU. + next_token_ids = [ + output_token_ids.cpu().tolist() + for output_token_ids in next_token_ids + ] + + # NOTE(woosuk): Minimal code to construct the sampler outputs. + # The TPU backend does not reuse the sampler, since the TPU backend + # does not support advanced sampling parameters such as logprobs. + zero_logprob = Logprob(0.0) + sampler_outputs = [] + for i, seq_group in enumerate(model_input.seq_groups): + seq_ids = seq_group + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + seq_outputs = [] + for j in range(model_input.n[i]): + next_token_id = next_token_ids[i][j] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + sampler_outputs.append( + CompletionSequenceGroupOutput(seq_outputs, None)) + return [SamplerOutput(sampler_outputs)] + else: + token_ids = model_input.token_ids.to(self.device) + position_ids = model_input.position_ids.to(self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.slot_mapping = attn_metadata.slot_mapping.to( + self.device) + attn_metadata.block_tables = attn_metadata.block_tables.to( + self.device) + attn_metadata.context_lens = attn_metadata.context_lens.to( + self.device) + t = model_input.t.to(self.device) + p = model_input.p.to(self.device) + input_lens = model_input.input_lens.to(self.device) + for i in range(num_steps): + slot_mapping = attn_metadata.slot_mapping + output_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + model_input.num_samples, + kv_caches, + is_prompt=False) + self.cached_step_outputs.append(output_token_ids) + + if i < num_steps - 1: + # Prepare the inputs for the next step. + token_ids = output_token_ids.unsqueeze(dim=1).int() + position_ids = position_ids + 1 + attn_metadata.context_lens = attn_metadata.context_lens + 1 + + block_tables = attn_metadata.block_tables + block_number = block_tables.gather( + 1, + position_ids.long() // self.block_size) + block_offset = position_ids % self.block_size + + is_padding = slot_mapping == _PAD_SLOT_ID + slot_mapping = block_number * self.block_size + block_offset + slot_mapping = slot_mapping.long() + slot_mapping = torch.where(is_padding, _PAD_SLOT_ID, + slot_mapping) + attn_metadata.slot_mapping = slot_mapping + + if model_input.async_callback is not None: + model_input.async_callback() + + if num_steps > 1: + return [] + # Retrieve the outputs to CPU. + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + return [sampler_output] + + +class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): + + def __init__(self, model: nn.Module): + self.model = model + compiled_callable = torch.compile(self.forward, + backend="openxla", + fullgraph=True, + dynamic=False) + super().__init__(compiled_callable) + + def __call__(self, *args, is_prompt: bool, **kwargs): + if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: + # not fully compiled yet, or not using the custom dispatcher, + # let PyTorch handle it + return self.compiled_callable(*args, **kwargs) + # the 3 compiled codes are: + # 0: for profiling + # 1: for prompt + # 2: for decode + # dispatch to the compiled code directly, skip PyTorch + if is_prompt: + with self.dispatch_to_code(1): + return self.forward(*args, **kwargs) + else: + with self.dispatch_to_code(2): + return self.forward(*args, **kwargs) + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + input_lens: torch.Tensor, + t: torch.Tensor, + p: torch.Tensor, + num_samples: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + attn_metadata: The Pallas attention metadata. + input_lens: The actual input lengths of shape [batch_size]. + t: The sampling temperature of shape [batch_size]. + p: The top-p probability of shape [batch_size]. + num_samples: Number of samples to draw from each logits vector. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + batch_size, seq_len = token_ids.shape + # Calculate the positions to sample from. + start_indicies = torch.arange( + batch_size, dtype=torch.int32, device=input_lens.device) * seq_len + logits_indices = start_indicies + input_lens - 1 + + # FIXME(woosuk): This is a temporary hack to avoid using the existing + # sampler and sampling metadata. + sampling_metadata = SamplingMetadata( + seq_groups=[], + selected_token_indices=logits_indices, + categorized_sample_indices={}, + num_prompts=attn_metadata.num_prefills, + ) + + # Skip this in memory profiling at initialization. + if kv_caches[0][0].numel() > 0: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + attn_metadata, + ) + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Argmax sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + argmax_token_ids = argmax_token_ids.repeat(1, num_samples) + + # Zero temperature means greedy decoding. Avoid division by zero. + nonzero_t = torch.where(t != 0, t, 1.0) + logits = logits / nonzero_t.unsqueeze(dim=1) + if _ENABLE_TOP_P: + logits = _apply_top_p(logits, p.unsqueeze(dim=1)) + + # Random sampling. + probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + sampled_token_ids = torch.multinomial(probs, + num_samples, + replacement=True) + if num_samples == 1: + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + sampled_token_ids = sampled_token_ids.squeeze(dim=-1) + next_token_ids = torch.where(t != 0, sampled_token_ids, + argmax_token_ids) + return next_token_ids + + +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() + + +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 + + +def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + logits_sorted = torch.sort(logits, dim=-1, descending=True).values + sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1) + cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True) + cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) + logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) + return logits + + +def _make_decode_output( + next_token_ids: List[int], + seq_groups: List[List[int]], +) -> SamplerOutput: + zero_logprob = Logprob(0.0) + sampler_outputs = [] + batch_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group + seq_outputs = [] + for seq_id in seq_ids: + next_token_id = next_token_ids[batch_idx] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + batch_idx += 1 + sampler_outputs.append(CompletionSequenceGroupOutput( + seq_outputs, None)) + return SamplerOutput(sampler_outputs) From a9fc40890c418fd0e7a8f657759e40fc0c971096 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 15:52:09 +0000 Subject: [PATCH 07/18] finished rebase --- vllm/platforms/tpu.py | 6 ++++++ vllm/v1/core/scheduler.py | 11 +++++++++++ vllm/v1/executor/abstract.py | 9 +++++++-- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 77f5c8401424b..dd0eae57e0354 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -72,3 +72,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" else: parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker" + + @classmethod + def is_pin_memory_available(cls): + # TODO: Verify if it is indeed the case + logger.warning("Pin memory is not supported on TPU.") + return False diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index b26716f5c02e6..c3d1560b3e7f8 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -27,6 +27,10 @@ def __init__( cache_config: CacheConfig, lora_config: Optional[LoRAConfig], ) -> None: + # TODO: Refactor! Properly handle for TPU. + cache_config.enable_prefix_caching = False + scheduler_config.chunked_prefill_enabled = False + self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config @@ -205,6 +209,13 @@ def schedule(self) -> "SchedulerOutput": num_computed_tokens -= self.block_size num_new_tokens = self.block_size computed_blocks.pop() + + # If chunked prefill is not enabled, breakout of the loop. + # TODO: Verify if needed + if (not self.scheduler_config.chunked_prefill_enabled + and num_new_tokens > token_budget): + break + num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 5d74d4b01f500..0af61a3875b7c 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -3,6 +3,7 @@ from vllm.config import VllmConfig from vllm.v1.outputs import ModelRunnerOutput +from vllm.platforms import current_platform class Executor(ABC): @@ -21,8 +22,12 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]: executor_class = MultiprocExecutor else: assert (distributed_executor_backend is None) - from vllm.v1.executor.uniproc_executor import UniprocExecutor - executor_class = UniprocExecutor + if current_platform.is_tpu(): + from vllm.v1.executor.uniproc_tpu_executor import UniprocTPUExecutor + executor_class = UniprocTPUExecutor + else: + from vllm.v1.executor.uniproc_executor import UniprocExecutor + executor_class = UniprocExecutor return executor_class @abstractmethod From fda64cb565092208cf36a920797c110fc1f20a8f Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 15:56:48 +0000 Subject: [PATCH 08/18] remove tmp files --- vllm/v1/executor/__tpu_executor_v1.py | 80 -- vllm/v1/worker/__tpu_model_runner_v1.py | 981 ------------------------ vllm/v1/worker/__tpu_worker_v1.py | 198 ----- vllm/worker/__tpu_model_runner_v0.py | 835 -------------------- 4 files changed, 2094 deletions(-) delete mode 100644 vllm/v1/executor/__tpu_executor_v1.py delete mode 100644 vllm/v1/worker/__tpu_model_runner_v1.py delete mode 100644 vllm/v1/worker/__tpu_worker_v1.py delete mode 100644 vllm/worker/__tpu_model_runner_v0.py diff --git a/vllm/v1/executor/__tpu_executor_v1.py b/vllm/v1/executor/__tpu_executor_v1.py deleted file mode 100644 index 5e6e63086946d..0000000000000 --- a/vllm/v1/executor/__tpu_executor_v1.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import Optional, Tuple - -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.utils import get_distributed_init_method, get_ip, get_open_port -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.worker.tpu_worker import TPUWorker - -logger = init_logger(__name__) - -# import torch_xla.debug.profiler as xp - - -class TPUExecutor: - - def __init__(self, vllm_config: VllmConfig) -> None: - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - self.worker = self._create_worker() - self.worker.initialize() - self.worker.load_model() - - # self.server = xp.start_server(9012) - - def _create_worker( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None) -> TPUWorker: - """Return worker init args for a given rank.""" - - if distributed_init_method is None: - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - - return TPUWorker( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - ) - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks by invoking the - underlying worker. - """ - return self.worker.determine_num_available_blocks() - - def initialize_cache(self, num_tpu_blocks: int) -> None: - """Initialize the KV cache by invoking the underlying worker. - """ - # NOTE: This is logged in the executor because there can be >1 worker - # with other executors. We could log in the engine level, but work - # remains to abstract away the device for non-GPU configurations. - logger.info("# TPU blocks: %d", num_tpu_blocks) - self.worker.initialize_cache(num_tpu_blocks) - self.worker.compile_or_warm_up_model() - - def execute_model( - self, - scheduler_output, - ) -> ModelRunnerOutput: - # xp.trace_detached('localhost:9012', "./profiles") - output = self.worker.execute_model(scheduler_output) - return output - - def check_health(self) -> None: - # TPUExecutor will always be healthy as long as - # it's running. - return diff --git a/vllm/v1/worker/__tpu_model_runner_v1.py b/vllm/v1/worker/__tpu_model_runner_v1.py deleted file mode 100644 index 7963fe4973b55..0000000000000 --- a/vllm/v1/worker/__tpu_model_runner_v1.py +++ /dev/null @@ -1,981 +0,0 @@ -import time -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple - -import numpy as np -import torch -import torch.distributed -import torch.nn as nn -import torch_xla.core.xla_model as xm - -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model -from vllm.multimodal import MultiModalDataDict -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, - PallasAttentionMetadata) -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.sample.metadata import SamplingMetadata - -if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput - -logger = init_logger(__name__) - -# Here we utilize the behavior that out-of-bound index is ignored. -# FIXME: Find a more reliable way to prevent possible bugs. -_PAD_SLOT_ID = 1_000_000_000 - - -@dataclass -class PrefillInputData: - - request_ids: List - prompt_lens: List - token_ids: List - position_ids: List - attn_metadata: List - - def zipped(self): - return zip(self.request_ids, self.prompt_lens, self.token_ids, - self.position_ids, self.attn_metadata) - - -@dataclass -class DecodeInputData: - - num_decodes: int - token_ids: Optional[torch.Tensor] = None - position_ids: Optional[torch.Tensor] = None - attn_metadata: PallasAttentionMetadata = None - - -class TPUModelRunner: - - def __init__( - self, - vllm_config: VllmConfig, - ): - # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config) - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - model_config = self.model_config - cache_config = self.cache_config - scheduler_config = self.scheduler_config - parallel_config = self.parallel_config - self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() - self.dtype = self.model_config.dtype - if cache_config.cache_dtype == "auto": - self.kv_cache_dtype = self.dtype - else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] - - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_model_len = model_config.max_model_len - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = scheduler_config.max_num_batched_tokens - - # Model-related. - self.num_attn_layers = model_config.get_num_attention_layers( - parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() - - # List[k_cache, v_cache] - self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] - - # Request states. - self.requests: Dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.scheduler_config.max_num_seqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - ) - - self.prefill_positions = torch.tensor( - range(self.max_model_len), - device="cpu", - ).to(torch.int32).reshape(1, -1) - - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: - # Remove stopped requests from the cached states. - # Keep the states of the pre-empted requests. - for req_id in scheduler_output.finished_req_ids: - self.requests.pop(req_id, None) - - # Remove the requests from the persistent batch. - stopped_req_ids = set().union( - scheduler_output.preempted_req_ids, - scheduler_output.finished_req_ids, - ) - removed_req_indices: List[int] = [] - for req_id in stopped_req_ids: - req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) - - # Update the states of the running requests. - for req_data in scheduler_output.scheduled_running_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] - req_index = self.input_batch.req_id_to_index[req_id] - - # Update the num_computed_tokens. - req_state.num_computed_tokens = req_data.num_computed_tokens - self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - - # Update the block table. - num_new_blocks = len(req_data.new_block_ids) - if num_new_blocks == 0: - continue - start_index = len(req_state.block_ids) - end_index = start_index + num_new_blocks - req_state.block_ids.extend(req_data.new_block_ids) - self.input_batch.block_table_cpu[ - req_index, start_index:end_index] = req_data.new_block_ids - - req_ids_to_add: List[str] = [] - # Add new requests to the cached states. - for req_data in scheduler_output.scheduled_new_reqs: - req_id = req_data.req_id - sampling_params = req_data.sampling_params - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_params.seed) - else: - generator = None - - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=req_data.prompt_token_ids, - prompt=req_data.prompt, - multi_modal_data=req_data.multi_modal_data, - sampling_params=sampling_params, - generator=generator, - block_ids=req_data.block_ids, - num_computed_tokens=req_data.num_computed_tokens, - output_token_ids=[], - ) - req_ids_to_add.append(req_id) - - # Update the cached states of the resumed requests. - for req_data in scheduler_output.scheduled_resumed_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] - - req_state.block_ids = req_data.block_ids - req_state.num_computed_tokens = req_data.num_computed_tokens - req_ids_to_add.append(req_id) - - # THIS MOVES ALL THE DECODES TO THE FIRST N IN BATCH. - # Condense the batched states if there are empty indices. - removed_req_indices = sorted(removed_req_indices, reverse=True) - if removed_req_indices: - self.input_batch.condense(removed_req_indices) - - # ALL THE PREFILLS ARE THE LAST M IN THE BATCH. - # These are added at the end after the bacth is condensed. - self.input_batch.num_prefills = len(req_ids_to_add) - for req_id in req_ids_to_add: - req_state = self.requests[req_id] - self.input_batch.add_request(req_state, None) - - def _prepare_prefill_inputs( - self, - num_scheduled_tokens: List[int], - ) -> PrefillInputData: - # Each prefill run separately with shape [1, padded_prompt_len]. - # So we create lists that will be used in execute_model(). - - prefill_request_ids = [] - prefill_prompt_lens = [] - prefill_token_ids = [] - prefill_position_ids = [] - prefill_attn_metadata = [] - - # DECODES are the first num_decodes REQUESTS. - # PREFILLS are the next num_reqs - num_decodes REQUESTS. - num_reqs = self.input_batch.num_reqs - num_decodes = self.input_batch.num_decodes - for idx in range(num_decodes, num_reqs): - prefill_request_ids.append(self.input_batch.req_ids[idx]) - - # STATIC SHAPE: prefills are padded to the next power of 2. - prompt_len = num_scheduled_tokens[idx] - padded_prompt_len = _get_padded_prefill_len(prompt_len) - prefill_prompt_lens.append(prompt_len) - assert padded_prompt_len <= self.max_model_len - - # TOKEN_IDS. - token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ - idx, :padded_prompt_len].reshape(1, -1)) - prefill_token_ids.append(token_ids.to(self.device)) - - # POSITIONS. - positions = self.prefill_positions[:, :padded_prompt_len] - prefill_position_ids.append(positions.to(self.device)) - - # SLOT_MAPPING. - # The "slot" is the "physical index" of a token in the KV cache. - # Look up the block_idx in the block table (logical<>physical map) - # to compute this. - block_numbers = self.input_batch.block_table_cpu_tensor[ - idx, positions // self.block_size].reshape(1, -1) - block_offsets = positions % self.block_size - slot_mapping = block_numbers * self.block_size + block_offsets - # Set an out of range value for the padding tokens so that they - # are ignored when inserting into the KV cache. - slot_mapping[:, prompt_len:] = _PAD_SLOT_ID - slot_mapping = slot_mapping.long() - - # ATTN_METADATA. - prefill_attn_metadata.append( - PallasAttentionMetadata( - is_prompt=True, - slot_mapping=slot_mapping.to(self.device), - block_tables=None, - context_lens=None, - )) - - return PrefillInputData( - request_ids=prefill_request_ids, - prompt_lens=prefill_prompt_lens, - token_ids=prefill_token_ids, - position_ids=prefill_position_ids, - attn_metadata=prefill_attn_metadata, - ) - - def _prepare_decode_inputs(self, num_decodes: int) -> DecodeInputData: - # Decodes run as one single padded batch with shape [batch, 1] - # - # We need to set _PAD_SLOT_ID for the padding tokens in the - # slot_mapping, such that the attention KV cache insertion - # logic knows to ignore those indicies. Otherwise, the - # padding data can be dummy since we have a causal mask. - - if num_decodes == 0: - return DecodeInputData(num_decodes=0) - - # PAD FOR STATIC SHAPES. - padded_batch_size = _get_padded_batch_size(num_decodes) - - # POSITIONS. [batch, 1] - # We slice at the end, since we use the positions for gathering. - positions = torch.from_numpy( - self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) - index = positions.to(torch.int64) - positions = positions[:padded_batch_size] - - # TOKEN_IDS. [batch, 1] - token_ids = torch.gather( - input=torch.from_numpy(self.input_batch.token_ids_cpu), - dim=1, - index=index, - )[:padded_batch_size] - - # SLOT_MAPPING [batch, 1] - # The "slot" is the "physical index" of a token in the KV cache. - # Look up the block_idx in the block table (logical<>physical map) - # to compute this. - block_number = torch.gather( - input=self.input_batch.block_table_cpu_tensor, - dim=1, - index=(index // self.block_size)) - block_offsets = index % self.block_size - slot_mapping = block_number * self.block_size + block_offsets - # Set an out of range value for the padding tokens so that they - # are ignored when inserting into the KV cache. - slot_mapping[num_decodes:] = _PAD_SLOT_ID - slot_mapping = slot_mapping[:padded_batch_size] - - # BLOCK_TABLE [batch, max_num_blocks_per_req] - block_table = self.input_batch.block_table_cpu_tensor[: - padded_batch_size] - - # CONTEXT_LENS [batch_size] - context_lens = (positions.reshape(-1) + 1) - - # CPU<>TPU sync happens here. - return DecodeInputData(num_decodes=num_decodes, - token_ids=token_ids.to(self.device), - position_ids=positions.to(self.device), - attn_metadata=PallasAttentionMetadata( - is_prompt=False, - slot_mapping=slot_mapping.to(self.device), - block_tables=block_table.to(self.device), - context_lens=context_lens.to(self.device), - )) - - def _prepare_inputs( - self, scheduler_output: "SchedulerOutput" - ) -> Tuple[PrefillInputData, Optional[DecodeInputData]]: - - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 - - num_reqs = self.input_batch.num_reqs - num_decodes = self.input_batch.num_decodes - - # Get the number of scheduled tokens for each request. - # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens = [] - for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens.append(num_tokens) - - # NOTE: assert that all the decodes are "decodes". - if idx < num_decodes: - assert num_tokens == 1 - - return ( - self._prepare_prefill_inputs(num_scheduled_tokens), - self._prepare_decode_inputs(num_decodes), - ) - - def _prepare_sampling( - self, - scheduler_output: "SchedulerOutput", - ) -> SamplingMetadata: - skip_copy = True - if (scheduler_output.finished_req_ids - or scheduler_output.preempted_req_ids): - skip_copy = False - if (scheduler_output.scheduled_new_reqs - or scheduler_output.scheduled_resumed_reqs): - skip_copy = False - # Create the sampling metadata. - sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) - return sampling_metadata - - @torch.no_grad() - def execute_model( - self, - scheduler_output: "SchedulerOutput", - ) -> ModelRunnerOutput: - self._update_states(scheduler_output) - prefill_data, decode_data = self._prepare_inputs(scheduler_output) - num_reqs = self.input_batch.num_reqs - sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) - - ######################### DECODES ######################### - # Decodes run as one single batch with [padded_batch, 1] - if decode_data.num_decodes > 0: - - # FORWARD. - selected_token_ids = self.model(decode_data.token_ids, - decode_data.position_ids, - decode_data.attn_metadata, - self.kv_caches, - is_prompt=False) - - # NOTE: TPU<>CPU sync happens here. - # We need to call .cpu() first to avoid recompilation. - token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] - sampled_token_ids_list = token_ids.tolist() - sampled_token_ids[:decode_data.num_decodes] = token_ids - - # UPDATE REQUEST STATE. - for i, req_id in enumerate( - self.input_batch.req_ids[:decode_data.num_decodes]): - req_state = self.requests[req_id] - - # TODO: ASSERT NO CHUNKED PREFILL. - assert scheduler_output.num_scheduled_tokens[req_id] == 1 - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - assert seq_len == req_state.num_tokens - - token_id = sampled_token_ids_list[i] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - req_state.output_token_ids.append(token_id) - - ######################### PREFILLS ######################### - # Prefills run separately with shape [1, padded_prefill_len], - # due to lack of variable length attention kernel so far. - for idx, (req_id, prompt_len, token_ids, position_ids, - attn_metadata) in enumerate(prefill_data.zipped()): - - # FORWARD. - selected_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - self.kv_caches, - is_prompt=True) - - # NOTE: TPU<>CPU sync happens here. - # We need to call .cpu() first to avoid recompilation. - token_id = selected_token_ids.cpu()[prompt_len - 1].item() - sampled_token_ids[decode_data.num_decodes + idx] = token_id - req_state = self.requests[req_id] - - # TODO: ASSERT NO PREFIX CACHING. - assert req_state.num_computed_tokens == 0 - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - - # TODO: ASSERT NO CHUNKED PREFILL. - assert seq_len == req_state.num_tokens - assert prompt_len == seq_len - - # UPDATE REQUEST STATE. - req_idx = self.input_batch.req_id_to_index[req_id] - self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id - req_state.output_token_ids.append(token_id) - - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids[:num_reqs], - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids_cpu=sampled_token_ids, - logprob_token_ids_cpu=None, - logprobs_cpu=None, - ) - - def load_model(self) -> None: - - # NOTE(woosuk): While the executor assigns the TP ranks to the worker - # process, the ranks can be different from the ranks internally assigned - # by the xm runtime. Therefore, there is a mismatch in the rank - # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. - # This is not a problem in linear layers because all-reduce is - # rank-agnostic. However, it matters for all-gather as the ranks - # determine the order of concatenating the output tensors. - # As a workaround, we use the xm's rank assignment only when loading - # the embedding weights. - - # xm_tp_rank = xr.global_ordinal() - # with patch( - # "vllm.model_executor.layers.vocab_parallel_embedding." - # "get_tensor_model_parallel_rank", - # return_value=xm_tp_rank): - # model = get_model(vllm_config=self.vllm_config) - model = get_model(vllm_config=self.vllm_config) - model = model.eval() - xm.wait_device_ops() - self.model = ModelWrapper(model) - - def _dummy_run(self, batch_size: int, seq_len: int, - kv_caches: List[torch.Tensor], is_prompt: bool) -> None: - """Dummy warmup run for memory usage and graph compilation.""" - - input_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((batch_size, seq_len), - dtype=torch.int64, - device=self.device) - block_tables = None if is_prompt else torch.zeros( - (batch_size, self.max_num_blocks_per_req), - dtype=torch.int32, - device=self.device, - ) - context_lens = None if is_prompt else torch.ones( - (batch_size, ), - dtype=torch.int32, - device=self.device, - ) - attn_metadata = PallasAttentionMetadata( - is_prompt=is_prompt, - slot_mapping=slot_mapping, - block_tables=block_tables, - context_lens=context_lens, - ) - - # NOTE: There are two stages of compilation: torch.compile and - # XLA compilation. Using `mark_dynamic` can reduce the torch.compile - # overhead by reusing the FX graph for different shapes. - # However, the XLA graph will still require static shapes and needs to - # be re-compiled for every different shapes. This overhead is inevitable - # in the first run, but can be skipped afterwards as we cache the XLA - # graphs in the disk (VLLM_XLA_CACHE_PATH). - if is_prompt: - torch._dynamo.mark_dynamic(input_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) - else: - torch._dynamo.mark_dynamic(input_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - - # Dummy run. - self.model(input_ids, - position_ids, - attn_metadata, - kv_caches, - is_prompt=is_prompt) - - def profile_run(self) -> None: - """Profile to measure peak memory during forward pass.""" - - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value `None`. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - dummy_kv_caches = [( - torch.tensor([], dtype=torch.float32, device=self.device), - torch.tensor([], dtype=torch.float32, device=self.device), - ) for _ in range(self.num_attn_layers)] - - # Round to multiple of 16. - seq_len = (self.max_num_tokens + 15) // 16 * 16 - - # Run empty forward. - self._dummy_run(batch_size=1, - seq_len=seq_len, - kv_caches=dummy_kv_caches, - is_prompt=True) - - def capture_model(self) -> None: - """Compile the model.""" - - logger.info("Compiling the model with different input shapes.") - - # Prefill shapes. - start = time.perf_counter() - for batch_size in [1]: - seq_len = 16 - while True: - self._dummy_run(batch_size, - seq_len, - self.kv_caches, - is_prompt=True) - xm.wait_device_ops() - logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - if seq_len >= self.model_config.max_model_len: - break - num_tokens = batch_size * seq_len - if num_tokens >= self.scheduler_config.max_num_batched_tokens: - break - seq_len = seq_len * 2 - - end = time.perf_counter() - logger.info("Compilation for prefill done in %.2f s.", end - start) - - # Decode shapes. - start = time.time() - seq_len = 1 - batch_size = 8 # Must be in sync with _get_padded_batch_size() - while True: - self._dummy_run(batch_size, - seq_len, - self.kv_caches, - is_prompt=False) - xm.wait_device_ops() - logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - - if batch_size >= self.scheduler_config.max_num_seqs: - break - batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 - - end = time.time() - logger.info("Compilation for decode done in %.2f s.", end - start) - - def initialize_kv_cache(self, num_blocks: int) -> None: - assert len(self.kv_caches) == 0 - kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) - for _ in range(self.num_attn_layers): - self.kv_caches.append(( - torch.zeros(kv_cache_shape, - dtype=self.kv_cache_dtype, - device=self.device), - torch.zeros(kv_cache_shape, - dtype=self.kv_cache_dtype, - device=self.device), - )) - - -@dataclass -class CachedRequestState: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - multi_modal_data: Optional["MultiModalDataDict"] - sampling_params: SamplingParams - generator: Optional[torch.Generator] - - block_ids: List[int] - num_computed_tokens: int - output_token_ids: List[int] - - @property - def num_tokens(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - -class InputBatch: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_blocks_per_req: int, - device: torch.device, - pin_memory: bool, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req - self.device = device - self.pin_memory = pin_memory - - self.req_ids: List[Optional[str]] = [None] * max_num_reqs - self.req_id_to_index: Dict[str, int] = {} - - self.token_ids_cpu = np.zeros((max_num_reqs, max_model_len), - dtype=np.int32) - self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32) - - # Attention-related. - self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32) - self.block_table_cpu_tensor = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.block_table_cpu = self.block_table_cpu_tensor.numpy() - - # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.temperature_cpu = self.temperature_cpu_tensor.numpy() - self.greedy_reqs: Set[str] = set() - self.random_reqs: Set[str] = set() - - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.top_p_cpu = self.top_p_cpu_tensor.numpy() - self.top_p_reqs: Set[str] = set() - - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.top_k_cpu = self.top_k_cpu_tensor.numpy() - self.top_k_reqs: Set[str] = set() - - # req_index -> generator - self.generators: Dict[int, torch.Generator] = {} - - self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() - - self.num_prefills = 0 - - def add_request( - self, - request: "CachedRequestState", - req_index: Optional[int] = None, - ) -> None: - if req_index is None: - req_index = self.num_reqs - assert req_index < self.max_num_reqs - - req_id = request.req_id - self.req_ids[req_index] = req_id - self.req_id_to_index[req_id] = req_index - - # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - - self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - num_blocks = len(request.block_ids) - self.block_table_cpu[req_index, :num_blocks] = request.block_ids - - sampling_params = request.sampling_params - self.temperature_cpu[req_index] = sampling_params.temperature - if sampling_params.sampling_type == SamplingType.GREEDY: - self.greedy_reqs.add(req_id) - else: - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - self.top_k_cpu[req_index] = sampling_params.top_k - if sampling_params.top_k > 0: - self.top_k_reqs.add(req_id) - - self.generators[req_index] = request.generator - - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) - - def remove_request(self, req_id: str) -> Optional[int]: - req_index = self.req_id_to_index.pop(req_id, None) - if req_index is None: - return None - self.req_ids[req_index] = None - - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.generators.pop(req_index, None) - self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) - return req_index - - def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.generators.clear() - self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() - - def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: - # The batched states are empty. - return - - # NOTE(woosuk): This function assumes that the empty_req_indices - # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 - while empty_req_indices: - # Find the largest non-empty index. - while last_req_index in empty_req_indices: - last_req_index -= 1 - - # Find the smallest empty index. - empty_index = empty_req_indices.pop() - if empty_index >= last_req_index: - break - - # Swap the states. - req_id = self.req_ids[last_req_index] - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None - self.req_id_to_index[req_id] = empty_index - - # TODO(woosuk): Optimize the copy of token_ids_cpu and - # block_table_cpu. - self.token_ids_cpu[empty_index] = self.token_ids_cpu[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table_cpu[empty_index] = self.block_table_cpu[ - last_req_index] - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] - self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] - self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - generator = self.generators.pop(last_req_index, None) - if generator is not None: - self.generators[empty_index] = generator - - # Decrement last_req_index since it is now empty. - last_req_index -= 1 - - def make_sampling_metadata( - self, - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - ) - - @property - def num_reqs(self) -> int: - return len(self.req_id_to_index) - - @property - def num_decodes(self) -> int: - return self.num_reqs - self.num_prefills - - @property - def all_greedy(self) -> bool: - return len(self.random_reqs) == 0 - - @property - def all_random(self) -> bool: - return len(self.greedy_reqs) == 0 - - @property - def no_top_p(self) -> bool: - return len(self.top_p_reqs) == 0 - - @property - def no_top_k(self) -> bool: - return len(self.top_k_reqs) == 0 - - @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 - - @property - def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0 - - -class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): - - def __init__(self, model: nn.Module): - self.model = model - compiled_callable = torch.compile(self.forward, - backend="openxla", - fullgraph=True, - dynamic=False) - super().__init__(compiled_callable) - - def __call__(self, *args, is_prompt: bool, **kwargs): - if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: - # not fully compiled yet, or not using the custom dispatcher, - # let PyTorch handle it - return self.compiled_callable(*args, **kwargs) - # the 3 compiled codes are: - # 0: for profiling - # 1: for prompt - # 2: for decode - # dispatch to the compiled code directly, skip PyTorch - if is_prompt: - with self.dispatch_to_code(1): - return self.forward(*args, **kwargs) - else: - with self.dispatch_to_code(2): - return self.forward(*args, **kwargs) - - def forward( - self, - token_ids: torch.Tensor, - position_ids: torch.Tensor, - attn_metadata: PallasAttentionMetadata, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> torch.Tensor: - """Executes the forward pass of the model and samples the next token. - - Args: - token_ids: The input token IDs of shape [batch_size, seq_len]. - position_ids: The input position IDs of shape [batch_size, seq_len]. - attn_metadata: The Pallas attention metadata. - kv_caches: The key and value caches. They can be None during the - memory profiling at initialization. - """ - - # Skip this in memory profiling at initialization. - if kv_caches[0][0].numel() > 0: - # index_copy_(slot_mapping) only works when the inserted dimension - # is 0. However, the KV cache in the Pallas backend has the shape - # [num_kv_heads, num_blocks, block_size, head_size]. To make it - # work, we need to flatten the first three dimensions and modify - # the slot_mapping accordingly. - num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape - slot_mapping = attn_metadata.slot_mapping - slot_mapping = slot_mapping.flatten() - head_indicies = torch.arange(0, - num_kv_heads, - device=slot_mapping.device, - dtype=slot_mapping.dtype) - head_indicies *= block_size * num_blocks - slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( - -1, num_kv_heads) - slot_mapping = slot_mapping + head_indicies.view(1, -1) - slot_mapping = slot_mapping.flatten() - attn_metadata.slot_mapping = slot_mapping - - hidden_states = self.model( - token_ids, - position_ids, - kv_caches, - attn_metadata, - ) - hidden_states = hidden_states.flatten(0, 1) - logits = self.model.compute_logits(hidden_states, None) - - # Greedy sampling. - argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - return argmax_token_ids.squeeze(dim=1) - - -def _get_padded_batch_size(batch_size: int) -> int: - # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. - # To meet this requirement in the simplest way, we set the minimal batch - # size to 8. - if batch_size <= 8: - return 8 - else: - return ((batch_size + 15) // 16) * 16 - - -def _get_padded_prefill_len(x: int) -> int: - # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence - # length to be a multiple of 16. We pad the prompt length to the nearest - # multiple of 16. This is also good for performance. - if x <= 16: - return 16 - return 1 << (x - 1).bit_length() diff --git a/vllm/v1/worker/__tpu_worker_v1.py b/vllm/v1/worker/__tpu_worker_v1.py deleted file mode 100644 index 866c1dbf6ea98..0000000000000 --- a/vllm/v1/worker/__tpu_worker_v1.py +++ /dev/null @@ -1,198 +0,0 @@ -"""A TPU worker class.""" - -import os -from typing import TYPE_CHECKING, Tuple - -import torch -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr - -import vllm.envs as envs -from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.logger import init_logger -from vllm.model_executor import set_random_seed -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.worker.tpu_model_runner import TPUModelRunner - -if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput - -logger = init_logger(__name__) - - -class TPUWorker: - - def __init__(self, vllm_config: VllmConfig, local_rank: int, rank: int, - distributed_init_method: str): - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - - def initialize(self): - os.environ["PJRT_DEVICE"] = "TPU" - torch.set_grad_enabled(False) - torch.set_default_dtype(self.model_config.dtype) - - # NOTE: This is just to initialize the TP group and broadcast - # the input objects on CPU. The all-reduce and all-gather ops on TPU - # are invoked by `xm.all_reduce` and `xm.all_gather` which use their - # own context. - init_distributed_environment( - world_size=self.parallel_config.world_size, - rank=self.rank, - local_rank=self.local_rank, - distributed_init_method=self.distributed_init_method, - backend="gloo", - ) - ensure_model_parallel_initialized( - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size) - - # Device initialization should happen after initializing the distributed - # runtime. - self.device = xm.xla_device() - self.device_config.device = self.device - - # Init ModelRunner here, so that we have access to self.device. - self.model_runner = TPUModelRunner(self.vllm_config) - - # Set random seed. - set_random_seed(self.model_config.seed) - xm.set_rng_state(self.model_config.seed, self.device) - - # Increase the cache size limit, which is the maximum number of - # dynamo graphs that can be compiled. - # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and - # 30-40 graphs for decode. 128 is an arbitrary safe number. - torch._dynamo.config.cache_size_limit = 128 - # Use persistent cache to avoid XLA recompilation. - # NOTE(woosuk): Set per-rank cache path since different ranks - # can have slightly different XLA graphs. - world_size = self.parallel_config.world_size - rank = xr.global_ordinal() - per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, - f"tp{world_size}_rank{rank}") - xr.initialize_cache(per_rank_path, readonly=False) - - def load_model(self): - self.model_runner.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - - self.model_runner.profile_run() - - # Synchronize before measuring the memory usage. - xm.wait_device_ops() - - # Get the maximum amount of memory used by the model weights and - # intermediate activations. - m = xm.get_memory_info(self.device) - total_tpu_memory = m["bytes_limit"] - peak_memory = m[ - "peak_bytes_used"] # Weights + intermediate activations. - logger.debug("Peak Used: %sGB", peak_memory // 1024 // 1024 // 1024) - logger.debug("Total Memory: %sGB", - total_tpu_memory // 1024 // 1024 // 1024) - - cache_block_size = _get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - num_tpu_blocks = int( - (total_tpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) - num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 - return num_tpu_blocks, 0 - - def initialize_cache(self, num_tpu_blocks: int) -> None: - """Allocate TPU and CPU KV cache with the specified number of blocks.""" - - if num_tpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - - max_seq_len = self.cache_config.block_size * num_tpu_blocks - max_model_len = self.model_config.max_model_len - if max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") - - self.model_runner.initialize_kv_cache(num_tpu_blocks) - - # Get the maximum amount of memory used by the model weights and - # intermediate activations. - xm.mark_step() - xm.wait_device_ops() - m = xm.get_memory_info(self.device) - peak_memory = m[ - "peak_bytes_used"] # Weights + intermediate activations. - logger.debug("Peak GB Used Post KV Cache: %sGB", - peak_memory // 1024 // 1024 // 1024) - - def compile_or_warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model() - - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) - - def execute_model( - self, - scheduler_output: "SchedulerOutput", - ) -> ModelRunnerOutput: - output = self.model_runner.execute_model(scheduler_output) - # TODO(woosuk): Send the output to the engine process. - return output - - -# TODO: this is a duplicate. -def _get_cache_block_size( - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig, -) -> int: - head_size = model_config.get_head_size() - num_heads = model_config.get_num_kv_heads(parallel_config) - num_attention_layers = model_config.get_num_attention_layers( - parallel_config) - - key_cache_block = cache_config.block_size * num_heads * head_size - value_cache_block = key_cache_block - total = num_attention_layers * (key_cache_block + value_cache_block) - if cache_config.cache_dtype == "auto": - dtype = model_config.dtype - else: - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - dtype_size = get_dtype_size(dtype) - return dtype_size * total diff --git a/vllm/worker/__tpu_model_runner_v0.py b/vllm/worker/__tpu_model_runner_v0.py deleted file mode 100644 index a721186137328..0000000000000 --- a/vllm/worker/__tpu_model_runner_v0.py +++ /dev/null @@ -1,835 +0,0 @@ -import time -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, - Type, Union) -from unittest.mock import patch - -import numpy as np -import torch -import torch.nn as nn -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr - -from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SequenceGroupMetadata, SequenceOutput) -from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, - _add_attn_metadata_broadcastable_dict, - _init_attn_metadata_from_tensor_dict) - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -logger = init_logger(__name__) - -# Here we utilize the behavior that out-of-bound index is ignored. -# FIXME(woosuk): Find a more reliable way to prevent possible bugs. -_PAD_SLOT_ID = 1_000_000_000 -# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. -_ENABLE_TOP_P = False -# FIXME(woosuk): A temporary hack to support `n > 1`. -# This can significantly affect the performance if too large. -_MAX_NUM_SAMPLES = 128 - - -@dataclass(frozen=True) -class ModelInputForTPU(ModelRunnerInputBase): - token_ids: torch.Tensor - position_ids: torch.Tensor - attn_metadata: AttentionMetadata - input_lens: torch.Tensor - t: torch.Tensor - p: torch.Tensor - num_samples: int - n: List[int] - seq_groups: List[List[int]] - is_first_multi_step: bool = True - is_last_step: bool = True - virtual_engine: int = 0 - async_callback: Optional[Callable] = None - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = { - "token_ids": self.token_ids, - "position_ids": self.position_ids, - "input_lens": self.input_lens, - "t": self.t, - "p": self.p, - "num_samples": self.num_samples, - "n": self.n, - "seq_groups": self.seq_groups, - "is_first_multi_step": self.is_first_multi_step, - "is_last_step": self.is_last_step, - "virtual_engine": self.virtual_engine, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type["ModelInputForTPU"], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForTPU": - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): - - def __init__( - self, - vllm_config: VllmConfig, - is_driver_worker: bool = False, - ): - ModelRunnerBase.__init__(self, vllm_config=vllm_config) - self.is_driver_worker = is_driver_worker - - self.block_size = self.cache_config.block_size - self.max_num_blocks_per_seq = (self.model_config.max_model_len // - self.block_size) - self.block_tables = np.zeros( - (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq), - dtype=np.int32) - self.attn_backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.cache_config.cache_dtype, - self.block_size, - self.model_config.is_attention_free, - False, - ) - self.cached_step_outputs: List[torch.Tensor] = [] - - smem_size = 512 * 1024 - block_table_size = 4 * self.block_tables.size - if block_table_size >= smem_size: - logger.warning( - "The max_model_len (%d) is too large. This may degrade the " - "performance due to the insufficient smem size. Consider " - "setting --max-model-len to a smaller value.", - self.model_config.max_model_len) - - def load_model(self) -> None: - self.device = self.device_config.device - - # NOTE(woosuk): While the executor assigns the TP ranks to the worker - # process, the ranks can be different from the ranks internally assigned - # by the xm runtime. Therefore, there is a mismatch in the rank - # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. - # This is not a problem in linear layers because all-reduce is - # rank-agnostic. However, it matters for all-gather as the ranks - # determine the order of concatenating the output tensors. - # As a workaround, we use the xm's rank assignment only when loading - # the embedding weights. - xm_tp_rank = xr.global_ordinal() - with patch( - "vllm.model_executor.layers.vocab_parallel_embedding." - "get_tensor_model_parallel_rank", - return_value=xm_tp_rank): - model = get_model(vllm_config=self.vllm_config) - model = model.eval() - xm.wait_device_ops() - self.model = ModelWrapper(model) - - def _dummy_run( - self, - batch_size: int, - seq_len: int, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - is_prompt: bool, - ) -> None: - if is_prompt: - seq_len = (seq_len + 15) // 16 * 16 - token_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((batch_size, seq_len), - dtype=torch.int64, - device=self.device) - attn_metadata = self.attn_backend.make_metadata( - num_prefills=batch_size, - num_prefill_tokens=batch_size * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - block_tables=None, - context_lens=None, - ) - input_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) - else: - assert seq_len == 1 - token_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((batch_size, seq_len), - dtype=torch.int64, - device=self.device) - block_tables = torch.zeros( - (batch_size, self.max_num_blocks_per_seq), - dtype=torch.int32, - device=self.device) - context_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) - input_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size * seq_len, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - block_tables=block_tables, - context_lens=context_lens, - ) - t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) - p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) - num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 - - # NOTE(woosuk): There are two stages of compilation: torch.compile and - # XLA compilation. Using `mark_dynamic` can reduce the torch.compile - # overhead by reusing the FX graph for different shapes. - # However, the XLA graph will still require static shapes and needs to - # be re-compiled for every different shapes. This overhead is inevitable - # in the first run, but can be skipped afterwards as we cache the XLA - # graphs in the disk (VLLM_XLA_CACHE_PATH). - if is_prompt: - # Prefll - torch._dynamo.mark_dynamic(token_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) - else: - # Decode - torch._dynamo.mark_dynamic(token_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(input_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - torch._dynamo.mark_dynamic(t, 0) - torch._dynamo.mark_dynamic(p, 0) - # Dummy run. - self.model(token_ids, - position_ids, - attn_metadata, - input_lens, - t, - p, - num_samples, - kv_caches, - is_prompt=is_prompt) - - def warmup_model( - self, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> None: - # Prefill - logger.info("Compiling the model with different input shapes...") - start = time.time() - for batch_size in [1]: - seq_len = 16 - while True: - self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True) - xm.wait_device_ops() - logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - - if seq_len >= self.model_config.max_model_len: - break - num_tokens = batch_size * seq_len - if num_tokens >= self.scheduler_config.max_num_batched_tokens: - break - seq_len = seq_len * 2 - - end = time.time() - logger.info("Compilation for prefill done in %.2f s.", end - start) - - # Decode - start = time.time() - seq_len = 1 - batch_size = 8 # Must be in sync with _get_padded_batch_size() - while True: - self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False) - xm.wait_device_ops() - logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - - if batch_size >= self.scheduler_config.max_num_seqs: - break - batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 - - end = time.time() - logger.info("Compilation for decode done in %.2f s.", end - start) - - def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - prompt_lens: List[int] = [] - slot_mapping: List[int] = [] - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - # Could include output tokens when a request is preempted. - prompt_tokens = seq_data.get_token_ids() - prompt_len = len(prompt_tokens) - prompt_lens.append(prompt_len) - - input_tokens.extend(prompt_tokens) - input_positions.extend(list(range(prompt_len))) - - assert seq_group_metadata.block_tables is not None - block_table = seq_group_metadata.block_tables[seq_id] - for i in range(prompt_len): - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - # Add paddings to EACH prompt to the smallest power of 2 that is - # greater than or equal to the prompt length. - # We pad the seq_len to reduce the compilation overhead. - # We execute each prompt individually (i.e., with batch_size 1) - # because the FlashAttention kernel does not support ragged inputs. - # TODO(woosuk): Use SplashAttention to support ragged inputs. - padded_prompt_len = _get_padded_prefill_len(prompt_len) - num_paddings = padded_prompt_len - prompt_len - input_tokens += [0] * num_paddings - input_positions += [0] * num_paddings - slot_mapping += [_PAD_SLOT_ID] * num_paddings - - assert len(prompt_lens) > 0 - num_prefills = len(prompt_lens) - input_tokens = torch.tensor(input_tokens, - dtype=torch.int32, - device="cpu") - input_positions = torch.tensor(input_positions, - dtype=torch.int32, - device="cpu") - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.int64, - device="cpu") - prompt_lens = torch.tensor(prompt_lens, - dtype=torch.int32, - device="cpu") - attn_metadata = self.attn_backend.make_metadata( - num_prefills=num_prefills, - num_prefill_tokens=0, # NOTE: This is not used. - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - block_tables=None, - context_lens=None, - ) - return input_tokens, input_positions, attn_metadata, prompt_lens - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] - context_lens: List[int] = [] - - batch_idx = 0 - for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) - - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append([position]) - context_lens.append(seq_len) - - assert seq_group_metadata.block_tables is not None - block_table = seq_group_metadata.block_tables[seq_id] - self.block_tables[batch_idx, :len(block_table)] = block_table - batch_idx += 1 - - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append([slot]) - - batch_size = _get_padded_batch_size(batch_idx) - num_paddings = batch_size - batch_idx - input_tokens = input_tokens + [[0]] * num_paddings - input_positions = input_positions + [[0]] * num_paddings - slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings - context_lens = context_lens + [0] * num_paddings - - input_tokens = torch.tensor(input_tokens, - dtype=torch.int32, - device="cpu") - input_positions = torch.tensor(input_positions, - dtype=torch.int32, - device="cpu") - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.int64, - device="cpu") - context_lens = torch.tensor(context_lens, - dtype=torch.int32, - device="cpu") - block_tables = torch.tensor(self.block_tables[:batch_size], - dtype=torch.int32, - device="cpu") - input_lens = torch.tensor([1] * batch_size, - dtype=torch.int32, - device="cpu") - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - block_tables=block_tables, - context_lens=context_lens, - ) - return input_tokens, input_positions, attn_metadata, input_lens - - def _prepare_sample( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - padded_batch_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: - assert len(seq_group_metadata_list) > 0 - t = [] - p = [] - n = [] - for seq_group_metadata in seq_group_metadata_list: - sampling_params = seq_group_metadata.sampling_params - t.append(sampling_params.temperature) - if sampling_params.top_p != 1 and not _ENABLE_TOP_P: - raise NotImplementedError( - "Top-p sampling is currently disabled for the TPU backend " - "due to performance issues.") - p.append(sampling_params.top_p) - if sampling_params.top_k != -1: - raise NotImplementedError( - "Top-k sampling is currently disabled for the TPU backend " - "due to performance issues.") - if sampling_params.n > _MAX_NUM_SAMPLES: - raise NotImplementedError( - f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " - "backend.") - n.append(sampling_params.n) - if sampling_params.logprobs is not None: - raise NotImplementedError( - "logprobs is not currently supported by the TPU backend.") - if sampling_params.prompt_logprobs is not None: - raise NotImplementedError( - "prompt_logprobs is not currently supported by the TPU " - "backend.") - - # Repeat the sampling params if the seq group has multiple seqs. - num_seqs = len(seq_group_metadata.seq_data) - t += [t[-1]] * (num_seqs - 1) - p += [p[-1]] * (num_seqs - 1) - n += [n[-1]] * (num_seqs - 1) - - num_paddings = padded_batch_size - len(t) - t += [1.0] * num_paddings - p += [1.0] * num_paddings - - t = torch.tensor(t, dtype=torch.float32, device="cpu") - p = torch.tensor(p, dtype=torch.float32, device="cpu") - return t, p, n - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None, - ) -> ModelInputForTPU: - del finished_requests_ids # Unused. - assert virtual_engine == 0 - assert len(seq_group_metadata_list) > 0 - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - if is_prompt: - inputs = self._prepare_prompt(seq_group_metadata_list) - else: - inputs = self._prepare_decode(seq_group_metadata_list) - input_tokens, input_positions, attn_metadata, input_lens = inputs - padded_batch_size = input_tokens.shape[0] - t, p, n = self._prepare_sample(seq_group_metadata_list, - padded_batch_size) - num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 - - seq_groups = [ - list(metadata.seq_data.keys()) - for metadata in seq_group_metadata_list - ] - return ModelInputForTPU(input_tokens, input_positions, attn_metadata, - input_lens, t, p, num_samples, n, seq_groups) - - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: - model_input = ModelInputForTPU.from_broadcasted_tensor_dict( - tensor_dict, attn_backend=self.attn_backend) - return model_input - - @torch.no_grad() - def execute_model( - self, - model_input: ModelInputForTPU, - kv_caches: Optional[List[Any]], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> List[SamplerOutput]: - assert intermediate_tensors is None - if not model_input.is_first_multi_step: - if not model_input.is_last_step: - return [] - - use_async_out_proc = model_input.async_callback is not None - sampler_outputs = [] - num_outputs = len(self.cached_step_outputs) - for i in range(num_outputs): - next_token_ids = self.cached_step_outputs.pop(0) - next_token_ids = next_token_ids.cpu().tolist() - sampler_output = _make_decode_output(next_token_ids, - model_input.seq_groups) - sampler_outputs.append(sampler_output) - - if i < num_outputs - 1 and use_async_out_proc: - assert model_input.async_callback is not None - ctx = model_input.async_callback.keywords[ # type: ignore - "ctx"] - ctx.append_output( - outputs=[sampler_output], - seq_group_metadata_list=ctx.seq_group_metadata_list, - scheduler_outputs=ctx.scheduler_outputs, - is_async=False, - is_last_step=False, - is_first_step_output=i == 0) - model_input.async_callback() - if use_async_out_proc: - return [sampler_outputs[-1]] - else: - return sampler_outputs - - is_prompt = model_input.attn_metadata.num_prefills > 0 - if is_prompt: - assert num_steps == 1 - # NOTE(woosuk): Since the FlashAttention kernel does not support - # ragged inputs, we split the prompts into different batches and - # process them separately. This is a temporary hack that should be - # optimized by using SplashAttention. - orig_slot_mapping = model_input.attn_metadata.slot_mapping - batch_size = model_input.input_lens.shape[0] - start_idx = 0 - next_token_ids = [] - for i in range(batch_size): - # Get the actual prefill_len. - prefill_len = model_input.input_lens[i:i + 1].item() - prefill_len = _get_padded_prefill_len(prefill_len) - end_idx = start_idx + prefill_len - - token_ids = model_input.token_ids[None, start_idx:end_idx].to( - self.device) - position_ids = model_input.position_ids[None, - start_idx:end_idx].to( - self.device) - attn_metadata = model_input.attn_metadata - attn_metadata.num_prefills = 1 - attn_metadata.slot_mapping = orig_slot_mapping[ - None, start_idx:end_idx].to(self.device) - input_lens = model_input.input_lens[i:i + 1].to(self.device) - t = model_input.t[i:i + 1].to(self.device) - p = model_input.p[i:i + 1].to(self.device) - output_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - input_lens, - t, - p, - model_input.num_samples, - kv_caches, - is_prompt=True) - next_token_ids.append(output_token_ids[0]) - start_idx = end_idx - - if model_input.async_callback is not None: - model_input.async_callback() - # Retrieve the outputs to CPU. - next_token_ids = [ - output_token_ids.cpu().tolist() - for output_token_ids in next_token_ids - ] - - # NOTE(woosuk): Minimal code to construct the sampler outputs. - # The TPU backend does not reuse the sampler, since the TPU backend - # does not support advanced sampling parameters such as logprobs. - zero_logprob = Logprob(0.0) - sampler_outputs = [] - for i, seq_group in enumerate(model_input.seq_groups): - seq_ids = seq_group - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - seq_outputs = [] - for j in range(model_input.n[i]): - next_token_id = next_token_ids[i][j] - seq_outputs.append( - SequenceOutput(seq_id, next_token_id, - {next_token_id: zero_logprob})) - sampler_outputs.append( - CompletionSequenceGroupOutput(seq_outputs, None)) - return [SamplerOutput(sampler_outputs)] - else: - token_ids = model_input.token_ids.to(self.device) - position_ids = model_input.position_ids.to(self.device) - attn_metadata = model_input.attn_metadata - attn_metadata.slot_mapping = attn_metadata.slot_mapping.to( - self.device) - attn_metadata.block_tables = attn_metadata.block_tables.to( - self.device) - attn_metadata.context_lens = attn_metadata.context_lens.to( - self.device) - t = model_input.t.to(self.device) - p = model_input.p.to(self.device) - input_lens = model_input.input_lens.to(self.device) - for i in range(num_steps): - slot_mapping = attn_metadata.slot_mapping - output_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - input_lens, - t, - p, - model_input.num_samples, - kv_caches, - is_prompt=False) - self.cached_step_outputs.append(output_token_ids) - - if i < num_steps - 1: - # Prepare the inputs for the next step. - token_ids = output_token_ids.unsqueeze(dim=1).int() - position_ids = position_ids + 1 - attn_metadata.context_lens = attn_metadata.context_lens + 1 - - block_tables = attn_metadata.block_tables - block_number = block_tables.gather( - 1, - position_ids.long() // self.block_size) - block_offset = position_ids % self.block_size - - is_padding = slot_mapping == _PAD_SLOT_ID - slot_mapping = block_number * self.block_size + block_offset - slot_mapping = slot_mapping.long() - slot_mapping = torch.where(is_padding, _PAD_SLOT_ID, - slot_mapping) - attn_metadata.slot_mapping = slot_mapping - - if model_input.async_callback is not None: - model_input.async_callback() - - if num_steps > 1: - return [] - # Retrieve the outputs to CPU. - next_token_ids = self.cached_step_outputs.pop(0) - next_token_ids = next_token_ids.cpu().tolist() - sampler_output = _make_decode_output(next_token_ids, - model_input.seq_groups) - return [sampler_output] - - -class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): - - def __init__(self, model: nn.Module): - self.model = model - compiled_callable = torch.compile(self.forward, - backend="openxla", - fullgraph=True, - dynamic=False) - super().__init__(compiled_callable) - - def __call__(self, *args, is_prompt: bool, **kwargs): - if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: - # not fully compiled yet, or not using the custom dispatcher, - # let PyTorch handle it - return self.compiled_callable(*args, **kwargs) - # the 3 compiled codes are: - # 0: for profiling - # 1: for prompt - # 2: for decode - # dispatch to the compiled code directly, skip PyTorch - if is_prompt: - with self.dispatch_to_code(1): - return self.forward(*args, **kwargs) - else: - with self.dispatch_to_code(2): - return self.forward(*args, **kwargs) - - def forward( - self, - token_ids: torch.Tensor, - position_ids: torch.Tensor, - attn_metadata: AttentionMetadata, - input_lens: torch.Tensor, - t: torch.Tensor, - p: torch.Tensor, - num_samples: int, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> torch.Tensor: - """Executes the forward pass of the model and samples the next token. - - Args: - token_ids: The input token IDs of shape [batch_size, seq_len]. - position_ids: The input position IDs of shape [batch_size, seq_len]. - attn_metadata: The Pallas attention metadata. - input_lens: The actual input lengths of shape [batch_size]. - t: The sampling temperature of shape [batch_size]. - p: The top-p probability of shape [batch_size]. - num_samples: Number of samples to draw from each logits vector. - kv_caches: The key and value caches. They can be None during the - memory profiling at initialization. - """ - batch_size, seq_len = token_ids.shape - # Calculate the positions to sample from. - start_indicies = torch.arange( - batch_size, dtype=torch.int32, device=input_lens.device) * seq_len - logits_indices = start_indicies + input_lens - 1 - - # FIXME(woosuk): This is a temporary hack to avoid using the existing - # sampler and sampling metadata. - sampling_metadata = SamplingMetadata( - seq_groups=[], - selected_token_indices=logits_indices, - categorized_sample_indices={}, - num_prompts=attn_metadata.num_prefills, - ) - - # Skip this in memory profiling at initialization. - if kv_caches[0][0].numel() > 0: - # index_copy_(slot_mapping) only works when the inserted dimension - # is 0. However, the KV cache in the Pallas backend has the shape - # [num_kv_heads, num_blocks, block_size, head_size]. To make it - # work, we need to flatten the first three dimensions and modify - # the slot_mapping accordingly. - num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape - slot_mapping = attn_metadata.slot_mapping - slot_mapping = slot_mapping.flatten() - head_indicies = torch.arange(0, - num_kv_heads, - device=slot_mapping.device, - dtype=slot_mapping.dtype) - head_indicies *= block_size * num_blocks - slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( - -1, num_kv_heads) - slot_mapping = slot_mapping + head_indicies.view(1, -1) - slot_mapping = slot_mapping.flatten() - attn_metadata.slot_mapping = slot_mapping - - hidden_states = self.model( - token_ids, - position_ids, - kv_caches, - attn_metadata, - ) - hidden_states = hidden_states.flatten(0, 1) - logits = self.model.compute_logits(hidden_states, sampling_metadata) - - # Argmax sampling. - argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - argmax_token_ids = argmax_token_ids.repeat(1, num_samples) - - # Zero temperature means greedy decoding. Avoid division by zero. - nonzero_t = torch.where(t != 0, t, 1.0) - logits = logits / nonzero_t.unsqueeze(dim=1) - if _ENABLE_TOP_P: - logits = _apply_top_p(logits, p.unsqueeze(dim=1)) - - # Random sampling. - probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - sampled_token_ids = torch.multinomial(probs, - num_samples, - replacement=True) - if num_samples == 1: - argmax_token_ids = argmax_token_ids.squeeze(dim=-1) - sampled_token_ids = sampled_token_ids.squeeze(dim=-1) - next_token_ids = torch.where(t != 0, sampled_token_ids, - argmax_token_ids) - return next_token_ids - - -def _get_padded_prefill_len(x: int) -> int: - # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence - # length to be a multiple of 16. We pad the prompt length to the nearest - # multiple of 16. This is also good for performance. - if x <= 16: - return 16 - return 1 << (x - 1).bit_length() - - -def _get_padded_batch_size(batch_size: int) -> int: - # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. - # To meet this requirement in the simplest way, we set the minimal batch - # size to 8. - if batch_size <= 8: - return 8 - else: - return ((batch_size + 15) // 16) * 16 - - -def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - logits_sorted = torch.sort(logits, dim=-1, descending=True).values - sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1) - cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True) - cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) - logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) - return logits - - -def _make_decode_output( - next_token_ids: List[int], - seq_groups: List[List[int]], -) -> SamplerOutput: - zero_logprob = Logprob(0.0) - sampler_outputs = [] - batch_idx = 0 - for seq_group in seq_groups: - seq_ids = seq_group - seq_outputs = [] - for seq_id in seq_ids: - next_token_id = next_token_ids[batch_idx] - seq_outputs.append( - SequenceOutput(seq_id, next_token_id, - {next_token_id: zero_logprob})) - batch_idx += 1 - sampler_outputs.append(CompletionSequenceGroupOutput( - seq_outputs, None)) - return SamplerOutput(sampler_outputs) From 774a112011cdeba19eff06641d1bd5ac09d89cee Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 16:10:03 +0000 Subject: [PATCH 09/18] fix refs --- vllm/v1/executor/uniproc_tpu_executor.py | 4 +- vllm/v1/worker/tpu_model_runner_new.py | 1389 ---------------------- vllm/v1/worker/tpu_worker_new.py | 244 ---- 3 files changed, 2 insertions(+), 1635 deletions(-) delete mode 100644 vllm/v1/worker/tpu_model_runner_new.py delete mode 100644 vllm/v1/worker/tpu_worker_new.py diff --git a/vllm/v1/executor/uniproc_tpu_executor.py b/vllm/v1/executor/uniproc_tpu_executor.py index 957fe20355cee..492de560627b5 100644 --- a/vllm/v1/executor/uniproc_tpu_executor.py +++ b/vllm/v1/executor/uniproc_tpu_executor.py @@ -6,7 +6,7 @@ from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.worker.tpu_worker_new import TPUWorker +from vllm.v1.worker.tpu_worker import TPUWorker logger = init_logger(__name__) @@ -38,7 +38,7 @@ def _create_worker( if distributed_init_method is None: distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - + return TPUWorker( vllm_config=self.vllm_config, local_rank=local_rank, diff --git a/vllm/v1/worker/tpu_model_runner_new.py b/vllm/v1/worker/tpu_model_runner_new.py deleted file mode 100644 index 2637319a492c9..0000000000000 --- a/vllm/v1/worker/tpu_model_runner_new.py +++ /dev/null @@ -1,1389 +0,0 @@ -import gc -import time -import enum -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Tuple, cast, Optional -from unittest.mock import patch - -import numpy as np -import torch -import torch.distributed -import torch.nn as nn - -# TPU XLA related -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr - -from vllm.attention import AttentionMetadata -from vllm.config import CompilationLevel, VllmConfig -from vllm.distributed.parallel_state import graph_capture -from vllm.forward_context import set_forward_context -from vllm.inputs import INPUT_REGISTRY -from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.sampling_params import SamplingType -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LayerBlockType, cdiv, is_pin_memory_available) -from vllm.v1.attention.backends.pallas import PallasMetadata, PallasAttentionBackend -from vllm.v1.engine.mm_input_mapper import MMInputMapperClient -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch - -if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput - -logger = init_logger(__name__) - -# Here we utilize the behavior that out-of-bound index is ignored. -# FIXME(woosuk): Find a more reliable way to prevent possible bugs. -_PAD_SLOT_ID = 1_000_000_000 -# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. -_ENABLE_TOP_P = False -# FIXME(woosuk): A temporary hack to support `n > 1`. -# This can significantly affect the performance if too large. -_MAX_NUM_SAMPLES = 128 - - -class ExecutionMode(enum.Enum): - PREFILL = enum.auto() - DECODE = enum.auto() - PREFIX_PREFILL = enum.auto() - - def is_prefill(self) -> bool: - return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) - - -@dataclass -class PrefillInputData: - - request_ids: List - prompt_lens: List - token_ids: List - position_ids: List - attn_metadata: List - - def zipped(self): - return zip(self.request_ids, self.prompt_lens, self.token_ids, - self.position_ids, self.attn_metadata) - - -@dataclass -class DecodeInputData: - - num_decodes: int - token_ids: Optional[torch.Tensor] = None - position_ids: Optional[torch.Tensor] = None - attn_metadata: PallasMetadata = None - - -class TPUModelRunner: - - def __init__( - self, - vllm_config: VllmConfig, - device: torch.device, - ): - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - self.device_config = vllm_config.device_config - - model_config = self.model_config - cache_config = self.cache_config - scheduler_config = self.scheduler_config - parallel_config = self.parallel_config - self.device = device - self.pin_memory = is_pin_memory_available() - self.dtype = self.model_config.dtype - if cache_config.cache_dtype == "auto": - self.kv_cache_dtype = self.dtype - else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] - - self.is_multimodal_model = model_config.is_multimodal_model - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_model_len = model_config.max_model_len - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = scheduler_config.max_num_batched_tokens - self.max_num_reqs = scheduler_config.max_num_seqs - - # Model-related. - self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() - self.hidden_size = model_config.get_hidden_size() - - # Multi-modal data support - self.input_registry = INPUT_REGISTRY - self.mm_registry = MULTIMODAL_REGISTRY - - # NOTE: Initialized input mapper is only used for processing dummy - # multimodal data into multimodal kwargs for GPU memory profiling. - self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) - self.mm_input_mapper_profiling.use_cache = False - - self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 - self.encoder_cache_size = self.scheduler_config.encoder_cache_size - - # Lazy initialization - # self.model: nn.Module # Set after load_model - self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] - # req_id -> (input_id -> encoder_output) - self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} - - # Request states. - self.requests: Dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=model_config.get_vocab_size(), - ) - - self.prefill_positions = torch.tensor(range(self.max_model_len), - device="cpu", - dtype=torch.int32).reshape( - 1, -1) - - self.new_req_ids = None - - # TODO: Remove this - # self.use_cuda_graph = (self.vllm_config.compilation_config.level - # == CompilationLevel.PIECEWISE - # and not self.model_config.enforce_eager) - # # TODO(woosuk): Provide an option to tune the max cudagraph batch size. - # # The convention is different. - # # self.cudagraph_batch_sizes sorts in ascending order. - # # The batch sizes in the config are in descending order. - # self.cudagraph_batch_sizes = list( - # reversed(self.vllm_config.compilation_config.capture_sizes)) - - # # Cache the device properties. - # self.device_properties = torch.cuda.get_device_properties(self.device) - # self.num_sms = self.device_properties.multi_processor_count - - # # Persistent buffers for CUDA graphs. - # self.input_ids = torch.zeros(self.max_num_tokens, - # dtype=torch.int32, - # device=self.device) - # self.positions = torch.zeros(self.max_num_tokens, - # dtype=torch.int64, - # device=self.device) - # self.inputs_embeds = torch.zeros( - # (self.max_num_tokens, self.hidden_size), - # dtype=self.dtype, - # device=self.device) - - # # OPTIMIZATION: Cache the tensors rather than creating them every step. - # self.arange_np = np.arange(max(self.max_num_reqs + 1, - # self.max_model_len), - # dtype=np.int32) - # # NOTE(woosuk): These tensors are "stateless", i.e., they are literally - # # a faster version of creating a new tensor every time. Thus, we should - # # not make any assumptions about the values in these tensors. - # self.input_ids_cpu = torch.zeros(self.max_num_tokens, - # dtype=torch.int32, - # device="cpu", - # pin_memory=self.pin_memory) - # self.input_ids_np = self.input_ids_cpu.numpy() - # self.positions_cpu = torch.zeros(self.max_num_tokens, - # dtype=torch.int64, - # device="cpu", - # pin_memory=self.pin_memory) - # self.positions_np = self.positions_cpu.numpy() - # self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - # dtype=torch.int32, - # device="cpu", - # pin_memory=self.pin_memory) - # self.slot_mapping_np = self.slot_mapping_cpu.numpy() - # self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, - # dtype=torch.int32, - # device="cpu", - # pin_memory=self.pin_memory) - # self.query_start_loc_np = self.query_start_loc_cpu.numpy() - # self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, - # dtype=torch.int32, - # device="cpu", - # pin_memory=self.pin_memory) - # self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() - - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: - # Remove stopped requests from the cached states. - # Keep the states of the pre-empted requests. - for req_id in scheduler_output.finished_req_ids: - self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) - - # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) - - # Remove the requests from the persistent batch. - stopped_req_ids = set().union( - scheduler_output.preempted_req_ids, - scheduler_output.finished_req_ids, - ) - removed_req_indices: List[int] = [] - for req_id in stopped_req_ids: - req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) - - # Update the states of the running requests. - for req_data in scheduler_output.scheduled_running_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] - req_index = self.input_batch.req_id_to_index[req_id] - - # Update the num_computed_tokens. - req_state.num_computed_tokens = req_data.num_computed_tokens - self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - - # Update the block table. - num_new_blocks = len(req_data.new_block_ids) - if num_new_blocks == 0: - continue - start_index = len(req_state.block_ids) - req_state.block_ids.extend(req_data.new_block_ids) - self.input_batch.block_table.append_row(req_index, start_index, - req_data.new_block_ids) - - req_ids_to_add: List[str] = [] - # Add new requests to the cached states. - for new_req_data in scheduler_output.scheduled_new_reqs: - req_id = new_req_data.req_id - sampling_params = new_req_data.sampling_params - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_params.seed) - else: - generator = None - - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - prompt=new_req_data.prompt, - mm_inputs=new_req_data.mm_inputs, - mm_positions=new_req_data.mm_positions, - sampling_params=sampling_params, - generator=generator, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - ) - req_ids_to_add.append(req_id) - - # Update the cached states of the resumed requests. - for res_req_data in scheduler_output.scheduled_resumed_reqs: - req_id = res_req_data.req_id - req_state = self.requests[req_id] - - req_state.block_ids = res_req_data.block_ids - req_state.num_computed_tokens = res_req_data.num_computed_tokens - req_ids_to_add.append(req_id) - - # For TPU, we keep all of the decode requests before the - # prefill requests in the batch sequence. - # 1. First condense, so all decodes move to start - # 2. Then add new prefills to the end of the batch - removed_req_indices = sorted(removed_req_indices, reverse=True) - if removed_req_indices: - self.input_batch.condense(removed_req_indices) - - for req_id in req_ids_to_add: - req_state = self.requests[req_id] - self.input_batch.add_request(req_state, None) # Append last - self.new_req_ids = req_ids_to_add - - def _prepare_prefill_inputs( - self, - num_scheduled_tokens: List[int], - ) -> PrefillInputData: - # Each prefill run separately with shape [1, padded_prompt_len]. - # So we create lists that will be used in execute_model(). - - prefill_request_ids = [] - prefill_prompt_lens = [] - prefill_token_ids = [] - prefill_position_ids = [] - prefill_attn_metadata = [] - - # DECODES are the first num_decodes REQUESTS. - # PREFILLS are the next num_reqs - num_decodes REQUESTS. - num_reqs = self.input_batch.num_reqs - num_decodes = num_reqs - self.new_req_ids - for idx in range(num_decodes, num_reqs): - prefill_request_ids.append(self.input_batch.req_ids[idx]) - - prompt_len = num_scheduled_tokens[idx] - prefill_prompt_lens.append(prompt_len) - - # STATIC SHAPE: prefills are padded to the next power of 2. - padded_prompt_len = _get_padded_prefill_len(prompt_len) - assert padded_prompt_len <= self.max_model_len - - # TOKEN_IDS. - token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ - idx, :padded_prompt_len].reshape(1, -1)) - prefill_token_ids.append(token_ids.to(self.device)) - - # POSITIONS. - positions = self.prefill_positions[:, :padded_prompt_len] - prefill_position_ids.append(positions.to(self.device)) - - # SLOT_MAPPING. - # The "slot" is the "physical index" of a token in the KV cache. - # Look up the block_idx in the block table (logical<>physical map) - # to compute this. - block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor( - ) - block_numbers = block_table_cpu_tensor[idx, positions // - self.block_size].reshape( - 1, -1) - block_offsets = positions % self.block_size - slot_mapping = block_numbers * self.block_size + block_offsets - # Set an out of range value for the padding tokens so that they - # are ignored when inserting into the KV cache. - slot_mapping[:, prompt_len:] = _PAD_SLOT_ID - slot_mapping = slot_mapping.long() - - prefill_attn_metadata.append( - PallasMetadata( - num_prefills=1, - num_prefill_tokens=padded_prompt_len, - num_decode_tokens=0, - slot_mapping=slot_mapping.to(self.device), - block_tables=None, - context_lens=None, - effective_query_lens=None, - )) - - return PrefillInputData( - request_ids=prefill_request_ids, - prompt_lens=prefill_prompt_lens, - token_ids=prefill_token_ids, - position_ids=prefill_position_ids, - attn_metadata=prefill_attn_metadata, - ) - - def _prepare_decode_inputs(self) -> DecodeInputData: - # Decodes run as one single padded batch with shape [batch, 1] - # - # We need to set _PAD_SLOT_ID for the padding tokens in the - # slot_mapping, such that the attention KV cache insertion - # logic knows to ignore those indicies. Otherwise, the - # padding data can be dummy since we have a causal mask. - - # DECODES are the first num_decodes REQUESTS. - # PREFILLS are the next num_reqs - num_decodes REQUESTS. - num_reqs = self.input_batch.num_reqs - num_decodes = num_reqs - self.new_req_ids - - if num_decodes == 0: - return DecodeInputData(num_decodes=0) - - # PAD FOR STATIC SHAPES. - padded_batch_size = _get_padded_batch_size(num_decodes) - - # POSITIONS. [batch, 1] - # We slice at the end, since we use the positions for gathering. - positions = torch.from_numpy( - self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) - index = positions.to(torch.int64) - positions = positions[:padded_batch_size] - - # TOKEN_IDS. [batch, 1] - token_ids = torch.gather( - input=torch.from_numpy(self.input_batch.token_ids_cpu), - dim=1, - index=index, - )[:padded_batch_size] - - # SLOT_MAPPING [batch, 1] - # The "slot" is the "physical index" of a token in the KV cache. - # Look up the block_idx in the block table (logical<>physical map) - # to compute this. - block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor() - block_number = torch.gather(input=block_table_cpu_tensor, - dim=1, - index=(index // self.block_size)) - block_offsets = index % self.block_size - slot_mapping = block_number * self.block_size + block_offsets - # Set an out of range value for the padding tokens so that they - # are ignored when inserting into the KV cache. - slot_mapping[num_decodes:] = _PAD_SLOT_ID - slot_mapping = slot_mapping[:padded_batch_size] - - # BLOCK_TABLE [batch, max_num_blocks_per_req] - block_table = block_table_cpu_tensor[:padded_batch_size] - - # CONTEXT_LENS [batch_size] - context_lens = (positions.reshape(-1) + 1) - - # CPU<>TPU sync happens here. - return DecodeInputData(num_decodes=num_decodes, - token_ids=token_ids.to(self.device), - position_ids=positions.to(self.device), - attn_metadata=PallasMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=padded_batch_size, - slot_mapping=slot_mapping.to(self.device), - block_tables=block_table.to(self.device), - context_lens=context_lens.to(self.device), - effective_query_lens=None, - )) - - def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 - - num_decodes = num_reqs - self.new_req_ids - - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - # TODO: Verify this works with TPUs - self.input_batch.block_table.commit(num_reqs) - - # Get the number of scheduled tokens for each request. - # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens = [] - max_num_scheduled_tokens = 0 - for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - assert req_id is not None - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens.append(num_tokens) - max_num_scheduled_tokens = max(max_num_scheduled_tokens, - num_tokens) - - # NOTE: Assert that all the decodes are "decodes". - if idx < num_decodes: - assert num_tokens == 1 - assert max_num_scheduled_tokens > 0 - - return ( - self._prepare_prefill_inputs(num_scheduled_tokens), - self._prepare_decode_inputs(num_decodes), - ) - - # # OPTIMIZATION: Start copying the block table first. - # # This way, we can overlap the copy with the following CPU operations. - # self.input_batch.block_table.commit(num_reqs) - - # # Get the number of scheduled tokens for each request. - # # TODO: The Python loop can be slow. Optimize. - # num_scheduled_tokens = [] - # max_num_scheduled_tokens = 0 - # for req_id in self.input_batch.req_ids[:num_reqs]: - # assert req_id is not None - # num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # num_scheduled_tokens.append(num_tokens) - # max_num_scheduled_tokens = max(max_num_scheduled_tokens, - # num_tokens) - # num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) - # assert max_num_scheduled_tokens > 0 - - # # Get request indices. - # # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - # req_indices = np.repeat(self.arange_np[:num_reqs], - # num_scheduled_tokens) - - # # Get batched arange. - # # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # arange = np.concatenate( - # [self.arange_np[:n] for n in num_scheduled_tokens]) - - # # Get positions. - # positions_np = self.positions_np[:total_num_scheduled_tokens] - # np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - # arange, - # out=positions_np) - - # # Get token indices. - # # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # # where M is the max_model_len. - # token_indices = (positions_np + - # req_indices * self.input_batch.token_ids_cpu.shape[1]) - # # NOTE(woosuk): We use torch.index_select instead of np.take here - # # because torch.index_select is much faster than np.take for large - # # tensors. - # torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - # 0, - # torch.from_numpy(token_indices), - # out=self.input_ids_cpu[:total_num_scheduled_tokens]) - - # # Calculate the slot mapping. - # # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # # where K is the max_num_blocks_per_req and the block size is 2. - # # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # # because M (max_model_len) is not necessarily divisible by block_size. - # block_table_indices = (req_indices * self.max_num_blocks_per_req + - # positions_np // self.block_size) - # # NOTE(woosuk): We use torch.index_select instead of np.take here - # # because torch.index_select is much faster than np.take for large - # # tensors. - # block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - # block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - # block_offsets = positions_np % self.block_size - # np.add(block_numbers * self.block_size, - # block_offsets, - # out=self.slot_mapping_np[:total_num_scheduled_tokens]) - - # # Prepare the attention metadata. - # self.query_start_loc_np[0] = 0 - # np.cumsum(num_scheduled_tokens, - # out=self.query_start_loc_np[1:num_reqs + 1]) - - # seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + - # num_scheduled_tokens) - # max_seq_len = seq_lens.max() - # self.seq_start_loc_np[0] = 0 - # np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1]) - - # # Copy the tensors to the GPU. - # self.input_ids[:total_num_scheduled_tokens].copy_( - # self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) - # self.positions[:total_num_scheduled_tokens].copy_( - # self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - # query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( - # self.device, non_blocking=True) - # seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to( - # self.device, non_blocking=True) - # slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( - # self.device, non_blocking=True).long() - - # # Prepare for cascade attention if needed. - # common_prefix_len = (scheduler_output.num_common_prefix_blocks * - # self.block_size) - # if common_prefix_len == 0: - # # Common case. - # use_cascade = False - # else: - # # NOTE(woosuk): Cascade attention uses two attention kernels: one - # # for the common prefix and the other for the rest. For the first - # # kernel, we concatenate all the query tokens (possibly from - # # different requests) and treat them as if they are from the same - # # request. Then, we use bi-directional attention to process the - # # common prefix in the KV cache. Importantly, this means that the - # # first kernel does not do any masking. - - # # Consider the following example: - # # Request 1's input query: [D, E, X] - # # Request 1's kv cache: [A, B, C, D, E, X] - # # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) - # # Request 2's input query: [E, Y] - # # Request 2's kv cache: [A, B, C, D, E, Y] - # # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) - - # # If we use [A, B, C, D, E] as the common prefix, then the - # # first kernel will compute the bi-directional attention between - # # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. - # # However, this is wrong because D in Request 1 should not attend to - # # E in the common prefix (i.e., we need masking). - # # To avoid this, [A, B, C, D] should be the common prefix. - # # That is, the common prefix should be capped by the minimum - # # num_computed_tokens among the requests, and plus one to include - # # the first token of the query. - - # # In practice, we use [A, B, C] as the common prefix, instead of - # # [A, B, C, D] (i.e., the common prefix is capped by the minimum - # # num_computed_tokens, without plus one). - # # This is because of an implementation detail: We want to always - # # use two kernels for cascade attention. Let's imagine: - # # Request 3's input query: [D] - # # Request 3's kv cache: [A, B, C, D] - # # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) - # # If we use [A, B, C, D] as the common prefix for Request 1-3, - # # then Request 3 will be processed only by the first kernel, - # # and the second kernel will get an empty input. While this is not - # # a fundamental problem, our current implementation does not support - # # this case. - # common_prefix_len = min( - # common_prefix_len, - # self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) - # # common_prefix_len should be a multiple of the block size. - # common_prefix_len = (common_prefix_len // self.block_size * - # self.block_size) - # use_cascade = FlashAttentionBackend.use_cascade_attention( - # common_prefix_len=common_prefix_len, - # query_lens=num_scheduled_tokens, - # num_query_heads=self.num_query_heads, - # num_kv_heads=self.num_kv_heads, - # use_alibi=False, # FIXME - # use_sliding_window=self.sliding_window is not None, - # num_sms=self.num_sms, - # ) - - # if use_cascade: - # # TODO: Optimize. - # cu_prefix_query_lens = torch.tensor( - # [0, total_num_scheduled_tokens], - # dtype=torch.int32, - # device=self.device) - # cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], - # dtype=torch.int32, - # device=self.device) - # cu_suffix_kv_lens = ( - # self.seq_start_loc_np[:num_reqs + 1] - - # self.arange_np[:num_reqs + 1] * common_prefix_len) - # cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to( - # self.device) - # else: - # cu_prefix_query_lens = None - # cu_prefix_kv_lens = None - # cu_suffix_kv_lens = None - - # attn_metadata = FlashAttentionMetadata( - # num_actual_tokens=total_num_scheduled_tokens, - # max_query_len=max_num_scheduled_tokens, - # query_start_loc=query_start_loc, - # max_seq_len=max_seq_len, - # seq_start_loc=seq_start_loc, - # block_table=( - # self.input_batch.block_table.get_device_tensor()[:num_reqs]), - # slot_mapping=slot_mapping, - # use_cascade=use_cascade, - # common_prefix_len=common_prefix_len, - # cu_prefix_query_lens=cu_prefix_query_lens, - # cu_prefix_kv_lens=cu_prefix_kv_lens, - # cu_suffix_kv_lens=cu_suffix_kv_lens, - # ) - # # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial - # # request in the batch. While we should not sample any token from this - # # partial request, we do so for simplicity. We will ignore the sampled - # # token from the partial request. - # # TODO: Support prompt logprobs. - # logits_indices = query_start_loc[1:] - 1 - # return attn_metadata, logits_indices - - def _prepare_sampling( - self, - scheduler_output: "SchedulerOutput", - ) -> SamplingMetadata: - skip_copy = True - if (scheduler_output.finished_req_ids - or scheduler_output.preempted_req_ids): - skip_copy = False - if (scheduler_output.scheduled_new_reqs - or scheduler_output.scheduled_resumed_reqs): - skip_copy = False - # Create the sampling metadata. - req_id_output_token_ids: Dict[str, List[int]] = \ - {req_id: req.output_token_ids \ - for req_id, req in self.requests.items()} - - sampling_metadata = self.input_batch.make_sampling_metadata( - req_id_output_token_ids, skip_copy) - return sampling_metadata - - def _execute_encoder(self, scheduler_output: "SchedulerOutput"): - scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs - if not scheduled_encoder_inputs: - return - - # Batch the multi-modal inputs. - mm_inputs: List[MultiModalKwargs] = [] - req_input_ids: List[Tuple[str, int]] = [] - for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): - req_state = self.requests[req_id] - for input_id in encoder_input_ids: - mm_inputs.append(req_state.mm_inputs[input_id]) - req_input_ids.append((req_id, input_id)) - batched_mm_inputs = MultiModalKwargs.batch(mm_inputs) - batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, - device=self.device) - - # Run the encoder. - # `encoder_outputs` is either of the following: - # 1. A tensor of shape [num_images, feature_size, hidden_size] - # in case when feature_size is fixed across all images. - # 2. A list (length: num_images) of tensors, each of shape - # [feature_size, hidden_size] in case when the feature size is - # dynamic depending on input images. - encoder_outputs = self.model.get_multimodal_embeddings( - **batched_mm_inputs) - - # Cache the encoder outputs. - for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} - self.encoder_cache[req_id][input_id] = output - - def _gather_encoder_outputs( - self, - scheduler_output: "SchedulerOutput", - ) -> List[torch.Tensor]: - encoder_outputs: List[torch.Tensor] = [] - num_reqs = self.input_batch.num_reqs - for req_id in self.input_batch.req_ids[:num_reqs]: - assert req_id is not None - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - req_state = self.requests[req_id] - num_computed_tokens = req_state.num_computed_tokens - mm_positions = req_state.mm_positions - for i, pos_info in enumerate(mm_positions): - start_pos = pos_info["offset"] - num_encoder_tokens = pos_info["length"] - - # The encoder output is needed if the two ranges overlap: - # [num_computed_tokens, - # num_computed_tokens + num_scheduled_tokens) and - # [start_pos, start_pos + num_encoder_tokens) - if start_pos >= num_computed_tokens + num_scheduled_tokens: - # The encoder output is not needed in this step. - break - if start_pos + num_encoder_tokens <= num_computed_tokens: - # The encoder output is already processed and stored - # in the decoder's KV cache. - continue - - start_idx = max(num_computed_tokens - start_pos, 0) - end_idx = min( - num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) - assert start_idx < end_idx - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] - encoder_output = self.encoder_cache[req_id][i] - encoder_outputs.append(encoder_output[start_idx:end_idx]) - return encoder_outputs - - def execute_model_xxx(): - prefill_data, decode_data = self._prepare_inputs(scheduler_output) - num_reqs = self.input_batch.num_reqs - sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) - - ######################### DECODES ######################### - # Decodes run as one single batch with [padded_batch, 1] - if decode_data.num_decodes > 0: - - # FORWARD. - selected_token_ids = self.model(decode_data.token_ids, - decode_data.position_ids, - decode_data.attn_metadata, - self.kv_caches, - is_prompt=False) - - # NOTE: TPU<>CPU sync happens here. - # We need to call .cpu() first to avoid recompilation. - token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] - sampled_token_ids_list = token_ids.tolist() - sampled_token_ids[:decode_data.num_decodes] = token_ids - - # UPDATE REQUEST STATE. - for i, req_id in enumerate( - self.input_batch.req_ids[:decode_data.num_decodes]): - req_state = self.requests[req_id] - - # TODO: ASSERT NO CHUNKED PREFILL. - assert scheduler_output.num_scheduled_tokens[req_id] == 1 - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - assert seq_len == req_state.num_tokens - - token_id = sampled_token_ids_list[i] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - req_state.output_token_ids.append(token_id) - - ######################### PREFILLS ######################### - # Prefills run separately with shape [1, padded_prefill_len], - # due to lack of variable length attention kernel so far. - for idx, (req_id, prompt_len, token_ids, position_ids, - attn_metadata) in enumerate(prefill_data.zipped()): - - # FORWARD. - selected_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - self.kv_caches, - is_prompt=True) - - # NOTE: TPU<>CPU sync happens here. - # We need to call .cpu() first to avoid recompilation. - token_id = selected_token_ids.cpu()[prompt_len - 1].item() - sampled_token_ids[decode_data.num_decodes + idx] = token_id - req_state = self.requests[req_id] - - # TODO: ASSERT NO PREFIX CACHING. - assert req_state.num_computed_tokens == 0 - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - - # TODO: ASSERT NO CHUNKED PREFILL. - assert seq_len == req_state.num_tokens - assert prompt_len == seq_len - - # UPDATE REQUEST STATE. - req_idx = self.input_batch.req_id_to_index[req_id] - self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id - req_state.output_token_ids.append(token_id) - - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids[:num_reqs], - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids_cpu=sampled_token_ids, - logprob_token_ids_cpu=None, - logprobs_cpu=None, - ) - - @torch.inference_mode() - def execute_model( - self, - scheduler_output: "SchedulerOutput", - ) -> ModelRunnerOutput: - self._update_states(scheduler_output) - - # TODO: Ressurect this code - # if self.is_multimodal_model: - # # Run the multimodal encoder if any. - # self._execute_encoder(scheduler_output) - # encoder_outputs = self._gather_encoder_outputs(scheduler_output) - # else: - # encoder_outputs = [] - - # Prepare the decoder inputs. - prefill_data, decode_data = self._prepare_inputs(scheduler_output) - - num_reqs = self.input_batch.num_reqs - sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) - - # attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - num_input_tokens = num_scheduled_tokens - # attn_metadata.num_input_tokens = num_input_tokens - - # TODO: Resurrect this code - # if self.is_multimodal_model: - # # NOTE(woosuk): To unify token ids and soft tokens (vision - # # embeddings), we always use embeddings (rather than token ids) - # # as input to the multimodal model, even when the input is text. - # input_ids = self.input_ids[:num_scheduled_tokens] - # if encoder_outputs: - # inputs_embeds = self.model.get_input_embeddings( - # input_ids, encoder_outputs) - # else: - # inputs_embeds = self.model.get_input_embeddings(input_ids) - # # TODO(woosuk): Avoid the copy. Optimize. - # self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) - # inputs_embeds = self.inputs_embeds[:num_input_tokens] - # input_ids = None - # else: - # # For text-only models, we use token ids as input. - # # While it is possible to use embeddings as input just like the - # # multimodal models, it is not desirable for performance since - # # then the embedding layer is not included in the CUDA graph. - # input_ids = self.input_ids[:num_input_tokens] - # inputs_embeds = None - - ######################### DECODES ######################### - # Decodes run as one single batch with [padded_batch, 1] - if decode_data.num_decodes > 0: - # FORWARD. - selected_token_ids = self.model(decode_data.token_ids, - decode_data.position_ids, - decode_data.attn_metadata, - self.kv_caches, - is_prompt=False) - - # NOTE: TPU<>CPU sync happens here. - # We need to call .cpu() first to avoid recompilation. - token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] - sampled_token_ids_list = token_ids.tolist() - sampled_token_ids[:decode_data.num_decodes] = token_ids - - # UPDATE REQUEST STATE. - for i, req_id in enumerate( - self.input_batch.req_ids[:decode_data.num_decodes]): - req_state = self.requests[req_id] - - # TODO: ASSERT NO CHUNKED PREFILL. - assert scheduler_output.num_scheduled_tokens[req_id] == 1 - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - assert seq_len == req_state.num_tokens - - # TODO: Verify if req_id_to_index mapping is needed here! - token_id = sampled_token_ids_list[i] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - req_state.output_token_ids.append(token_id) - - ######################### PREFILLS ######################### - # Prefills run separately with shape [1, padded_prefill_len], - # due to lack of variable length attention kernel so far. - for idx, (req_id, prompt_len, token_ids, position_ids, - attn_metadata) in enumerate(prefill_data.zipped()): - # FORWARD. - selected_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - self.kv_caches, - is_prompt=True) - - # NOTE: TPU<>CPU sync happens here. - # We need to call .cpu() first to avoid recompilation. - token_id = selected_token_ids.cpu()[prompt_len - 1].item() - sampled_token_ids[decode_data.num_decodes + idx] = token_id - req_state = self.requests[req_id] - - # TODO: ASSERT NO PREFIX CACHING. - assert req_state.num_computed_tokens == 0 - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - - # TODO: ASSERT NO CHUNKED PREFILL. - assert seq_len == req_state.num_tokens - assert prompt_len == seq_len - - # UPDATE REQUEST STATE. - req_idx = self.input_batch.req_id_to_index[req_id] - self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id - req_state.output_token_ids.append(token_id) - - # TODO: Remove - # # Sample the next token and get logprobs if needed. - # sampling_metadata = self._prepare_sampling(scheduler_output) - # sampler_output = self.model.sample( - # logits=logits, - # sampling_metadata=sampling_metadata, - # ) - - # sampled_token_ids = sampler_output.sampled_token_ids - # # TODO(woosuk): The following loop can be slow since it iterates over - # # the requests one by one. Optimize. - # num_reqs = self.input_batch.num_reqs - # for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - # assert req_id is not None - # req_state = self.requests[req_id] - # seq_len = (req_state.num_computed_tokens + - # scheduler_output.num_scheduled_tokens[req_id]) - # assert seq_len <= req_state.num_tokens - # if seq_len == req_state.num_tokens: - # # Append the sampled token to the output token ids. - # token_id = sampled_token_ids[i] - # self.input_batch.token_ids_cpu[i, seq_len] = token_id - # self.input_batch.num_tokens[i] += 1 - # req_state.output_token_ids.append(token_id) - # else: - # # Ignore the sampled token from the partial request. - # # Rewind the generator state as if the token was not sampled. - # generator = self.input_batch.generators.get(i) - # if generator is not None: - # # This relies on cuda-specific torch-internal impl details - # generator.set_offset(generator.get_offset() - 4) - - # if sampler_output.logprob_token_ids is None: - # logprob_token_ids = None - # else: - # logprob_token_ids = sampler_output.logprob_token_ids.cpu() - # if sampler_output.logprobs is None: - # logprobs = None - # else: - # logprobs = sampler_output.logprobs.cpu() - - # num_reqs entries should be non-None - assert all( - req_id is not None for req_id in - self.input_batch.req_ids[:num_reqs]), "req_ids contains None" - req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) - - model_runner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids_cpu=sampled_token_ids, - logprob_token_ids_cpu=None, - logprobs_cpu=None, - ) - - return model_runner_output - - def load_model(self) -> None: - self.device = self.device_config.device - - # NOTE(woosuk): While the executor assigns the TP ranks to the worker - # process, the ranks can be different from the ranks internally assigned - # by the xm runtime. Therefore, there is a mismatch in the rank - # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. - # This is not a problem in linear layers because all-reduce is - # rank-agnostic. However, it matters for all-gather as the ranks - # determine the order of concatenating the output tensors. - # As a workaround, we use the xm's rank assignment only when loading - # the embedding weights. - xm_tp_rank = xr.global_ordinal() - with patch( - "vllm.model_executor.layers.vocab_parallel_embedding." - "get_tensor_model_parallel_rank", - return_value=xm_tp_rank): - model = get_model(vllm_config=self.vllm_config) - model = model.eval() - xm.wait_device_ops() - model = ModelWrapper(model) - self.model = torch.compile(model, - backend="openxla", - fullgraph=True, - dynamic=False) - - @torch.inference_mode() - def _dummy_run( - self, - batch_size: int, - seq_len: int, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - exec_mode: ExecutionMode, - ) -> None: - exec_mode = ExecutionMode(exec_mode) - if exec_mode.is_prefill(): - seq_len = (seq_len + 15) // 16 * 16 - token_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((batch_size, seq_len), - dtype=torch.int64, - device=self.device) - input_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) - if exec_mode == ExecutionMode.PREFILL: - attn_metadata = PallasMetadata( - num_prefills=batch_size, - num_prefill_tokens=batch_size * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - block_tables=None, - context_lens=None, - effective_query_lens=None, - ) - - else: - context_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) - - block_tables = torch.zeros( - (batch_size, self.max_num_blocks_per_req), - dtype=torch.int32, - device=self.device) - - effective_query_lens = torch.ones_like(context_lens) - - attn_metadata = PallasMetadata( - num_prefills=batch_size, - num_prefill_tokens=batch_size * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - block_tables=block_tables, - context_lens=context_lens, - effective_query_lens=effective_query_lens, - ) - else: - assert seq_len == 1 - token_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((batch_size, seq_len), - dtype=torch.int64, - device=self.device) - block_tables = torch.zeros( - (batch_size, self.max_num_blocks_per_req), - dtype=torch.int32, - device=self.device) - context_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) - input_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) - attn_metadata = PallasMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size * seq_len, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - block_tables=block_tables, - context_lens=context_lens, - ) - - t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) - p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) - num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 - - # NOTE(woosuk): There are two stages of compilation: torch.compile and - # XLA compilation. Using `mark_dynamic` can reduce the torch.compile - # overhead by reusing the FX graph for different shapes. - # However, the XLA graph will still require static shapes and needs to - # be re-compiled for every different shapes. This overhead is inevitable - # in the first run, but can be skipped afterwards as we cache the XLA - # graphs in the disk (VLLM_XLA_CACHE_PATH). - if exec_mode.is_prefill(): - # Prefll - torch._dynamo.mark_dynamic(token_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) - else: - # Decode - torch._dynamo.mark_dynamic(token_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(input_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - torch._dynamo.mark_dynamic(t, 0) - torch._dynamo.mark_dynamic(p, 0) - - # Dummy run. - self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, - num_samples, kv_caches) - - def profile_run(self) -> None: - """Profile to measure peak memory during forward pass.""" - - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value `None`. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - dummy_kv_caches = [( - torch.tensor([], dtype=torch.float32, device=self.device), - torch.tensor([], dtype=torch.float32, device=self.device), - ) for _ in range(self.num_attn_layers)] - - # Run empty forward. - self._dummy_run( - batch_size=1, - seq_len=self.max_num_tokens, # Will be rounded to 16 multiple - kv_caches=dummy_kv_caches, - exec_mode=ExecutionMode.PREFILL) - - def capture_model(self) -> None: - """Compile the model.""" - - logger.info("Compiling the model with different input shapes.") - - # Capture prefill shapes - start = time.perf_counter() - for batch_size in [1]: - seq_len = 16 - while True: - self._dummy_run(batch_size, - seq_len, - self.kv_caches, - exec_mode=ExecutionMode.PREFILL) - xm.wait_device_ops() - logger.info(" -- batch_size: %d, seq_len: %d", batch_size, - seq_len) - - if seq_len >= self.model_config.max_model_len: - break - - num_tokens = batch_size * seq_len - if num_tokens >= self.scheduler_config.max_num_batched_tokens: - break - - # Move to next seq_len - seq_len = seq_len * 2 - - end = time.perf_counter() - logger.info("Compilation for prefill shapes is done in %.2f [secs].", - end - start) - - # Capture decode shapes. - start = time.time() - seq_len = 1 - batch_size = 8 # Must be in sync with _get_padded_batch_size() - while True: - self._dummy_run(batch_size, - seq_len, - self.kv_caches, - exec_mode=ExecutionMode.DECODE) - xm.wait_device_ops() - logger.info(" -- batch_size: %d, seq_len: %d", batch_size, - seq_len) - - if batch_size >= self.scheduler_config.max_num_seqs: - break - - batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 - - end = time.time() - logger.info("Compilation for decode shapes is done in %.2f [secs].", - end - start) - - def initialize_kv_cache(self, num_blocks: int) -> None: - assert len(self.kv_caches) == 0 - kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) - for _ in range(self.num_attn_layers): - self.kv_caches.append(( - torch.zeros(kv_cache_shape, - dtype=self.kv_cache_dtype, - device=self.device), - torch.zeros(kv_cache_shape, - dtype=self.kv_cache_dtype, - device=self.device), - )) - - -# TODO: This is duplicate from V0, refactor -class ModelWrapper(nn.Module): - - def __init__(self, model: nn.Module): - super().__init__() - self.model = model - - def forward( - self, - token_ids: torch.Tensor, - position_ids: torch.Tensor, - attn_metadata: AttentionMetadata, - input_lens: torch.Tensor, - t: torch.Tensor, - p: torch.Tensor, - num_samples: int, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> torch.Tensor: - """Executes the forward pass of the model and samples the next token. - - Args: - token_ids: The input token IDs of shape [batch_size, seq_len]. - position_ids: The input position IDs of shape [batch_size, seq_len]. - attn_metadata: The Pallas attention metadata. - input_lens: The actual input lengths of shape [batch_size]. - t: The sampling temperature of shape [batch_size]. - p: The top-p probability of shape [batch_size]. - num_samples: Number of samples to draw from each logits vector. - kv_caches: The key and value caches. They can be None during the - memory profiling at initialization. - """ - batch_size, seq_len = token_ids.shape - # Calculate the positions to sample from. - start_indicies = torch.arange( - batch_size, dtype=torch.int32, device=input_lens.device) * seq_len - logits_indices = start_indicies + input_lens - 1 - - # FIXME(woosuk): This is a temporary hack to avoid using the existing - # sampler and sampling metadata. - sampling_metadata = SamplingMetadata( - seq_groups=[], - selected_token_indices=logits_indices, - categorized_sample_indices={}, - num_prompts=attn_metadata.num_prefills, - ) - - # Skip this in memory profiling at initialization. - if kv_caches[0][0].numel() > 0: - # index_copy_(slot_mapping) only works when the inserted dimension - # is 0. However, the KV cache in the Pallas backend has the shape - # [num_kv_heads, num_blocks, block_size, head_size]. To make it - # work, we need to flatten the first three dimensions and modify - # the slot_mapping accordingly. - num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape - slot_mapping = attn_metadata.slot_mapping - slot_mapping = slot_mapping.flatten() - head_indicies = torch.arange(0, - num_kv_heads, - device=slot_mapping.device, - dtype=slot_mapping.dtype) - head_indicies *= block_size * num_blocks - slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( - -1, num_kv_heads) - slot_mapping = slot_mapping + head_indicies.view(1, -1) - slot_mapping = slot_mapping.flatten() - attn_metadata.slot_mapping = slot_mapping - - hidden_states = self.model( - token_ids, - position_ids, - kv_caches, - attn_metadata, - ) - hidden_states = hidden_states.flatten(0, 1) - logits = self.model.compute_logits(hidden_states, sampling_metadata) - - # Argmax sampling. - argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - argmax_token_ids = argmax_token_ids.repeat(1, num_samples) - - # Zero temperature means greedy decoding. Avoid division by zero. - nonzero_t = torch.where(t != 0, t, 1.0) - logits = logits / nonzero_t.unsqueeze(dim=1) - if _ENABLE_TOP_P: - logits = _apply_top_p(logits, p.unsqueeze(dim=1)) - - # Random sampling. - probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - sampled_token_ids = torch.multinomial(probs, - num_samples, - replacement=True) - if num_samples == 1: - argmax_token_ids = argmax_token_ids.squeeze(dim=-1) - sampled_token_ids = sampled_token_ids.squeeze(dim=-1) - next_token_ids = torch.where(t != 0, sampled_token_ids, - argmax_token_ids) - return next_token_ids - - -# TODO: Duplicate with V0, refactor -def _get_padded_prefill_len(x: int) -> int: - # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence - # length to be a multiple of 16. We pad the prompt length to the nearest - # multiple of 16. This is also good for performance. - if x <= 16: - return 16 - return 1 << (x - 1).bit_length() - - -# TODO: Duplicate with V0, refactor -def _get_padded_batch_size(batch_size: int) -> int: - # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. - # To meet this requirement in the simplest way, we set the minimal batch - # size to 8. - if batch_size <= 8: - return 8 - else: - return ((batch_size + 15) // 16) * 16 - - -# TODO: Duplicate with V0, refactor -def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - logits_sorted = torch.sort(logits, dim=-1, descending=True).values - sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1) - cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True) - cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) - logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) - return logits diff --git a/vllm/v1/worker/tpu_worker_new.py b/vllm/v1/worker/tpu_worker_new.py deleted file mode 100644 index c696ae5f5349d..0000000000000 --- a/vllm/v1/worker/tpu_worker_new.py +++ /dev/null @@ -1,244 +0,0 @@ -"""A GPU worker class.""" -import gc -import os -from typing import TYPE_CHECKING, Optional, Tuple - -import torch -import torch.distributed -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr - -import vllm.envs as envs -from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) -from vllm.logger import init_logger -from vllm.model_executor import set_random_seed -from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size -from vllm.v1.core.scheduler import SchedulerOutput -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.worker.tpu_model_runner_new import TPUModelRunner - -logger = init_logger(__name__) - -if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput - - -class TPUWorker: - - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - ): - - # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) - else: - self.profiler = None - - def initialize(self): - os.environ["PJRT_DEVICE"] = "TPU" - torch.set_grad_enabled(False) - torch.set_default_dtype(self.model_config.dtype) - - # NOTE(woosuk): This is just to initialize the TP group and broadcast - # the input objects on CPU. The all-reduce and all-gather ops on TPU - # are invoked by `xm.all_reduce` and `xm.all_gather` which use their - # own context. - init_distributed_environment( - world_size=self.parallel_config.world_size, - rank=self.rank, - local_rank=self.local_rank, - distributed_init_method=self.distributed_init_method, - backend="gloo", - ) - ensure_model_parallel_initialized( - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size) - - # Device initialization should happen after initializing the distributed - # runtime. - self.device = xm.xla_device() - self.device_config.device = self.device - - # Set random seed. - set_random_seed(self.model_config.seed) - xm.set_rng_state(self.model_config.seed, self.device) - - # Increase the cache size limit, which is the maximum number of - # dynamo graphs that can be compiled. - # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and - # 30-40 graphs for decode. 128 is an arbitrary safe number. - torch._dynamo.config.cache_size_limit = 128 - # Use persistent cache to avoid XLA recompilation. - # NOTE(woosuk): Set per-rank cache path since different ranks - # can have slightly different XLA graphs. - world_size = self.parallel_config.world_size - rank = xr.global_ordinal() - per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, - f"tp{world_size}_rank{rank}") - xr.initialize_cache(per_rank_path, readonly=False) - - # Init ModelRunner here, so that we have access to self.device. - self.model_runner = TPUModelRunner(self.vllm_config, self.device) - - def load_model(self) -> None: - self.model_runner.load_model() - - @torch.inference_mode() - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - - self.model_runner.profile_run() - - # Synchronize before measuring the memory usage. - xm.wait_device_ops() - - # Get the maximum amount of memory used by the model weights and - # intermediate activations. - m = xm.get_memory_info(self.device) - total_tpu_memory = m["bytes_limit"] - peak_memory = m[ - "peak_bytes_used"] # Weights + intermediate activations. - logger.debug("Peak Used: %sGB", peak_memory // 1024 // 1024 // 1024) - logger.debug("Total Memory: %sGB", - total_tpu_memory // 1024 // 1024 // 1024) - - cache_block_size = _get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - num_tpu_blocks = int( - (total_tpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) - num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 - return num_tpu_blocks, 0 - - def initialize_cache(self, num_tpu_blocks: int) -> None: - """Allocate TPU and CPU KV cache with the specified number of blocks.""" - if num_tpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - - max_seq_len = self.cache_config.block_size * num_tpu_blocks - max_model_len = self.model_config.max_model_len - if max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") - - self.model_runner.initialize_kv_cache(num_tpu_blocks) - - # For debug: Get the maximum amount of memory used by the model weights and - # intermediate activations. - # TODO: Remove this? - xm.mark_step() - xm.wait_device_ops() - m = xm.get_memory_info(self.device) - peak_memory = m[ - "peak_bytes_used"] # Weights + intermediate activations. - logger.debug("Peak GB Used Post KV Cache: %sGB", - peak_memory // 1024 // 1024 // 1024) - - def compile_or_warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model() - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) - - @torch.inference_mode() - def execute_model( - self, - scheduler_output: "SchedulerOutput", - ) -> ModelRunnerOutput: - output = self.model_runner.execute_model(scheduler_output) - return output if self.rank == 0 else None - - def profile(self, is_start: bool = True): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - if is_start: - self.profiler.start() - else: - self.profiler.stop() - - def check_health(self) -> None: - # worker will always be healthy as long as it's running. - return - - -# TODO: Duplicate, refactor -def _get_cache_block_size( - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig, -) -> int: - head_size = model_config.get_head_size() - num_heads = model_config.get_num_kv_heads(parallel_config) - num_attention_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - - key_cache_block = cache_config.block_size * num_heads * head_size - value_cache_block = key_cache_block - total = num_attention_layers * (key_cache_block + value_cache_block) - if cache_config.cache_dtype == "auto": - dtype = model_config.dtype - else: - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - dtype_size = get_dtype_size(dtype) - return dtype_size * total From d534ecf8b415611f97e177b16a2a40091bf3ae3d Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 16:10:43 +0000 Subject: [PATCH 10/18] add files --- vllm/v1/worker/tpu_model_runner.py | 1311 ++++++++++++++++++++++++++++ vllm/v1/worker/tpu_worker.py | 244 ++++++ 2 files changed, 1555 insertions(+) create mode 100644 vllm/v1/worker/tpu_model_runner.py create mode 100644 vllm/v1/worker/tpu_worker.py diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py new file mode 100644 index 0000000000000..12ecd65a89623 --- /dev/null +++ b/vllm/v1/worker/tpu_model_runner.py @@ -0,0 +1,1311 @@ +import gc +import time +import enum +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Tuple, cast, Optional +from unittest.mock import patch + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn + +# TPU XLA related +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +from vllm.attention import AttentionMetadata +from vllm.config import CompilationLevel, VllmConfig +from vllm.distributed.parallel_state import graph_capture +from vllm.forward_context import set_forward_context +from vllm.inputs import INPUT_REGISTRY +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.sampling_params import SamplingType +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, + LayerBlockType, cdiv, is_pin_memory_available) +from vllm.v1.attention.backends.pallas import PallasMetadata, PallasAttentionBackend +from vllm.v1.engine.mm_input_mapper import MMInputMapperClient +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 +# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. +_ENABLE_TOP_P = False +# FIXME(woosuk): A temporary hack to support `n > 1`. +# This can significantly affect the performance if too large. +_MAX_NUM_SAMPLES = 128 + + +class ExecutionMode(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + PREFIX_PREFILL = enum.auto() + + def is_prefill(self) -> bool: + return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) + + +@dataclass +class PrefillInputData: + + request_ids: List + prompt_lens: List + token_ids: List + position_ids: List + attn_metadata: List + + def zipped(self): + return zip(self.request_ids, self.prompt_lens, self.token_ids, + self.position_ids, self.attn_metadata) + + +@dataclass +class DecodeInputData: + + num_decodes: int + token_ids: Optional[torch.Tensor] = None + position_ids: Optional[torch.Tensor] = None + attn_metadata: PallasMetadata = None + + +class TPUModelRunner: + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + self.device_config = vllm_config.device_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] + + self.is_multimodal_model = model_config.is_multimodal_model + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + self.max_num_reqs = scheduler_config.max_num_seqs + + # Model-related. + self.num_attn_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + self.num_query_heads = model_config.get_num_attention_heads( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + self.hidden_size = model_config.get_hidden_size() + + # Multi-modal data support + self.input_registry = INPUT_REGISTRY + self.mm_registry = MULTIMODAL_REGISTRY + + # NOTE: Initialized input mapper is only used for processing dummy + # multimodal data into multimodal kwargs for GPU memory profiling. + self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) + self.mm_input_mapper_profiling.use_cache = False + + self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 + self.encoder_cache_size = self.scheduler_config.encoder_cache_size + + # Lazy initialization + # self.model: nn.Module # Set after load_model + self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] + # req_id -> (input_id -> encoder_output) + self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=model_config.get_vocab_size(), + ) + + self.prefill_positions = torch.tensor(range(self.max_model_len), + device="cpu", + dtype=torch.int32).reshape( + 1, -1) + + self.new_req_ids = None + + # TODO: Remove this + # self.use_cuda_graph = (self.vllm_config.compilation_config.level + # == CompilationLevel.PIECEWISE + # and not self.model_config.enforce_eager) + # # TODO(woosuk): Provide an option to tune the max cudagraph batch size. + # # The convention is different. + # # self.cudagraph_batch_sizes sorts in ascending order. + # # The batch sizes in the config are in descending order. + # self.cudagraph_batch_sizes = list( + # reversed(self.vllm_config.compilation_config.capture_sizes)) + + # # Cache the device properties. + # self.device_properties = torch.cuda.get_device_properties(self.device) + # self.num_sms = self.device_properties.multi_processor_count + + # # Persistent buffers for CUDA graphs. + # self.input_ids = torch.zeros(self.max_num_tokens, + # dtype=torch.int32, + # device=self.device) + # self.positions = torch.zeros(self.max_num_tokens, + # dtype=torch.int64, + # device=self.device) + # self.inputs_embeds = torch.zeros( + # (self.max_num_tokens, self.hidden_size), + # dtype=self.dtype, + # device=self.device) + + # # OPTIMIZATION: Cache the tensors rather than creating them every step. + # self.arange_np = np.arange(max(self.max_num_reqs + 1, + # self.max_model_len), + # dtype=np.int32) + # # NOTE(woosuk): These tensors are "stateless", i.e., they are literally + # # a faster version of creating a new tensor every time. Thus, we should + # # not make any assumptions about the values in these tensors. + # self.input_ids_cpu = torch.zeros(self.max_num_tokens, + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # self.input_ids_np = self.input_ids_cpu.numpy() + # self.positions_cpu = torch.zeros(self.max_num_tokens, + # dtype=torch.int64, + # device="cpu", + # pin_memory=self.pin_memory) + # self.positions_np = self.positions_cpu.numpy() + # self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # self.slot_mapping_np = self.slot_mapping_cpu.numpy() + # self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # self.query_start_loc_np = self.query_start_loc_cpu.numpy() + # self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + # dtype=torch.int32, + # device="cpu", + # pin_memory=self.pin_memory) + # self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove stopped requests from the cached states. + # Keep the states of the pre-empted requests. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + + # Remove the requests from the persistent batch. + stopped_req_ids = set().union( + scheduler_output.preempted_req_ids, + scheduler_output.finished_req_ids, + ) + removed_req_indices: List[int] = [] + for req_id in stopped_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Update the states of the running requests. + for req_data in scheduler_output.scheduled_running_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + req_index = self.input_batch.req_id_to_index[req_id] + + # Update the num_computed_tokens. + req_state.num_computed_tokens = req_data.num_computed_tokens + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + + # Update the block table. + num_new_blocks = len(req_data.new_block_ids) + if num_new_blocks == 0: + continue + start_index = len(req_state.block_ids) + req_state.block_ids.extend(req_data.new_block_ids) + self.input_batch.block_table.append_row(req_index, start_index, + req_data.new_block_ids) + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt=new_req_data.prompt, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + ) + req_ids_to_add.append(req_id) + + # Update the cached states of the resumed requests. + for res_req_data in scheduler_output.scheduled_resumed_reqs: + req_id = res_req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = res_req_data.block_ids + req_state.num_computed_tokens = res_req_data.num_computed_tokens + req_ids_to_add.append(req_id) + + # For TPU, we keep all of the decode requests before the + # prefill requests in the batch sequence. + # 1. First condense, so all decodes move to start + # 2. Then add new prefills to the end of the batch + removed_req_indices = sorted(removed_req_indices, reverse=True) + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + self.input_batch.add_request(req_state, None) # Append last + self.new_req_ids = req_ids_to_add + + def _prepare_prefill_inputs( + self, + num_scheduled_tokens: List[int], + ) -> PrefillInputData: + # Each prefill run separately with shape [1, padded_prompt_len]. + # So we create lists that will be used in execute_model(). + + prefill_request_ids = [] + prefill_prompt_lens = [] + prefill_token_ids = [] + prefill_position_ids = [] + prefill_attn_metadata = [] + + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = num_reqs - self.new_req_ids + for idx in range(num_decodes, num_reqs): + prefill_request_ids.append(self.input_batch.req_ids[idx]) + + prompt_len = num_scheduled_tokens[idx] + prefill_prompt_lens.append(prompt_len) + + # STATIC SHAPE: prefills are padded to the next power of 2. + padded_prompt_len = _get_padded_prefill_len(prompt_len) + assert padded_prompt_len <= self.max_model_len + + # TOKEN_IDS. + token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ + idx, :padded_prompt_len].reshape(1, -1)) + prefill_token_ids.append(token_ids.to(self.device)) + + # POSITIONS. + positions = self.prefill_positions[:, :padded_prompt_len] + prefill_position_ids.append(positions.to(self.device)) + + # SLOT_MAPPING. + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor( + ) + block_numbers = block_table_cpu_tensor[idx, positions // + self.block_size].reshape( + 1, -1) + block_offsets = positions % self.block_size + slot_mapping = block_numbers * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[:, prompt_len:] = _PAD_SLOT_ID + slot_mapping = slot_mapping.long() + + prefill_attn_metadata.append( + PallasMetadata( + num_prefills=1, + num_prefill_tokens=padded_prompt_len, + num_decode_tokens=0, + slot_mapping=slot_mapping.to(self.device), + block_tables=None, + context_lens=None, + effective_query_lens=None, + )) + + return PrefillInputData( + request_ids=prefill_request_ids, + prompt_lens=prefill_prompt_lens, + token_ids=prefill_token_ids, + position_ids=prefill_position_ids, + attn_metadata=prefill_attn_metadata, + ) + + def _prepare_decode_inputs(self) -> DecodeInputData: + # Decodes run as one single padded batch with shape [batch, 1] + # + # We need to set _PAD_SLOT_ID for the padding tokens in the + # slot_mapping, such that the attention KV cache insertion + # logic knows to ignore those indicies. Otherwise, the + # padding data can be dummy since we have a causal mask. + + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = num_reqs - self.new_req_ids + + if num_decodes == 0: + return DecodeInputData(num_decodes=0) + + # PAD FOR STATIC SHAPES. + padded_batch_size = _get_padded_batch_size(num_decodes) + + # POSITIONS. [batch, 1] + # We slice at the end, since we use the positions for gathering. + positions = torch.from_numpy( + self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) + index = positions.to(torch.int64) + positions = positions[:padded_batch_size] + + # TOKEN_IDS. [batch, 1] + token_ids = torch.gather( + input=torch.from_numpy(self.input_batch.token_ids_cpu), + dim=1, + index=index, + )[:padded_batch_size] + + # SLOT_MAPPING [batch, 1] + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor() + block_number = torch.gather(input=block_table_cpu_tensor, + dim=1, + index=(index // self.block_size)) + block_offsets = index % self.block_size + slot_mapping = block_number * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[num_decodes:] = _PAD_SLOT_ID + slot_mapping = slot_mapping[:padded_batch_size] + + # BLOCK_TABLE [batch, max_num_blocks_per_req] + block_table = block_table_cpu_tensor[:padded_batch_size] + + # CONTEXT_LENS [batch_size] + context_lens = (positions.reshape(-1) + 1) + + # CPU<>TPU sync happens here. + return DecodeInputData(num_decodes=num_decodes, + token_ids=token_ids.to(self.device), + position_ids=positions.to(self.device), + attn_metadata=PallasMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=padded_batch_size, + slot_mapping=slot_mapping.to(self.device), + block_tables=block_table.to(self.device), + context_lens=context_lens.to(self.device), + effective_query_lens=None, + )) + + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + num_decodes = num_reqs - self.new_req_ids + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + # TODO: Verify this works with TPUs + self.input_batch.block_table.commit(num_reqs) + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + max_num_scheduled_tokens = 0 + for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + assert req_id is not None + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens.append(num_tokens) + max_num_scheduled_tokens = max(max_num_scheduled_tokens, + num_tokens) + + # NOTE: Assert that all the decodes are "decodes". + if idx < num_decodes: + assert num_tokens == 1 + assert max_num_scheduled_tokens > 0 + + return ( + self._prepare_prefill_inputs(num_scheduled_tokens), + self._prepare_decode_inputs(num_decodes), + ) + + # # OPTIMIZATION: Start copying the block table first. + # # This way, we can overlap the copy with the following CPU operations. + # self.input_batch.block_table.commit(num_reqs) + + # # Get the number of scheduled tokens for each request. + # # TODO: The Python loop can be slow. Optimize. + # num_scheduled_tokens = [] + # max_num_scheduled_tokens = 0 + # for req_id in self.input_batch.req_ids[:num_reqs]: + # assert req_id is not None + # num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # num_scheduled_tokens.append(num_tokens) + # max_num_scheduled_tokens = max(max_num_scheduled_tokens, + # num_tokens) + # num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) + # assert max_num_scheduled_tokens > 0 + + # # Get request indices. + # # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + # req_indices = np.repeat(self.arange_np[:num_reqs], + # num_scheduled_tokens) + + # # Get batched arange. + # # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # arange = np.concatenate( + # [self.arange_np[:n] for n in num_scheduled_tokens]) + + # # Get positions. + # positions_np = self.positions_np[:total_num_scheduled_tokens] + # np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + # arange, + # out=positions_np) + + # # Get token indices. + # # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # # where M is the max_model_len. + # token_indices = (positions_np + + # req_indices * self.input_batch.token_ids_cpu.shape[1]) + # # NOTE(woosuk): We use torch.index_select instead of np.take here + # # because torch.index_select is much faster than np.take for large + # # tensors. + # torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + # 0, + # torch.from_numpy(token_indices), + # out=self.input_ids_cpu[:total_num_scheduled_tokens]) + + # # Calculate the slot mapping. + # # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # # where K is the max_num_blocks_per_req and the block size is 2. + # # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # # because M (max_model_len) is not necessarily divisible by block_size. + # block_table_indices = (req_indices * self.max_num_blocks_per_req + + # positions_np // self.block_size) + # # NOTE(woosuk): We use torch.index_select instead of np.take here + # # because torch.index_select is much faster than np.take for large + # # tensors. + # block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + # block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + # block_offsets = positions_np % self.block_size + # np.add(block_numbers * self.block_size, + # block_offsets, + # out=self.slot_mapping_np[:total_num_scheduled_tokens]) + + # # Prepare the attention metadata. + # self.query_start_loc_np[0] = 0 + # np.cumsum(num_scheduled_tokens, + # out=self.query_start_loc_np[1:num_reqs + 1]) + + # seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + + # num_scheduled_tokens) + # max_seq_len = seq_lens.max() + # self.seq_start_loc_np[0] = 0 + # np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1]) + + # # Copy the tensors to the GPU. + # self.input_ids[:total_num_scheduled_tokens].copy_( + # self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + # self.positions[:total_num_scheduled_tokens].copy_( + # self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) + # query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( + # self.device, non_blocking=True) + # seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to( + # self.device, non_blocking=True) + # slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( + # self.device, non_blocking=True).long() + + # # Prepare for cascade attention if needed. + # common_prefix_len = (scheduler_output.num_common_prefix_blocks * + # self.block_size) + # if common_prefix_len == 0: + # # Common case. + # use_cascade = False + # else: + # # NOTE(woosuk): Cascade attention uses two attention kernels: one + # # for the common prefix and the other for the rest. For the first + # # kernel, we concatenate all the query tokens (possibly from + # # different requests) and treat them as if they are from the same + # # request. Then, we use bi-directional attention to process the + # # common prefix in the KV cache. Importantly, this means that the + # # first kernel does not do any masking. + + # # Consider the following example: + # # Request 1's input query: [D, E, X] + # # Request 1's kv cache: [A, B, C, D, E, X] + # # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) + # # Request 2's input query: [E, Y] + # # Request 2's kv cache: [A, B, C, D, E, Y] + # # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) + + # # If we use [A, B, C, D, E] as the common prefix, then the + # # first kernel will compute the bi-directional attention between + # # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. + # # However, this is wrong because D in Request 1 should not attend to + # # E in the common prefix (i.e., we need masking). + # # To avoid this, [A, B, C, D] should be the common prefix. + # # That is, the common prefix should be capped by the minimum + # # num_computed_tokens among the requests, and plus one to include + # # the first token of the query. + + # # In practice, we use [A, B, C] as the common prefix, instead of + # # [A, B, C, D] (i.e., the common prefix is capped by the minimum + # # num_computed_tokens, without plus one). + # # This is because of an implementation detail: We want to always + # # use two kernels for cascade attention. Let's imagine: + # # Request 3's input query: [D] + # # Request 3's kv cache: [A, B, C, D] + # # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) + # # If we use [A, B, C, D] as the common prefix for Request 1-3, + # # then Request 3 will be processed only by the first kernel, + # # and the second kernel will get an empty input. While this is not + # # a fundamental problem, our current implementation does not support + # # this case. + # common_prefix_len = min( + # common_prefix_len, + # self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + # # common_prefix_len should be a multiple of the block size. + # common_prefix_len = (common_prefix_len // self.block_size * + # self.block_size) + # use_cascade = FlashAttentionBackend.use_cascade_attention( + # common_prefix_len=common_prefix_len, + # query_lens=num_scheduled_tokens, + # num_query_heads=self.num_query_heads, + # num_kv_heads=self.num_kv_heads, + # use_alibi=False, # FIXME + # use_sliding_window=self.sliding_window is not None, + # num_sms=self.num_sms, + # ) + + # if use_cascade: + # # TODO: Optimize. + # cu_prefix_query_lens = torch.tensor( + # [0, total_num_scheduled_tokens], + # dtype=torch.int32, + # device=self.device) + # cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], + # dtype=torch.int32, + # device=self.device) + # cu_suffix_kv_lens = ( + # self.seq_start_loc_np[:num_reqs + 1] - + # self.arange_np[:num_reqs + 1] * common_prefix_len) + # cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to( + # self.device) + # else: + # cu_prefix_query_lens = None + # cu_prefix_kv_lens = None + # cu_suffix_kv_lens = None + + # attn_metadata = FlashAttentionMetadata( + # num_actual_tokens=total_num_scheduled_tokens, + # max_query_len=max_num_scheduled_tokens, + # query_start_loc=query_start_loc, + # max_seq_len=max_seq_len, + # seq_start_loc=seq_start_loc, + # block_table=( + # self.input_batch.block_table.get_device_tensor()[:num_reqs]), + # slot_mapping=slot_mapping, + # use_cascade=use_cascade, + # common_prefix_len=common_prefix_len, + # cu_prefix_query_lens=cu_prefix_query_lens, + # cu_prefix_kv_lens=cu_prefix_kv_lens, + # cu_suffix_kv_lens=cu_suffix_kv_lens, + # ) + # # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # # request in the batch. While we should not sample any token from this + # # partial request, we do so for simplicity. We will ignore the sampled + # # token from the partial request. + # # TODO: Support prompt logprobs. + # logits_indices = query_start_loc[1:] - 1 + # return attn_metadata, logits_indices + + def _prepare_sampling( + self, + scheduler_output: "SchedulerOutput", + ) -> SamplingMetadata: + skip_copy = True + if (scheduler_output.finished_req_ids + or scheduler_output.preempted_req_ids): + skip_copy = False + if (scheduler_output.scheduled_new_reqs + or scheduler_output.scheduled_resumed_reqs): + skip_copy = False + # Create the sampling metadata. + req_id_output_token_ids: Dict[str, List[int]] = \ + {req_id: req.output_token_ids \ + for req_id, req in self.requests.items()} + + sampling_metadata = self.input_batch.make_sampling_metadata( + req_id_output_token_ids, skip_copy) + return sampling_metadata + + def _execute_encoder(self, scheduler_output: "SchedulerOutput"): + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return + + # Batch the multi-modal inputs. + mm_inputs: List[MultiModalKwargs] = [] + req_input_ids: List[Tuple[str, int]] = [] + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + for input_id in encoder_input_ids: + mm_inputs.append(req_state.mm_inputs[input_id]) + req_input_ids.append((req_id, input_id)) + batched_mm_inputs = MultiModalKwargs.batch(mm_inputs) + batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, + device=self.device) + + # Run the encoder. + # `encoder_outputs` is either of the following: + # 1. A tensor of shape [num_images, feature_size, hidden_size] + # in case when feature_size is fixed across all images. + # 2. A list (length: num_images) of tensors, each of shape + # [feature_size, hidden_size] in case when the feature size is + # dynamic depending on input images. + encoder_outputs = self.model.get_multimodal_embeddings( + **batched_mm_inputs) + + # Cache the encoder outputs. + for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): + if req_id not in self.encoder_cache: + self.encoder_cache[req_id] = {} + self.encoder_cache[req_id][input_id] = output + + def _gather_encoder_outputs( + self, + scheduler_output: "SchedulerOutput", + ) -> List[torch.Tensor]: + encoder_outputs: List[torch.Tensor] = [] + num_reqs = self.input_batch.num_reqs + for req_id in self.input_batch.req_ids[:num_reqs]: + assert req_id is not None + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + req_state = self.requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + mm_positions = req_state.mm_positions + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info["offset"] + num_encoder_tokens = pos_info["length"] + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, + # num_computed_tokens + num_scheduled_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_scheduled_tokens: + # The encoder output is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens) + assert start_idx < end_idx + assert req_id in self.encoder_cache + assert i in self.encoder_cache[req_id] + encoder_output = self.encoder_cache[req_id][i] + encoder_outputs.append(encoder_output[start_idx:end_idx]) + return encoder_outputs + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + self._update_states(scheduler_output) + + # TODO: Ressurect this code + # if self.is_multimodal_model: + # # Run the multimodal encoder if any. + # self._execute_encoder(scheduler_output) + # encoder_outputs = self._gather_encoder_outputs(scheduler_output) + # else: + # encoder_outputs = [] + + # Prepare the decoder inputs. + prefill_data, decode_data = self._prepare_inputs(scheduler_output) + + num_reqs = self.input_batch.num_reqs + sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) + + # attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_input_tokens = num_scheduled_tokens + # attn_metadata.num_input_tokens = num_input_tokens + + # TODO: Resurrect this code + # if self.is_multimodal_model: + # # NOTE(woosuk): To unify token ids and soft tokens (vision + # # embeddings), we always use embeddings (rather than token ids) + # # as input to the multimodal model, even when the input is text. + # input_ids = self.input_ids[:num_scheduled_tokens] + # if encoder_outputs: + # inputs_embeds = self.model.get_input_embeddings( + # input_ids, encoder_outputs) + # else: + # inputs_embeds = self.model.get_input_embeddings(input_ids) + # # TODO(woosuk): Avoid the copy. Optimize. + # self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) + # inputs_embeds = self.inputs_embeds[:num_input_tokens] + # input_ids = None + # else: + # # For text-only models, we use token ids as input. + # # While it is possible to use embeddings as input just like the + # # multimodal models, it is not desirable for performance since + # # then the embedding layer is not included in the CUDA graph. + # input_ids = self.input_ids[:num_input_tokens] + # inputs_embeds = None + + ######################### DECODES ######################### + # Decodes run as one single batch with [padded_batch, 1] + if decode_data.num_decodes > 0: + # FORWARD. + selected_token_ids = self.model(decode_data.token_ids, + decode_data.position_ids, + decode_data.attn_metadata, + self.kv_caches, + is_prompt=False) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] + sampled_token_ids_list = token_ids.tolist() + sampled_token_ids[:decode_data.num_decodes] = token_ids + + # UPDATE REQUEST STATE. + for i, req_id in enumerate( + self.input_batch.req_ids[:decode_data.num_decodes]): + req_state = self.requests[req_id] + + # TODO: ASSERT NO CHUNKED PREFILL. + assert scheduler_output.num_scheduled_tokens[req_id] == 1 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len == req_state.num_tokens + + # TODO: Verify if req_id_to_index mapping is needed here! + token_id = sampled_token_ids_list[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + ######################### PREFILLS ######################### + # Prefills run separately with shape [1, padded_prefill_len], + # due to lack of variable length attention kernel so far. + for idx, (req_id, prompt_len, token_ids, position_ids, + attn_metadata) in enumerate(prefill_data.zipped()): + # FORWARD. + selected_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + self.kv_caches, + is_prompt=True) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_id = selected_token_ids.cpu()[prompt_len - 1].item() + sampled_token_ids[decode_data.num_decodes + idx] = token_id + req_state = self.requests[req_id] + + # TODO: ASSERT NO PREFIX CACHING. + assert req_state.num_computed_tokens == 0 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + # TODO: ASSERT NO CHUNKED PREFILL. + assert seq_len == req_state.num_tokens + assert prompt_len == seq_len + + # UPDATE REQUEST STATE. + req_idx = self.input_batch.req_id_to_index[req_id] + self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + # TODO: Remove + # # Sample the next token and get logprobs if needed. + # sampling_metadata = self._prepare_sampling(scheduler_output) + # sampler_output = self.model.sample( + # logits=logits, + # sampling_metadata=sampling_metadata, + # ) + + # sampled_token_ids = sampler_output.sampled_token_ids + # # TODO(woosuk): The following loop can be slow since it iterates over + # # the requests one by one. Optimize. + # num_reqs = self.input_batch.num_reqs + # for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + # assert req_id is not None + # req_state = self.requests[req_id] + # seq_len = (req_state.num_computed_tokens + + # scheduler_output.num_scheduled_tokens[req_id]) + # assert seq_len <= req_state.num_tokens + # if seq_len == req_state.num_tokens: + # # Append the sampled token to the output token ids. + # token_id = sampled_token_ids[i] + # self.input_batch.token_ids_cpu[i, seq_len] = token_id + # self.input_batch.num_tokens[i] += 1 + # req_state.output_token_ids.append(token_id) + # else: + # # Ignore the sampled token from the partial request. + # # Rewind the generator state as if the token was not sampled. + # generator = self.input_batch.generators.get(i) + # if generator is not None: + # # This relies on cuda-specific torch-internal impl details + # generator.set_offset(generator.get_offset() - 4) + + # if sampler_output.logprob_token_ids is None: + # logprob_token_ids = None + # else: + # logprob_token_ids = sampler_output.logprob_token_ids.cpu() + # if sampler_output.logprobs is None: + # logprobs = None + # else: + # logprobs = sampler_output.logprobs.cpu() + + # num_reqs entries should be non-None + assert all( + req_id is not None for req_id in + self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids_cpu=sampled_token_ids, + logprob_token_ids_cpu=None, + logprobs_cpu=None, + ) + + return model_runner_output + + def load_model(self) -> None: + self.device = self.device_config.device + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + xm_tp_rank = xr.global_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.wait_device_ops() + model = ModelWrapper(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) + + @torch.inference_mode() + def _dummy_run( + self, + batch_size: int, + seq_len: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + exec_mode: ExecutionMode, + ) -> None: + exec_mode = ExecutionMode(exec_mode) + if exec_mode.is_prefill(): + seq_len = (seq_len + 15) // 16 * 16 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + if exec_mode == ExecutionMode.PREFILL: + attn_metadata = PallasMetadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=None, + context_lens=None, + effective_query_lens=None, + ) + + else: + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + + block_tables = torch.zeros( + (batch_size, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device) + + effective_query_lens = torch.ones_like(context_lens) + + attn_metadata = PallasMetadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=effective_query_lens, + ) + else: + assert seq_len == 1 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (batch_size, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device) + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + attn_metadata = PallasMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size * seq_len, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=block_tables, + context_lens=context_lens, + ) + + t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 + + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if exec_mode.is_prefill(): + # Prefll + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + # Decode + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(input_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(t, 0) + torch._dynamo.mark_dynamic(p, 0) + + # Dummy run. + self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, + num_samples, kv_caches) + + def profile_run(self) -> None: + """Profile to measure peak memory during forward pass.""" + + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value `None`. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + dummy_kv_caches = [( + torch.tensor([], dtype=torch.float32, device=self.device), + torch.tensor([], dtype=torch.float32, device=self.device), + ) for _ in range(self.num_attn_layers)] + + # Run empty forward. + self._dummy_run( + batch_size=1, + seq_len=self.max_num_tokens, # Will be rounded to 16 multiple + kv_caches=dummy_kv_caches, + exec_mode=ExecutionMode.PREFILL) + + def capture_model(self) -> None: + """Compile the model.""" + + logger.info("Compiling the model with different input shapes.") + + # Capture prefill shapes + start = time.perf_counter() + for batch_size in [1]: + seq_len = 16 + while True: + self._dummy_run(batch_size, + seq_len, + self.kv_caches, + exec_mode=ExecutionMode.PREFILL) + xm.wait_device_ops() + logger.info(" -- batch_size: %d, seq_len: %d", batch_size, + seq_len) + + if seq_len >= self.model_config.max_model_len: + break + + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + + # Move to next seq_len + seq_len = seq_len * 2 + + end = time.perf_counter() + logger.info("Compilation for prefill shapes is done in %.2f [secs].", + end - start) + + # Capture decode shapes. + start = time.time() + seq_len = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() + while True: + self._dummy_run(batch_size, + seq_len, + self.kv_caches, + exec_mode=ExecutionMode.DECODE) + xm.wait_device_ops() + logger.info(" -- batch_size: %d, seq_len: %d", batch_size, + seq_len) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode shapes is done in %.2f [secs].", + end - start) + + def initialize_kv_cache(self, num_blocks: int) -> None: + assert len(self.kv_caches) == 0 + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + for _ in range(self.num_attn_layers): + self.kv_caches.append(( + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device), + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device), + )) + + +# TODO: This is duplicate from V0, refactor +class ModelWrapper(nn.Module): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + input_lens: torch.Tensor, + t: torch.Tensor, + p: torch.Tensor, + num_samples: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + attn_metadata: The Pallas attention metadata. + input_lens: The actual input lengths of shape [batch_size]. + t: The sampling temperature of shape [batch_size]. + p: The top-p probability of shape [batch_size]. + num_samples: Number of samples to draw from each logits vector. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + batch_size, seq_len = token_ids.shape + # Calculate the positions to sample from. + start_indicies = torch.arange( + batch_size, dtype=torch.int32, device=input_lens.device) * seq_len + logits_indices = start_indicies + input_lens - 1 + + # FIXME(woosuk): This is a temporary hack to avoid using the existing + # sampler and sampling metadata. + sampling_metadata = SamplingMetadata( + seq_groups=[], + selected_token_indices=logits_indices, + categorized_sample_indices={}, + num_prompts=attn_metadata.num_prefills, + ) + + # Skip this in memory profiling at initialization. + if kv_caches[0][0].numel() > 0: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + attn_metadata, + ) + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Argmax sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + argmax_token_ids = argmax_token_ids.repeat(1, num_samples) + + # Zero temperature means greedy decoding. Avoid division by zero. + nonzero_t = torch.where(t != 0, t, 1.0) + logits = logits / nonzero_t.unsqueeze(dim=1) + if _ENABLE_TOP_P: + logits = _apply_top_p(logits, p.unsqueeze(dim=1)) + + # Random sampling. + probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + sampled_token_ids = torch.multinomial(probs, + num_samples, + replacement=True) + if num_samples == 1: + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + sampled_token_ids = sampled_token_ids.squeeze(dim=-1) + next_token_ids = torch.where(t != 0, sampled_token_ids, + argmax_token_ids) + return next_token_ids + + +# TODO: Duplicate with V0, refactor +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() + + +# TODO: Duplicate with V0, refactor +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 + + +# TODO: Duplicate with V0, refactor +def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + logits_sorted = torch.sort(logits, dim=-1, descending=True).values + sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1) + cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True) + cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) + logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) + return logits diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py new file mode 100644 index 0000000000000..df22d2db2db14 --- /dev/null +++ b/vllm/v1/worker/tpu_worker.py @@ -0,0 +1,244 @@ +"""A GPU worker class.""" +import gc +import os +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.distributed +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size +from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.tpu_model_runner import TPUModelRunner + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + + +class TPUWorker: + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + ): + + # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def initialize(self): + os.environ["PJRT_DEVICE"] = "TPU" + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=self.parallel_config.world_size, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size) + + # Device initialization should happen after initializing the distributed + # runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = TPUModelRunner(self.vllm_config, self.device) + + def load_model(self) -> None: + self.model_runner.load_model() + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + + self.model_runner.profile_run() + + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_tpu_memory = m["bytes_limit"] + peak_memory = m[ + "peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak Used: %sGB", peak_memory // 1024 // 1024 // 1024) + logger.debug("Total Memory: %sGB", + total_tpu_memory // 1024 // 1024 // 1024) + + cache_block_size = _get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + num_tpu_blocks = int( + (total_tpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 + return num_tpu_blocks, 0 + + def initialize_cache(self, num_tpu_blocks: int) -> None: + """Allocate TPU and CPU KV cache with the specified number of blocks.""" + if num_tpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_tpu_blocks + max_model_len = self.model_config.max_model_len + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + self.model_runner.initialize_kv_cache(num_tpu_blocks) + + # For debug: Get the maximum amount of memory used by the model weights and + # intermediate activations. + # TODO: Remove this? + xm.mark_step() + xm.wait_device_ops() + m = xm.get_memory_info(self.device) + peak_memory = m[ + "peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak GB Used Post KV Cache: %sGB", + peak_memory // 1024 // 1024 // 1024) + + def compile_or_warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + output = self.model_runner.execute_model(scheduler_output) + return output if self.rank == 0 else None + + def profile(self, is_start: bool = True): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + self.profiler.start() + else: + self.profiler.stop() + + def check_health(self) -> None: + # worker will always be healthy as long as it's running. + return + + +# TODO: Duplicate, refactor +def _get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, +) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + return dtype_size * total From e9057a7873af2ac33557f411f93951f717c4824a Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 16:51:13 +0000 Subject: [PATCH 11/18] add test --- tests/entrypoints/openai/test_accuracy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index b1d4461d164aa..086ad97587e60 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -20,7 +20,7 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 -DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"] +DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests", "--enforce-eager"] MORE_ARGS_LIST = [ [], # Default ["--enable-chunked-prefill"], # Chunked @@ -61,12 +61,15 @@ def run_test(more_args): ) measured_value = results["results"][TASK][FILTER] + print("measured_value = {}".format(measured_value)) + assert (measured_value - RTOL < EXPECTED_VALUE and measured_value + RTOL > EXPECTED_VALUE ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" -@pytest.mark.skipif(not current_platform.is_cuda(), +@pytest.mark.skipif(not current_platform.is_cuda() + and not current_platform.is_tpu(), reason="V1 currently only supported on CUDA") def test_lm_eval_accuracy_v1_engine(monkeypatch): """Run with the V1 Engine.""" From d40ef186885b9f6d691b593d3adcb061bc16310e Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 18:04:31 +0000 Subject: [PATCH 12/18] tmp not working yet --- tests/entrypoints/openai/test_accuracy.py | 2 +- vllm/platforms/tpu.py | 5 +- vllm/v1/worker/tpu_model_runner.py | 184 +++++++++++----------- vllm/v1/worker/tpu_worker.py | 74 +++++---- 4 files changed, 140 insertions(+), 125 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 086ad97587e60..f51fd6c574715 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -20,7 +20,7 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 -DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests", "--enforce-eager"] +DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests", "-O", "2"] MORE_ARGS_LIST = [ [], # Default ["--enable-chunked-prefill"], # Chunked diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index dd0eae57e0354..d37dddafc10b1 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -55,8 +55,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if compilation_config.level == CompilationLevel.NO_COMPILATION: # TPU does not support NO_COMPILATION compilation_config.level = CompilationLevel.DYNAMO_ONCE - assert compilation_config.level < CompilationLevel.PIECEWISE,\ - "TPU does not support Inductor." + compilation_config.level = 2 + # assert compilation_config.level < CompilationLevel.PIECEWISE,\ + # "TPU does not support Inductor. compilation_config.level = {}".format(compilation_config.level) if compilation_config.backend == "": compilation_config.backend = "openxla" diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 12ecd65a89623..99f7db4dcc703 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -15,15 +15,13 @@ import torch_xla.runtime as xr from vllm.attention import AttentionMetadata -from vllm.config import CompilationLevel, VllmConfig -from vllm.distributed.parallel_state import graph_capture -from vllm.forward_context import set_forward_context +from vllm.config import VllmConfig from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.pallas import PallasMetadata, PallasAttentionBackend from vllm.v1.engine.mm_input_mapper import MMInputMapperClient @@ -831,8 +829,7 @@ def execute_model( selected_token_ids = self.model(decode_data.token_ids, decode_data.position_ids, decode_data.attn_metadata, - self.kv_caches, - is_prompt=False) + self.kv_caches) # NOTE: TPU<>CPU sync happens here. # We need to call .cpu() first to avoid recompilation. @@ -862,11 +859,10 @@ def execute_model( for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata) in enumerate(prefill_data.zipped()): # FORWARD. - selected_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - self.kv_caches, - is_prompt=True) + selected_token_ids = self.model(decode_data.token_ids, + decode_data.position_ids, + decode_data.attn_metadata, + self.kv_caches) # NOTE: TPU<>CPU sync happens here. # We need to call .cpu() first to avoid recompilation. @@ -957,12 +953,15 @@ def load_model(self) -> None: # determine the order of concatenating the output tensors. # As a workaround, we use the xm's rank assignment only when loading # the embedding weights. - xm_tp_rank = xr.global_ordinal() - with patch( - "vllm.model_executor.layers.vocab_parallel_embedding." - "get_tensor_model_parallel_rank", - return_value=xm_tp_rank): - model = get_model(vllm_config=self.vllm_config) + + # TODO: Why this is commented out? + # xm_tp_rank = xr.global_ordinal() + # with patch( + # "vllm.model_executor.layers.vocab_parallel_embedding." + # "get_tensor_model_parallel_rank", + # return_value=xm_tp_rank): + + model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() model = ModelWrapper(model) @@ -991,9 +990,9 @@ def _dummy_run( slot_mapping = torch.zeros((batch_size, seq_len), dtype=torch.int64, device=self.device) - input_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) + # input_lens = torch.ones((batch_size, ), + # dtype=torch.int32, + # device=self.device) if exec_mode == ExecutionMode.PREFILL: attn_metadata = PallasMetadata( num_prefills=batch_size, @@ -1046,9 +1045,9 @@ def _dummy_run( context_lens = torch.ones((batch_size, ), dtype=torch.int32, device=self.device) - input_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) + # input_lens = torch.ones((batch_size, ), + # dtype=torch.int32, + # device=self.device) attn_metadata = PallasMetadata( num_prefills=0, num_prefill_tokens=0, @@ -1059,9 +1058,9 @@ def _dummy_run( context_lens=context_lens, ) - t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) - p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) - num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 + # t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + # p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + # num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 # NOTE(woosuk): There are two stages of compilation: torch.compile and # XLA compilation. Using `mark_dynamic` can reduce the torch.compile @@ -1079,38 +1078,40 @@ def _dummy_run( # Decode torch._dynamo.mark_dynamic(token_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(input_lens, 0) + # torch._dynamo.mark_dynamic(input_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - torch._dynamo.mark_dynamic(t, 0) - torch._dynamo.mark_dynamic(p, 0) + # torch._dynamo.mark_dynamic(t, 0) + # torch._dynamo.mark_dynamic(p, 0) # Dummy run. - self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, - num_samples, kv_caches) - - def profile_run(self) -> None: - """Profile to measure peak memory during forward pass.""" - - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value `None`. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - dummy_kv_caches = [( - torch.tensor([], dtype=torch.float32, device=self.device), - torch.tensor([], dtype=torch.float32, device=self.device), - ) for _ in range(self.num_attn_layers)] - - # Run empty forward. - self._dummy_run( - batch_size=1, - seq_len=self.max_num_tokens, # Will be rounded to 16 multiple - kv_caches=dummy_kv_caches, - exec_mode=ExecutionMode.PREFILL) + # TODO: Fix this! + # self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, + # num_samples, kv_caches) + self.model(token_ids, position_ids, attn_metadata, kv_caches) + + # def profile_run(self) -> None: + # """Profile to measure peak memory during forward pass.""" + + # # use an empty tensor instead of `None`` to force Dynamo to pass + # # it by reference, rather by specializing on the value `None`. + # # the `dtype` argument does not matter, and we use `float32` as + # # a placeholder (it has wide hardware support). + # # it is important to create tensors inside the loop, rather than + # # multiplying the list, to avoid Dynamo from treating them as + # # tensor aliasing. + # dummy_kv_caches = [( + # torch.tensor([], dtype=torch.float32, device=self.device), + # torch.tensor([], dtype=torch.float32, device=self.device), + # ) for _ in range(self.num_attn_layers)] + + # # Run empty forward. + # self._dummy_run( + # batch_size=1, + # seq_len=self.max_num_tokens, # Will be rounded to 16 multiple + # kv_caches=dummy_kv_caches, + # exec_mode=ExecutionMode.PREFILL) def capture_model(self) -> None: """Compile the model.""" @@ -1193,10 +1194,6 @@ def forward( token_ids: torch.Tensor, position_ids: torch.Tensor, attn_metadata: AttentionMetadata, - input_lens: torch.Tensor, - t: torch.Tensor, - p: torch.Tensor, - num_samples: int, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. @@ -1212,20 +1209,21 @@ def forward( kv_caches: The key and value caches. They can be None during the memory profiling at initialization. """ - batch_size, seq_len = token_ids.shape + # batch_size, seq_len = token_ids.shape # Calculate the positions to sample from. - start_indicies = torch.arange( - batch_size, dtype=torch.int32, device=input_lens.device) * seq_len - logits_indices = start_indicies + input_lens - 1 + # start_indicies = torch.arange( + # batch_size, dtype=torch.int32, device=input_lens.device) * seq_len + # logits_indices = start_indicies + input_lens - 1 + # TODO: Ressurect # FIXME(woosuk): This is a temporary hack to avoid using the existing # sampler and sampling metadata. - sampling_metadata = SamplingMetadata( - seq_groups=[], - selected_token_indices=logits_indices, - categorized_sample_indices={}, - num_prompts=attn_metadata.num_prefills, - ) + # sampling_metadata = SamplingMetadata( + # seq_groups=[], + # selected_token_indices=logits_indices, + # categorized_sample_indices={}, + # num_prompts=attn_metadata.num_prefills, + # ) # Skip this in memory profiling at initialization. if kv_caches[0][0].numel() > 0: @@ -1254,30 +1252,40 @@ def forward( kv_caches, attn_metadata, ) + hidden_states = hidden_states.flatten(0, 1) - logits = self.model.compute_logits(hidden_states, sampling_metadata) + logits = self.model.compute_logits(hidden_states, None) - # Argmax sampling. + # Greedy sampling. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - argmax_token_ids = argmax_token_ids.repeat(1, num_samples) - - # Zero temperature means greedy decoding. Avoid division by zero. - nonzero_t = torch.where(t != 0, t, 1.0) - logits = logits / nonzero_t.unsqueeze(dim=1) - if _ENABLE_TOP_P: - logits = _apply_top_p(logits, p.unsqueeze(dim=1)) - - # Random sampling. - probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - sampled_token_ids = torch.multinomial(probs, - num_samples, - replacement=True) - if num_samples == 1: - argmax_token_ids = argmax_token_ids.squeeze(dim=-1) - sampled_token_ids = sampled_token_ids.squeeze(dim=-1) - next_token_ids = torch.where(t != 0, sampled_token_ids, - argmax_token_ids) - return next_token_ids + # argmax_token_ids = argmax_token_ids.repeat(1, num_samples) + return argmax_token_ids.squeeze(dim=1) + + # TODO: Ressurect this code + # hidden_states = hidden_states.flatten(0, 1) + # logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # # Argmax sampling. + # argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + # argmax_token_ids = argmax_token_ids.repeat(1, num_samples) + + # # Zero temperature means greedy decoding. Avoid division by zero. + # nonzero_t = torch.where(t != 0, t, 1.0) + # logits = logits / nonzero_t.unsqueeze(dim=1) + # if _ENABLE_TOP_P: + # logits = _apply_top_p(logits, p.unsqueeze(dim=1)) + + # # Random sampling. + # probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + # sampled_token_ids = torch.multinomial(probs, + # num_samples, + # replacement=True) + # if num_samples == 1: + # argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + # sampled_token_ids = sampled_token_ids.squeeze(dim=-1) + # next_token_ids = torch.where(t != 0, sampled_token_ids, + # argmax_token_ids) + # return next_token_ids # TODO: Duplicate with V0, refactor diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index df22d2db2db14..eba96e9f15146 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Tuple import torch import torch.distributed @@ -11,15 +11,13 @@ import vllm.envs as envs from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) + init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.worker.tpu_model_runner import TPUModelRunner +from vllm.v1.worker.tpu_model_runner import TPUModelRunner, ExecutionMode logger = init_logger(__name__) @@ -128,41 +126,49 @@ def load_model(self) -> None: @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - - self.model_runner.profile_run() - + num_layers = self.model_config.get_num_layers(self.parallel_config) + head_size = self.model_config.get_head_size() + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [(torch.tensor([], dtype=torch.float32, + device=self.device), + torch.tensor([], dtype=torch.float32, + device=self.device)) + for _ in range(num_layers)] + self.model_runner._dummy_run( + batch_size=1, + seq_len=self.scheduler_config.max_num_batched_tokens, + kv_caches=kv_caches, + exec_mode=ExecutionMode.PREFILL, + ) # Synchronize before measuring the memory usage. xm.wait_device_ops() # Get the maximum amount of memory used by the model weights and # intermediate activations. m = xm.get_memory_info(self.device) - total_tpu_memory = m["bytes_limit"] - peak_memory = m[ - "peak_bytes_used"] # Weights + intermediate activations. - logger.debug("Peak Used: %sGB", peak_memory // 1024 // 1024 // 1024) - logger.debug("Total Memory: %sGB", - total_tpu_memory // 1024 // 1024 // 1024) - - cache_block_size = _get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - num_tpu_blocks = int( - (total_tpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) - num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 - return num_tpu_blocks, 0 + total_memory_size = m["bytes_limit"] + profiled = m["peak_bytes_used"] # Weights + intermediate activations. + + # Calculate the TPU KV cache size based on profiling. + usable_memory_size = int(total_memory_size * + self.cache_config.gpu_memory_utilization) + tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) + dtype_btyes = get_dtype_size(self.cache_dtype) + block_size_bytes = (dtype_btyes * self.cache_config.block_size * + num_layers * 2 * head_size * num_kv_heads) + num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes + num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. + + # Calculate the CPU KV cache size based on the config. + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + block_size_bytes) + num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. + return num_tpu_blocks, num_cpu_blocks def initialize_cache(self, num_tpu_blocks: int) -> None: """Allocate TPU and CPU KV cache with the specified number of blocks.""" From 422aecc327a7974893bca8b56e3089e0a5d4f596 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 18:36:43 +0000 Subject: [PATCH 13/18] made progress --- tests/entrypoints/openai/test_accuracy.py | 2 +- vllm/v1/core/scheduler.py | 1 + vllm/v1/worker/tpu_model_runner.py | 20 ++++++++++---------- vllm/v1/worker/tpu_worker.py | 12 ++++++++++-- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index f51fd6c574715..81809472c8b29 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -20,7 +20,7 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 -DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests", "-O", "2"] +DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests", "-O2", "--max-num-seqs", "128"] MORE_ARGS_LIST = [ [], # Default ["--enable-chunked-prefill"], # Chunked diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index c3d1560b3e7f8..515c3b328873b 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -30,6 +30,7 @@ def __init__( # TODO: Refactor! Properly handle for TPU. cache_config.enable_prefix_caching = False scheduler_config.chunked_prefill_enabled = False + print(" --- scheduler_config.max_num_seqs = {}".format(scheduler_config.max_num_seqs)) self.scheduler_config = scheduler_config self.cache_config = cache_config diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 99f7db4dcc703..d9cdb1f6a5bf7 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -160,7 +160,7 @@ def __init__( dtype=torch.int32).reshape( 1, -1) - self.new_req_ids = None + self.num_new_reqs = None # TODO: Remove this # self.use_cuda_graph = (self.vllm_config.compilation_config.level @@ -313,7 +313,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in req_ids_to_add: req_state = self.requests[req_id] self.input_batch.add_request(req_state, None) # Append last - self.new_req_ids = req_ids_to_add + self.num_new_reqs = len(req_ids_to_add) def _prepare_prefill_inputs( self, @@ -331,7 +331,7 @@ def _prepare_prefill_inputs( # DECODES are the first num_decodes REQUESTS. # PREFILLS are the next num_reqs - num_decodes REQUESTS. num_reqs = self.input_batch.num_reqs - num_decodes = num_reqs - self.new_req_ids + num_decodes = num_reqs - self.num_new_reqs for idx in range(num_decodes, num_reqs): prefill_request_ids.append(self.input_batch.req_ids[idx]) @@ -397,7 +397,7 @@ def _prepare_decode_inputs(self) -> DecodeInputData: # DECODES are the first num_decodes REQUESTS. # PREFILLS are the next num_reqs - num_decodes REQUESTS. num_reqs = self.input_batch.num_reqs - num_decodes = num_reqs - self.new_req_ids + num_decodes = num_reqs - self.num_new_reqs if num_decodes == 0: return DecodeInputData(num_decodes=0) @@ -460,7 +460,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - num_decodes = num_reqs - self.new_req_ids + num_decodes = num_reqs - self.num_new_reqs # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. @@ -773,7 +773,7 @@ def _gather_encoder_outputs( encoder_outputs.append(encoder_output[start_idx:end_idx]) return encoder_outputs - @torch.inference_mode() + # @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", @@ -970,7 +970,7 @@ def load_model(self) -> None: fullgraph=True, dynamic=False) - @torch.inference_mode() + # @torch.inference_mode() def _dummy_run( self, batch_size: int, @@ -1155,9 +1155,9 @@ def capture_model(self) -> None: self.kv_caches, exec_mode=ExecutionMode.DECODE) xm.wait_device_ops() - logger.info(" -- batch_size: %d, seq_len: %d", batch_size, - seq_len) - + logger.info(" -- batch_size: %d, seq_len: %d, max_num_seqs = %d", batch_size, + seq_len, self.scheduler_config.max_num_seqs) + if batch_size >= self.scheduler_config.max_num_seqs: break diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index eba96e9f15146..698e39a0b7222 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -75,6 +75,14 @@ def __init__( else: self.profiler = None + assert self.device_config.device_type == "tpu" + if self.cache_config.cache_dtype == "auto": + self.cache_dtype = self.model_config.dtype + else: + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + + def initialize(self): os.environ["PJRT_DEVICE"] = "TPU" torch.set_grad_enabled(False) @@ -124,7 +132,7 @@ def initialize(self): def load_model(self) -> None: self.model_runner.load_model() - @torch.inference_mode() + # @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: num_layers = self.model_config.get_num_layers(self.parallel_config) head_size = self.model_config.get_head_size() @@ -207,7 +215,7 @@ def compile_or_warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) - @torch.inference_mode() + # @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", From 6065facdca283e27f46058b90efa155ac4823a5f Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 19:18:36 +0000 Subject: [PATCH 14/18] more progress --- tests/entrypoints/openai/test_accuracy.py | 2 +- vllm/v1/worker/tpu_model_runner.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 81809472c8b29..e7dc5c38c4694 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -20,7 +20,7 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 -DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests", "-O2", "--max-num-seqs", "128"] +DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests", "-O2", "--max-num-seqs", "64"] MORE_ARGS_LIST = [ [], # Default ["--enable-chunked-prefill"], # Chunked diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index d9cdb1f6a5bf7..7757ed287906a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -373,6 +373,7 @@ def _prepare_prefill_inputs( num_prefill_tokens=padded_prompt_len, num_decode_tokens=0, slot_mapping=slot_mapping.to(self.device), + multi_modal_placeholder_index_maps=None, block_tables=None, context_lens=None, effective_query_lens=None, @@ -449,6 +450,7 @@ def _prepare_decode_inputs(self) -> DecodeInputData: num_prefill_tokens=0, num_decode_tokens=padded_batch_size, slot_mapping=slot_mapping.to(self.device), + multi_modal_placeholder_index_maps=None, block_tables=block_table.to(self.device), context_lens=context_lens.to(self.device), effective_query_lens=None, @@ -485,7 +487,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): return ( self._prepare_prefill_inputs(num_scheduled_tokens), - self._prepare_decode_inputs(num_decodes), + self._prepare_decode_inputs(), ) # # OPTIMIZATION: Start copying the block table first. @@ -824,6 +826,7 @@ def execute_model( ######################### DECODES ######################### # Decodes run as one single batch with [padded_batch, 1] + sampled_token_ids_list = [] if decode_data.num_decodes > 0: # FORWARD. selected_token_ids = self.model(decode_data.token_ids, @@ -834,7 +837,7 @@ def execute_model( # NOTE: TPU<>CPU sync happens here. # We need to call .cpu() first to avoid recompilation. token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] - sampled_token_ids_list = token_ids.tolist() + sampled_token_ids_list.extend(token_ids.tolist()) sampled_token_ids[:decode_data.num_decodes] = token_ids # UPDATE REQUEST STATE. @@ -859,14 +862,15 @@ def execute_model( for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata) in enumerate(prefill_data.zipped()): # FORWARD. - selected_token_ids = self.model(decode_data.token_ids, - decode_data.position_ids, - decode_data.attn_metadata, + selected_token_ids = self.model(token_ids, + position_ids, + attn_metadata, self.kv_caches) # NOTE: TPU<>CPU sync happens here. # We need to call .cpu() first to avoid recompilation. token_id = selected_token_ids.cpu()[prompt_len - 1].item() + sampled_token_ids_list.append(token_id) sampled_token_ids[decode_data.num_decodes + idx] = token_id req_state = self.requests[req_id] @@ -934,7 +938,7 @@ def execute_model( model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids_cpu=sampled_token_ids, + sampled_token_ids=sampled_token_ids_list, logprob_token_ids_cpu=None, logprobs_cpu=None, ) From f1da4b0552ebd77922780aa06df1b49ebf1a7496 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 19:23:22 +0000 Subject: [PATCH 15/18] runs, no correctness yet --- vllm/v1/request.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 45450165eaefe..11cbdff23f358 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -97,6 +97,7 @@ def append_output_token_ids( ) -> None: if isinstance(token_ids, int): token_ids = [token_ids] + self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) From 6ea94b0a73a3ba13e7c6de49b58835215b529c48 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 10 Jan 2025 21:00:04 +0000 Subject: [PATCH 16/18] fixes --- examples/offline_inference/offline_inference.py | 2 +- vllm/v1/worker/gpu_input_batch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/offline_inference.py b/examples/offline_inference/offline_inference.py index 23cc6e8539431..b62c0648499a0 100644 --- a/examples/offline_inference/offline_inference.py +++ b/examples/offline_inference/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 40494e64b22f0..2d907bcc04d61 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -69,7 +69,7 @@ def __init__( self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32) # Block table. self.block_table = BlockTable( From cefce4a3b78391a180d712bb6e49f63166f001fb Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Sat, 11 Jan 2025 03:14:24 +0000 Subject: [PATCH 17/18] tmp wip --- .../offline_inference/offline_inference.py | 2 +- vllm/v1/worker/tpu_model_runner.py | 85 ++++++++++++++----- 2 files changed, 65 insertions(+), 22 deletions(-) diff --git a/examples/offline_inference/offline_inference.py b/examples/offline_inference/offline_inference.py index b62c0648499a0..b54f979779aea 100644 --- a/examples/offline_inference/offline_inference.py +++ b/examples/offline_inference/offline_inference.py @@ -8,7 +8,7 @@ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams()#temperature=0.8, top_p=0.95) # Create an LLM. llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 7757ed287906a..b6d2afe922e83 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -21,8 +21,8 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, - LayerBlockType, cdiv, is_pin_memory_available) +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv, + is_pin_memory_available) from vllm.v1.attention.backends.pallas import PallasMetadata, PallasAttentionBackend from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.outputs import ModelRunnerOutput @@ -313,6 +313,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in req_ids_to_add: req_state = self.requests[req_id] self.input_batch.add_request(req_state, None) # Append last + self.num_new_reqs = len(req_ids_to_add) def _prepare_prefill_inputs( @@ -333,7 +334,8 @@ def _prepare_prefill_inputs( num_reqs = self.input_batch.num_reqs num_decodes = num_reqs - self.num_new_reqs for idx in range(num_decodes, num_reqs): - prefill_request_ids.append(self.input_batch.req_ids[idx]) + req_id = self.input_batch.req_ids[idx] + prefill_request_ids.append(req_id) prompt_len = num_scheduled_tokens[idx] prefill_prompt_lens.append(prompt_len) @@ -345,10 +347,12 @@ def _prepare_prefill_inputs( # TOKEN_IDS. token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ idx, :padded_prompt_len].reshape(1, -1)) + token_ids[:, prompt_len:] = 0 prefill_token_ids.append(token_ids.to(self.device)) # POSITIONS. positions = self.prefill_positions[:, :padded_prompt_len] + positions[:, prompt_len:] = 0 prefill_position_ids.append(positions.to(self.device)) # SLOT_MAPPING. @@ -367,16 +371,26 @@ def _prepare_prefill_inputs( slot_mapping[:, prompt_len:] = _PAD_SLOT_ID slot_mapping = slot_mapping.long() + # BLOCK_TABLE [batch, max_num_blocks_per_req] + block_table = block_table_cpu_tensor[idx:idx + 1, :] + + context_lens_tensor = torch.tensor([prompt_len], + dtype=torch.int32, + device=self.device) + prompt_lens_tensor = torch.tensor([prompt_len], + dtype=torch.int32, + device=self.device) + prefill_attn_metadata.append( PallasMetadata( num_prefills=1, - num_prefill_tokens=padded_prompt_len, + num_prefill_tokens=prompt_len, # NOTE: This is not used. num_decode_tokens=0, slot_mapping=slot_mapping.to(self.device), multi_modal_placeholder_index_maps=None, - block_tables=None, - context_lens=None, - effective_query_lens=None, + block_tables=None, #block_table.to(self.device), + context_lens=None, #context_lens_tensor, + effective_query_lens=None, #prompt_lens_tensor, )) return PrefillInputData( @@ -418,7 +432,7 @@ def _prepare_decode_inputs(self) -> DecodeInputData: input=torch.from_numpy(self.input_batch.token_ids_cpu), dim=1, index=index, - )[:padded_batch_size] + )[:padded_batch_size].to(torch.int32) # SLOT_MAPPING [batch, 1] # The "slot" is the "physical index" of a token in the KV cache. @@ -434,6 +448,7 @@ def _prepare_decode_inputs(self) -> DecodeInputData: # are ignored when inserting into the KV cache. slot_mapping[num_decodes:] = _PAD_SLOT_ID slot_mapping = slot_mapping[:padded_batch_size] + slot_mapping = slot_mapping.long() # BLOCK_TABLE [batch, max_num_blocks_per_req] block_table = block_table_cpu_tensor[:padded_batch_size] @@ -464,10 +479,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): num_decodes = num_reqs - self.num_new_reqs + # TODO: Ressurect # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. # TODO: Verify this works with TPUs - self.input_batch.block_table.commit(num_reqs) + # self.input_batch.block_table.commit(num_reqs) # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. @@ -483,6 +499,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # NOTE: Assert that all the decodes are "decodes". if idx < num_decodes: assert num_tokens == 1 + assert max_num_scheduled_tokens > 0 return ( @@ -775,7 +792,7 @@ def _gather_encoder_outputs( encoder_outputs.append(encoder_output[start_idx:end_idx]) return encoder_outputs - # @torch.inference_mode() + @torch.no_grad() def execute_model( self, scheduler_output: "SchedulerOutput", @@ -853,7 +870,9 @@ def execute_model( # TODO: Verify if req_id_to_index mapping is needed here! token_id = sampled_token_ids_list[i] - self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_idx = self.input_batch.req_id_to_index[req_id] + self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id + self.input_batch.num_tokens[req_idx] += 1 req_state.output_token_ids.append(token_id) ######################### PREFILLS ######################### @@ -862,10 +881,8 @@ def execute_model( for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata) in enumerate(prefill_data.zipped()): # FORWARD. - selected_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - self.kv_caches) + selected_token_ids = self.model(token_ids, position_ids, + attn_metadata, self.kv_caches) # NOTE: TPU<>CPU sync happens here. # We need to call .cpu() first to avoid recompilation. @@ -886,6 +903,7 @@ def execute_model( # UPDATE REQUEST STATE. req_idx = self.input_batch.req_id_to_index[req_id] self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id + self.input_batch.num_tokens[req_idx] += 1 req_state.output_token_ids.append(token_id) # TODO: Remove @@ -957,7 +975,7 @@ def load_model(self) -> None: # determine the order of concatenating the output tensors. # As a workaround, we use the xm's rank assignment only when loading # the embedding weights. - + # TODO: Why this is commented out? # xm_tp_rank = xr.global_ordinal() # with patch( @@ -1159,9 +1177,10 @@ def capture_model(self) -> None: self.kv_caches, exec_mode=ExecutionMode.DECODE) xm.wait_device_ops() - logger.info(" -- batch_size: %d, seq_len: %d, max_num_seqs = %d", batch_size, - seq_len, self.scheduler_config.max_num_seqs) - + logger.info(" -- batch_size: %d, seq_len: %d, max_num_seqs = %d", + batch_size, seq_len, + self.scheduler_config.max_num_seqs) + if batch_size >= self.scheduler_config.max_num_seqs: break @@ -1260,10 +1279,34 @@ def forward( hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, None) + # Greedy sampling. + # argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + # # argmax_token_ids = argmax_token_ids.repeat(1, num_samples) + # return argmax_token_ids.squeeze(dim=-1) + + ###### # Greedy sampling. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - # argmax_token_ids = argmax_token_ids.repeat(1, num_samples) - return argmax_token_ids.squeeze(dim=1) + argmax_token_ids = argmax_token_ids.repeat(1, 1) + + # Zero temperature means greedy decoding. Avoid division by zero. + # nonzero_t = torch.where(t != 0, t, 1.0) + # logits = logits / nonzero_t.unsqueeze(dim=1) + # if _ENABLE_TOP_P: + # logits = _apply_top_p(logits, p.unsqueeze(dim=1)) + + # # Random sampling. + # probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + # sampled_token_ids = torch.multinomial(probs, + # num_samples, + # replacement=True) + # if num_samples == 1: + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + # sampled_token_ids = sampled_token_ids.squeeze(dim=-1) + # next_token_ids = torch.where(t != 0, sampled_token_ids, + # argmax_token_ids) + return argmax_token_ids + #### # TODO: Ressurect this code # hidden_states = hidden_states.flatten(0, 1) From fca776539657187155b25296acfe5c1c8e0c2ca5 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Sat, 11 Jan 2025 14:26:04 +0000 Subject: [PATCH 18/18] works! --- .../offline_inference/offline_inference.py | 2 +- tests/entrypoints/openai/test_accuracy.py | 2 +- vllm/v1/worker/tpu_model_runner.py | 92 +++++++++++-------- 3 files changed, 56 insertions(+), 40 deletions(-) diff --git a/examples/offline_inference/offline_inference.py b/examples/offline_inference/offline_inference.py index b54f979779aea..53451d3a3dcd9 100644 --- a/examples/offline_inference/offline_inference.py +++ b/examples/offline_inference/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams()#temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16) +llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16, enforce_eager=True) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index e7dc5c38c4694..e11a16c08fb21 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -20,7 +20,7 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 -DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests", "-O2", "--max-num-seqs", "64"] +DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests", "--enforce-eager", "--max-num-seqs", "64"] MORE_ARGS_LIST = [ [], # Default ["--enable-chunked-prefill"], # Chunked diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index b6d2afe922e83..8f7e1851fff4b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -334,26 +334,32 @@ def _prepare_prefill_inputs( num_reqs = self.input_batch.num_reqs num_decodes = num_reqs - self.num_new_reqs for idx in range(num_decodes, num_reqs): + print("prepare prefill idx = {}".format(idx)) req_id = self.input_batch.req_ids[idx] prefill_request_ids.append(req_id) + print(" req_id = {}".format(req_id)) prompt_len = num_scheduled_tokens[idx] prefill_prompt_lens.append(prompt_len) + print(" prompt_len = {}".format(prompt_len)) # STATIC SHAPE: prefills are padded to the next power of 2. padded_prompt_len = _get_padded_prefill_len(prompt_len) assert padded_prompt_len <= self.max_model_len + print(" padded_prompt_len = {}".format(padded_prompt_len)) # TOKEN_IDS. token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ idx, :padded_prompt_len].reshape(1, -1)) token_ids[:, prompt_len:] = 0 prefill_token_ids.append(token_ids.to(self.device)) + print(" token_ids.shape = {} token_ids.vals = {}".format(token_ids.shape, token_ids)) # POSITIONS. - positions = self.prefill_positions[:, :padded_prompt_len] + positions = self.prefill_positions[:, :padded_prompt_len].clone() positions[:, prompt_len:] = 0 prefill_position_ids.append(positions.to(self.device)) + print(" positions.shape = {} positions.vals = {}".format(positions.shape, positions)) # SLOT_MAPPING. # The "slot" is the "physical index" of a token in the KV cache. @@ -364,22 +370,25 @@ def _prepare_prefill_inputs( block_numbers = block_table_cpu_tensor[idx, positions // self.block_size].reshape( 1, -1) + print(" block_numbers.shape = {} block_numbers.vals = {}".format(block_numbers.shape, block_numbers)) + block_offsets = positions % self.block_size slot_mapping = block_numbers * self.block_size + block_offsets # Set an out of range value for the padding tokens so that they # are ignored when inserting into the KV cache. slot_mapping[:, prompt_len:] = _PAD_SLOT_ID slot_mapping = slot_mapping.long() + print(" slot_mapping.shape = {} slot_mapping.vals = {}".format(slot_mapping.shape, slot_mapping)) # BLOCK_TABLE [batch, max_num_blocks_per_req] - block_table = block_table_cpu_tensor[idx:idx + 1, :] + # block_table = block_table_cpu_tensor[idx:idx + 1, :] - context_lens_tensor = torch.tensor([prompt_len], - dtype=torch.int32, - device=self.device) - prompt_lens_tensor = torch.tensor([prompt_len], - dtype=torch.int32, - device=self.device) + # context_lens_tensor = torch.tensor([prompt_len], + # dtype=torch.int32, + # device=self.device) + # prompt_lens_tensor = torch.tensor([prompt_len], + # dtype=torch.int32, + # device=self.device) prefill_attn_metadata.append( PallasMetadata( @@ -417,6 +426,7 @@ def _prepare_decode_inputs(self) -> DecodeInputData: if num_decodes == 0: return DecodeInputData(num_decodes=0) + print("prepare num_decodes = {}".format(num_decodes)) # PAD FOR STATIC SHAPES. padded_batch_size = _get_padded_batch_size(num_decodes) @@ -425,7 +435,10 @@ def _prepare_decode_inputs(self) -> DecodeInputData: positions = torch.from_numpy( self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) index = positions.to(torch.int64) + index[num_decodes:] = 0 positions = positions[:padded_batch_size] + positions[num_decodes:] = 0 + print(" positions.shape = {} positions.vals = {}".format(positions.shape, positions)) # TOKEN_IDS. [batch, 1] token_ids = torch.gather( @@ -433,6 +446,8 @@ def _prepare_decode_inputs(self) -> DecodeInputData: dim=1, index=index, )[:padded_batch_size].to(torch.int32) + token_ids[num_decodes:] = 0 + print(" token_ids.shape = {} token_ids.vals = {}".format(token_ids.shape, token_ids)) # SLOT_MAPPING [batch, 1] # The "slot" is the "physical index" of a token in the KV cache. @@ -442,6 +457,7 @@ def _prepare_decode_inputs(self) -> DecodeInputData: block_number = torch.gather(input=block_table_cpu_tensor, dim=1, index=(index // self.block_size)) + print(" block_number.shape = {} block_number.vals = {}".format(block_number.shape, block_number)) block_offsets = index % self.block_size slot_mapping = block_number * self.block_size + block_offsets # Set an out of range value for the padding tokens so that they @@ -450,11 +466,14 @@ def _prepare_decode_inputs(self) -> DecodeInputData: slot_mapping = slot_mapping[:padded_batch_size] slot_mapping = slot_mapping.long() + print(" slot_mapping.shape = {} slot_mapping.vals = {}".format(slot_mapping.shape, slot_mapping)) # BLOCK_TABLE [batch, max_num_blocks_per_req] block_table = block_table_cpu_tensor[:padded_batch_size] # CONTEXT_LENS [batch_size] context_lens = (positions.reshape(-1) + 1) + context_lens[num_decodes:] = 0 + print(" context_lens.shape = {} context_lens.vals = {}".format(context_lens.shape, context_lens)) # CPU<>TPU sync happens here. return DecodeInputData(num_decodes=num_decodes, @@ -811,7 +830,7 @@ def execute_model( prefill_data, decode_data = self._prepare_inputs(scheduler_output) num_reqs = self.input_batch.num_reqs - sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) + # sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) # attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -851,11 +870,12 @@ def execute_model( decode_data.attn_metadata, self.kv_caches) + print("DECODE selected_token_ids.shape = {}".format(selected_token_ids.shape)) # NOTE: TPU<>CPU sync happens here. # We need to call .cpu() first to avoid recompilation. token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] sampled_token_ids_list.extend(token_ids.tolist()) - sampled_token_ids[:decode_data.num_decodes] = token_ids + # sampled_token_ids[:decode_data.num_decodes] = token_ids # UPDATE REQUEST STATE. for i, req_id in enumerate( @@ -870,9 +890,8 @@ def execute_model( # TODO: Verify if req_id_to_index mapping is needed here! token_id = sampled_token_ids_list[i] - req_idx = self.input_batch.req_id_to_index[req_id] - self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id - self.input_batch.num_tokens[req_idx] += 1 + self.input_batch.token_ids_cpu[i, seq_len] = token_id + self.input_batch.num_tokens[i] += 1 req_state.output_token_ids.append(token_id) ######################### PREFILLS ######################### @@ -884,11 +903,12 @@ def execute_model( selected_token_ids = self.model(token_ids, position_ids, attn_metadata, self.kv_caches) + print("PREFILL selected_token_ids.shape = {}".format(selected_token_ids.shape)) # NOTE: TPU<>CPU sync happens here. # We need to call .cpu() first to avoid recompilation. - token_id = selected_token_ids.cpu()[prompt_len - 1].item() + token_id = selected_token_ids.cpu()[prompt_len-1].item() sampled_token_ids_list.append(token_id) - sampled_token_ids[decode_data.num_decodes + idx] = token_id + # sampled_token_ids[decode_data.num_decodes + idx] = token_id req_state = self.requests[req_id] # TODO: ASSERT NO PREFIX CACHING. @@ -975,15 +995,12 @@ def load_model(self) -> None: # determine the order of concatenating the output tensors. # As a workaround, we use the xm's rank assignment only when loading # the embedding weights. - - # TODO: Why this is commented out? - # xm_tp_rank = xr.global_ordinal() - # with patch( - # "vllm.model_executor.layers.vocab_parallel_embedding." - # "get_tensor_model_parallel_rank", - # return_value=xm_tp_rank): - - model = get_model(vllm_config=self.vllm_config) + xm_tp_rank = xr.global_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() model = ModelWrapper(model) @@ -1190,19 +1207,18 @@ def capture_model(self) -> None: logger.info("Compilation for decode shapes is done in %.2f [secs].", end - start) - def initialize_kv_cache(self, num_blocks: int) -> None: + def initialize_kv_cache(self, num_tpu_blocks: int) -> None: assert len(self.kv_caches) == 0 - kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) + + tpu_cache_shape = PallasAttentionBackend.get_kv_cache_shape( + num_tpu_blocks, self.block_size, self.num_kv_heads, self.head_size) + for _ in range(self.num_attn_layers): - self.kv_caches.append(( - torch.zeros(kv_cache_shape, - dtype=self.kv_cache_dtype, - device=self.device), - torch.zeros(kv_cache_shape, - dtype=self.kv_cache_dtype, - device=self.device), - )) + tpu_k_cache = torch.zeros(tpu_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device) + tpu_v_cache = torch.zeros_like(tpu_k_cache) + self.kv_caches.append((tpu_k_cache, tpu_v_cache)) # TODO: This is duplicate from V0, refactor @@ -1287,7 +1303,7 @@ def forward( ###### # Greedy sampling. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - argmax_token_ids = argmax_token_ids.repeat(1, 1) + # argmax_token_ids = argmax_token_ids.repeat(1, 1) # Zero temperature means greedy decoding. Avoid division by zero. # nonzero_t = torch.where(t != 0, t, 1.0) @@ -1302,11 +1318,11 @@ def forward( # replacement=True) # if num_samples == 1: argmax_token_ids = argmax_token_ids.squeeze(dim=-1) - # sampled_token_ids = sampled_token_ids.squeeze(dim=-1) + # sampled_token_ids = sampled_token_ids.squeeze(dim=-1) # next_token_ids = torch.where(t != 0, sampled_token_ids, # argmax_token_ids) return argmax_token_ids - #### + #### # TODO: Ressurect this code # hidden_states = hidden_states.flatten(0, 1)