# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Datastructures defining a GPU input batch from dataclasses import dataclass from typing import Optional, cast import numpy as np import torch from typing_extensions import deprecated from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, LogitsProcessors, MoveDirectionality) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import MultiGroupBlockTable @dataclass class CachedRequestState: req_id: str prompt_token_ids: Optional[list[int]] mm_features: list[MultiModalFeatureSpec] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] block_ids: tuple[list[int], ...] num_computed_tokens: int output_token_ids: list[int] mrope_positions: Optional[torch.Tensor] = None mrope_position_delta: Optional[int] = None lora_request: Optional[LoRARequest] = None prompt_embeds: Optional[torch.Tensor] = None deepstack_input_embeds: Optional[torch.Tensor] = None def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.prompt_token_ids, self.prompt_embeds) @property def num_tokens(self) -> int: return self.num_prompt_tokens + len(self.output_token_ids) # Temporary back-compatibility for plugins that define model runner @property @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " "removed in v0.13. Please use `mm_kwargs` instead.") def mm_inputs(self) -> list[MultiModalKwargsItems]: return [ MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features if f.data is not None ] def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: if self.prompt_token_ids is None: raise ValueError( f"Tried to access token index {idx}, but that token was " "provided via prompt_embeds, and its ID is unknown.") return self.prompt_token_ids[idx] elif idx - self.num_prompt_tokens < len(self.output_token_ids): return self.output_token_ids[idx - self.num_prompt_tokens] else: return -1 class InputBatch: def __init__( self, max_num_reqs: int, max_model_len: int, max_num_batched_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group logitsprocs: Optional[LogitsProcessors] = None, is_spec_decode: bool = False, is_pooling_model: bool = False, num_speculative_tokens: int = 0, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_batched_tokens = max_num_batched_tokens self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size self._req_ids: list[Optional[str]] = [] self.req_id_to_index: dict[str, int] = {} # TODO(woosuk): This buffer could be too large if max_model_len is big. # Find a way to reduce the CPU memory usage. # This buffer is not directly transferred to the GPU, so it does not # need to be pinned. self.token_ids_cpu_tensor = torch.zeros( (max_num_reqs, max_model_len), device="cpu", dtype=torch.int32, pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.is_token_ids = torch.zeros((max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False) # Store prompt embeddings per request to avoid OOM from large upfront # allocation if max_model_len is big. # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) self.req_prompt_embeds: dict[int, torch.Tensor] = {} #patch to add req_deepstack_input_embeds self.req_deepstack_input_embeds: dict[int, torch.Tensor] = {} self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( (max_num_reqs, ), device="cpu", dtype=torch.int32, pin_memory=pin_memory, ) self.num_computed_tokens_cpu = \ self.num_computed_tokens_cpu_tensor.numpy() # Block table. self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, max_model_len=max_model_len, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, block_sizes=block_sizes, num_speculative_tokens=num_speculative_tokens, ) # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), dtype=torch.float32, device=device) self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), dtype=torch.float32, device="cpu", pin_memory=pin_memory) self.temperature_cpu = self.temperature_cpu_tensor.numpy() self.greedy_reqs: set[str] = set() self.random_reqs: set[str] = set() self.top_p = torch.empty((max_num_reqs, ), dtype=torch.float32, device=device) self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), dtype=torch.float32, device="cpu", pin_memory=pin_memory) self.top_p_cpu = self.top_p_cpu_tensor.numpy() self.top_p_reqs: set[str] = set() self.top_k = torch.empty((max_num_reqs, ), dtype=torch.int32, device=device) self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), dtype=torch.int32, device="cpu", pin_memory=pin_memory) self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: set[str] = set() # IDs of requests which do not support spec decoding self.spec_decode_unsupported_reqs: set[str] = set() # Frequency penalty related data structures # self.frequency_penalties = torch.empty((max_num_reqs, ), # dtype=torch.float, # device=device) self.frequency_penalties_cpu_tensor = torch.empty( (max_num_reqs, ), dtype=torch.float, device="cpu", pin_memory=pin_memory) self.frequency_penalties_cpu = \ self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: set[str] = set() # Presence penalty related data structures # self.presence_penalties = torch.empty((max_num_reqs, ), # dtype=torch.float, # device=device) self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), dtype=torch.float, device="cpu", pin_memory=pin_memory) self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( ) self.presence_penalties_reqs: set[str] = set() # Repetition penalty related data structures # self.repetition_penalties = torch.empty((max_num_reqs, ), # dtype=torch.float, # device=device) self.repetition_penalties_cpu_tensor = torch.empty( (max_num_reqs, ), dtype=torch.float, device="cpu", pin_memory=pin_memory) self.repetition_penalties_cpu = \ self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() # Speculative decoding self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ), dtype=torch.int64, device="cpu", pin_memory=pin_memory) self.num_accepted_tokens_cpu = \ self.num_accepted_tokens_cpu_tensor.numpy() # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), dtype=np.int32) self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {} # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own # generator should not be included in the dictionary. self.generators: dict[int, torch.Generator] = {} self.num_logprobs: dict[str, int] = {} # NOTE(rob): num_prompt_logprobs only includes reqs # that are currently in the prefill phase. self.num_prompt_logprobs: dict[str, int] = {} # To accumulate prompt logprobs tensor chunks across prefill steps. self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} # Internal representation of per-step batch state changes, used for # reordering persistent batch and generating logitsprocs batch state # updates. Should reset each step. self.batch_update_builder = BatchUpdateBuilder() # TODO convert this to LogitsProcessor self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, # the value is False. Since we use masked_fill_ to set -inf. self.allowed_token_ids_mask: Optional[torch.Tensor] = None self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None # req_index -> bad_words_token_ids self.bad_words_token_ids: dict[int, list[list[int]]] = {} self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool) self.req_output_token_ids: list[Optional[list[int]]] = [] # Store provided logitsprocs. If none are provided, initialize empty # data structure self.logitsprocs = logitsprocs or LogitsProcessors() # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() self.pooling_params: dict[str, PoolingParams] = {} # Cached reference to the GPU tensor of previously sampled tokens self.prev_sampled_token_ids: Optional[torch.Tensor] = None self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None self.prev_req_id_to_index: Optional[dict[str, int]] = None def add_request( self, request: "CachedRequestState", ) -> int: req_index = self._register_add_request(request) req_id = request.req_id if req_index == len(self._req_ids): self._req_ids.append(req_id) self.req_output_token_ids.append(request.output_token_ids) else: self._req_ids[req_index] = req_id self.req_output_token_ids[req_index] = request.output_token_ids self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. num_prompt_tokens = length_from_prompt_token_ids_or_embeds( request.prompt_token_ids, request.prompt_embeds) self.num_prompt_tokens[req_index] = num_prompt_tokens start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) if request.prompt_token_ids is not None: self.token_ids_cpu[ req_index, :num_prompt_tokens] = request.prompt_token_ids self.is_token_ids[req_index, :num_prompt_tokens] = True else: self.is_token_ids[req_index, :num_prompt_tokens] = False if request.prompt_embeds is not None: self.req_prompt_embeds[req_index] = request.prompt_embeds #patch to add req_deepstack_input_embeds if request.deepstack_input_embeds is not None: self.req_deepstack_input_embeds[req_index] = request.deepstack_input_embeds self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids self.is_token_ids[req_index, start_idx:end_idx] = True # Number of token ids in prompt (token_ids_cpu or prompt_embeds). # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens # Number of tokens without spec decode tokens. self.num_tokens_no_spec[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.block_table.add_row(request.block_ids, req_index) if sampling_params := request.sampling_params: if (self.is_spec_decode and is_spec_decode_unsupported(sampling_params)): self.spec_decode_unsupported_reqs.add(req_id) if sampling_params.sampling_type == SamplingType.GREEDY: # Should avoid division by zero later when apply_temperature. self.temperature_cpu[req_index] = 0.0 self.greedy_reqs.add(req_id) else: self.temperature_cpu[req_index] = sampling_params.temperature self.random_reqs.add(req_id) self.top_p_cpu[req_index] = sampling_params.top_p if sampling_params.top_p < 1: self.top_p_reqs.add(req_id) top_k = sampling_params.top_k if 0 < top_k < self.vocab_size: self.top_k_reqs.add(req_id) else: top_k = self.vocab_size self.top_k_cpu[req_index] = top_k self.frequency_penalties_cpu[ req_index] = sampling_params.frequency_penalty if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) self.presence_penalties_cpu[ req_index] = sampling_params.presence_penalty if sampling_params.presence_penalty != 0.0: self.presence_penalties_reqs.add(req_id) self.repetition_penalties_cpu[ req_index] = sampling_params.repetition_penalty if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) # NOTE(woosuk): self.generators should not include the requests that # do not have their own generator. if request.generator is not None: self.generators[req_index] = request.generator if sampling_params.logprobs is not None: self.num_logprobs[req_id] = (self.vocab_size if sampling_params.logprobs == -1 else sampling_params.logprobs) if sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[req_id] = ( self.vocab_size if sampling_params.prompt_logprobs == -1 else sampling_params.prompt_logprobs) if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id) if self.allowed_token_ids_mask_cpu_tensor is None: # Lazy allocation for this tensor, which can be large. # False means we don't fill with -inf. self.allowed_token_ids_mask = torch.zeros( self.max_num_reqs, self.vocab_size, dtype=torch.bool, device=self.device) self.allowed_token_ids_mask_cpu_tensor = torch.zeros( self.max_num_reqs, self.vocab_size, dtype=torch.bool, device="cpu") self.allowed_token_ids_mask_cpu_tensor[req_index] = True # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index][ sampling_params.allowed_token_ids] = False if sampling_params.bad_words_token_ids: self.bad_words_token_ids[ req_index] = sampling_params.bad_words_token_ids elif pooling_params := request.pooling_params: self.pooling_params[req_id] = pooling_params self.logits_processing_needs_token_ids[req_index] = ( pooling_params.requires_token_ids) else: raise NotImplementedError("Unrecognized request type") # Speculative decoding: by default 1 token is generated. self.num_accepted_tokens_cpu[req_index] = 1 # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id if lora_id not in self.lora_id_to_request_ids: self.lora_id_to_request_ids[lora_id] = set() self.request_lora_mapping[req_index] = lora_id self.lora_id_to_request_ids[lora_id].add(request.req_id) self.lora_id_to_lora_request[lora_id] = request.lora_request else: # No LoRA self.request_lora_mapping[req_index] = 0 return req_index def _make_sampling_metadata(self) -> SamplingMetadata: num_reqs = self.num_reqs # if not self.all_greedy: # temperature = copy_slice(self.temperature_cpu_tensor, # self.temperature, num_reqs) # else: # temperature = None # if not self.no_top_p: # copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) # if not self.no_top_k: # copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) # if not self.no_penalties: # # Since syncing these tensors is expensive only copy them # # if necessary i.e. if there are requests which require # # penalties to be applied during sampling. # copy_slice(self.frequency_penalties_cpu_tensor, # self.frequency_penalties, num_reqs) # copy_slice(self.presence_penalties_cpu_tensor, # self.presence_penalties, num_reqs) # copy_slice(self.repetition_penalties_cpu_tensor, # self.repetition_penalties, num_reqs) needs_prompt_token_ids = ( not self.no_penalties or self.logits_processing_needs_token_ids[:num_reqs].any()) if needs_prompt_token_ids: # The prompt tokens are used only for applying penalties or # step pooling during the sampling/pooling process. # Hence copy these tensors only when there are requests which # need penalties/step_pooler to be applied. prompt_token_ids = self._make_prompt_token_ids_tensor() else: prompt_token_ids = None allowed_token_ids_mask: Optional[torch.Tensor] = None if not self.no_allowed_token_ids: assert self.allowed_token_ids_mask is not None copy_slice(self.allowed_token_ids_mask_cpu_tensor, self.allowed_token_ids_mask, num_reqs) allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] return SamplingMetadata( temperature=self.temperature_cpu_tensor[:num_reqs].to(self.device), all_greedy=self.all_greedy, all_random=self.all_random, # top_p=None if self.no_top_p else self.top_p_cpu_tensor[:num_reqs].to(self.device), # top_k=None if self.no_top_k else torch.tensor([40 for _ in range(num_reqs)]).to(torch.int32).to(self.device), top_p=torch.tensor([1 for _ in range(num_reqs)]).to(torch.float32).to(self.device) if self.no_top_p else self.top_p_cpu_tensor[:num_reqs].to(self.device), top_k=torch.tensor([40 for _ in range(num_reqs)]).to(torch.int32).to(self.device), generators=self.generators, max_num_logprobs=self.max_num_logprobs, prompt_token_ids=prompt_token_ids, frequency_penalties=self.frequency_penalties_cpu_tensor[:num_reqs].tolist(), presence_penalties=self.presence_penalties_cpu_tensor[:num_reqs].tolist(), repetition_penalties=self.repetition_penalties_cpu_tensor[:num_reqs].tolist(), output_token_ids=cast(list[list[int]], self.req_output_token_ids), no_penalties=self.no_penalties, allowed_token_ids_mask=allowed_token_ids_mask, bad_words_token_ids=self.bad_words_token_ids, logitsprocs=self.logitsprocs, temperature_cpu = self.temperature_cpu_tensor[:num_reqs], top_p_cpu = self.top_p_cpu_tensor[:num_reqs], top_k_cpu = torch.tensor([40 for _ in range(num_reqs)]).to(torch.int32), )