# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, TypeVar import torch from vllm import SamplingParams from vllm.v1.sample.logits_processor.interface import ( BatchUpdate, LogitsProcessor, MoveDirectionality, ) if TYPE_CHECKING: from vllm.config import VllmConfig T = TypeVar("T") class MinPLogitsProcessor(LogitsProcessor): def __init__( self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool ): max_num_reqs = vllm_config.scheduler_config.max_num_seqs self.min_p_count: int = 0 self.min_p_cpu_tensor = torch.zeros( (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=is_pin_memory ) self.min_p_cpu = self.min_p_cpu_tensor.numpy() self.use_double_tensor = torch.device(device).type != "cpu" if self.use_double_tensor: # Pre-allocated device tensor self.min_p_device: torch.Tensor = torch.empty( (max_num_reqs,), dtype=torch.float32, device=device ) else: self.min_p_device = self.min_p_cpu_tensor # Current slice of the device tensor self.min_p: torch.Tensor = self.min_p_device[:0] def is_argmax_invariant(self) -> bool: """Min-p never impacts greedy sampling""" return True def get_min_p_by_index(self, index: int) -> float: return float(self.min_p_cpu[index]) def update_state(self, batch_update: BatchUpdate | None): if not batch_update: return needs_update = False # Process added requests. for index, params, _, _ in batch_update.added: min_p = params.min_p min_p_before = self.min_p_cpu[index] if min_p_before != min_p: needs_update = True self.min_p_cpu[index] = min_p if min_p and not min_p_before: self.min_p_count += 1 elif not min_p and min_p_before: self.min_p_count -= 1 if self.min_p_count: # Process removed requests. if batch_update.removed: needs_update = True for index in batch_update.removed: if self.min_p_cpu[index]: self.min_p_cpu[index] = 0 self.min_p_count -= 1 # Process moved requests, unidirectional (a->b) and swap (a<->b). for adx, bdx, direct in batch_update.moved: min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx] if min_p_a != min_p_b: needs_update = True self.min_p_cpu[bdx] = min_p_a if direct == MoveDirectionality.SWAP: self.min_p_cpu[adx] = min_p_b if direct == MoveDirectionality.UNIDIRECTIONAL: if min_p_a: self.min_p_cpu[adx] = 0 if min_p_b: self.min_p_count -= 1 # Update tensors if needed. size = batch_update.batch_size if self.min_p_count and (needs_update or self.min_p.shape[0] != size): self.min_p = self.min_p_device[:size] if self.use_double_tensor: self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True) self.min_p.unsqueeze_(1) def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.min_p_count: return logits # Convert logits to probability distribution probability_values = torch.nn.functional.softmax(logits, dim=-1) # Calculate maximum probabilities per sequence max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) # Adjust min_p adjusted_min_p = max_probabilities.mul_(self.min_p) # Identify valid tokens using threshold comparison invalid_token_mask = probability_values < adjusted_min_p # Apply mask using boolean indexing logits[invalid_token_mask] = -float("inf") return logits class LogitBiasLogitsProcessor(LogitsProcessor): def __init__(self, _, device: torch.device, is_pin_memory: bool): self.device = device self.pin_memory = is_pin_memory self.biases: dict[int, dict[int, float]] = {} self.bias_tensor: torch.Tensor = torch.tensor(()) self.logits_slice = ( self._device_tensor([], torch.int32), self._device_tensor([], torch.int32), ) def is_argmax_invariant(self) -> bool: """Logit bias can rebalance token probabilities and change the outcome of argmax in greedy sampling.""" return False def update_state(self, batch_update: BatchUpdate | None): needs_update = process_dict_updates( self.biases, batch_update, lambda params, _, __: params.logit_bias or None ) # Update tensors if needed. if needs_update: reqs: list[int] = [] tok_ids: list[int] = [] biases: list[float] = [] for req, lb in self.biases.items(): reqs.extend([req] * len(lb)) tok_ids.extend(lb.keys()) biases.extend(lb.values()) self.bias_tensor = self._device_tensor(biases, torch.float32) self.logits_slice = ( self._device_tensor(reqs, torch.int32), self._device_tensor(tok_ids, torch.int32), ) def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: return torch.tensor( data, device="cpu", dtype=dtype, pin_memory=self.pin_memory ).to(device=self.device, non_blocking=True) def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.biases: logits[self.logits_slice] += self.bias_tensor return logits class MinTokensLogitsProcessor(LogitsProcessor): def __init__( self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool ): # index -> (min_toks, output_token_ids, stop_token_ids) self.device = device self.pin_memory = is_pin_memory self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} # (req_idx_tensor,eos_tok_id_tensor) self.logits_slice: tuple[torch.Tensor, torch.Tensor] = ( self._device_tensor([], torch.int32), self._device_tensor([], torch.int32), ) def is_argmax_invariant(self) -> bool: """By censoring stop tokens, min-tokens can change the outcome of the argmax operation in greedy sampling.""" return False @staticmethod def add_request( params: SamplingParams, _: list[int] | None, output_tok_ids: list[int] ) -> tuple[int, Sequence[int], set[int]] | None: min_tokens = params.min_tokens if not min_tokens or len(output_tok_ids) >= min_tokens: return None return min_tokens, output_tok_ids, params.all_stop_token_ids def update_state(self, batch_update: BatchUpdate | None): needs_update = process_dict_updates( self.min_toks, batch_update, self.add_request ) if self.min_toks: # Check for any requests that have attained their min tokens. to_remove = tuple( index for index, (min_toks, out_tok_ids, _) in self.min_toks.items() if len(out_tok_ids) >= min_toks ) if to_remove: needs_update = True for index in to_remove: del self.min_toks[index] # Update tensors if needed. if needs_update: reqs: list[int] = [] tok_ids: list[int] = [] for req, (_, _, stop_tok_ids) in self.min_toks.items(): reqs.extend([req] * len(stop_tok_ids)) tok_ids.extend(stop_tok_ids) self.logits_slice = ( self._device_tensor(reqs, torch.int32), self._device_tensor(tok_ids, torch.int32), ) def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: return torch.tensor( data, device="cpu", dtype=dtype, pin_memory=self.pin_memory ).to(device=self.device, non_blocking=True) def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.min_toks: # Inhibit EOS token for requests which have not reached min length logits[self.logits_slice] = -float("inf") return logits def process_dict_updates( req_entries: dict[int, T], batch_update: BatchUpdate | None, new_state: Callable[[SamplingParams, list[int] | None, list[int]], T | None], ) -> bool: """Utility function to update dict state for sparse LogitsProcessors.""" if not batch_update: # Nothing to do. return False updated = False for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: if (state := new_state(params, prompt_tok_ids, output_tok_ids)) is not None: req_entries[index] = state updated = True elif req_entries.pop(index, None) is not None: updated = True if req_entries: # Process removed requests. for index in batch_update.removed: if req_entries.pop(index, None): updated = True # Process moved requests, unidirectional (a->b) and # swapped (a<->b) for a_index, b_index, direct in batch_update.moved: a_entry = req_entries.pop(a_index, None) b_entry = req_entries.pop(b_index, None) if a_entry is not None: req_entries[b_index] = a_entry updated = True if b_entry is not None: updated = True if direct == MoveDirectionality.SWAP: req_entries[a_index] = b_entry return updated