Files
2026-03-05 18:06:10 +08:00

275 lines
9.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, TypeVar
import torch
from vllm import SamplingParams
from vllm.v1.sample.logits_processor.interface import (
BatchUpdate,
LogitsProcessor,
MoveDirectionality,
)
if TYPE_CHECKING:
from vllm.config import VllmConfig
T = TypeVar("T")
class MinPLogitsProcessor(LogitsProcessor):
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.min_p_count: int = 0
self.min_p_cpu_tensor = torch.zeros(
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=is_pin_memory
)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.use_double_tensor = torch.device(device).type != "cpu"
if self.use_double_tensor:
# Pre-allocated device tensor
self.min_p_device: torch.Tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device=device
)
else:
self.min_p_device = self.min_p_cpu_tensor
# Current slice of the device tensor
self.min_p: torch.Tensor = self.min_p_device[:0]
def is_argmax_invariant(self) -> bool:
"""Min-p never impacts greedy sampling"""
return True
def get_min_p_by_index(self, index: int) -> float:
return float(self.min_p_cpu[index])
def update_state(self, batch_update: BatchUpdate | None):
if not batch_update:
return
needs_update = False
# Process added requests.
for index, params, _, _ in batch_update.added:
min_p = params.min_p
min_p_before = self.min_p_cpu[index]
if min_p_before != min_p:
needs_update = True
self.min_p_cpu[index] = min_p
if min_p and not min_p_before:
self.min_p_count += 1
elif not min_p and min_p_before:
self.min_p_count -= 1
if self.min_p_count:
# Process removed requests.
if batch_update.removed:
needs_update = True
for index in batch_update.removed:
if self.min_p_cpu[index]:
self.min_p_cpu[index] = 0
self.min_p_count -= 1
# Process moved requests, unidirectional (a->b) and swap (a<->b).
for adx, bdx, direct in batch_update.moved:
min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx]
if min_p_a != min_p_b:
needs_update = True
self.min_p_cpu[bdx] = min_p_a
if direct == MoveDirectionality.SWAP:
self.min_p_cpu[adx] = min_p_b
if direct == MoveDirectionality.UNIDIRECTIONAL:
if min_p_a:
self.min_p_cpu[adx] = 0
if min_p_b:
self.min_p_count -= 1
# Update tensors if needed.
size = batch_update.batch_size
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
self.min_p = self.min_p_device[:size]
if self.use_double_tensor:
self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
self.min_p.unsqueeze_(1)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.min_p_count:
return logits
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
# Adjust min_p
adjusted_min_p = max_probabilities.mul_(self.min_p)
# Identify valid tokens using threshold comparison
invalid_token_mask = probability_values < adjusted_min_p
# Apply mask using boolean indexing
logits[invalid_token_mask] = -float("inf")
return logits
class LogitBiasLogitsProcessor(LogitsProcessor):
def __init__(self, _, device: torch.device, is_pin_memory: bool):
self.device = device
self.pin_memory = is_pin_memory
self.biases: dict[int, dict[int, float]] = {}
self.bias_tensor: torch.Tensor = torch.tensor(())
self.logits_slice = (
self._device_tensor([], torch.int32),
self._device_tensor([], torch.int32),
)
def is_argmax_invariant(self) -> bool:
"""Logit bias can rebalance token probabilities and change the
outcome of argmax in greedy sampling."""
return False
def update_state(self, batch_update: BatchUpdate | None):
needs_update = process_dict_updates(
self.biases, batch_update, lambda params, _, __: params.logit_bias or None
)
# Update tensors if needed.
if needs_update:
reqs: list[int] = []
tok_ids: list[int] = []
biases: list[float] = []
for req, lb in self.biases.items():
reqs.extend([req] * len(lb))
tok_ids.extend(lb.keys())
biases.extend(lb.values())
self.bias_tensor = self._device_tensor(biases, torch.float32)
self.logits_slice = (
self._device_tensor(reqs, torch.int32),
self._device_tensor(tok_ids, torch.int32),
)
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
return torch.tensor(
data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
).to(device=self.device, non_blocking=True)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.biases:
logits[self.logits_slice] += self.bias_tensor
return logits
class MinTokensLogitsProcessor(LogitsProcessor):
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
# index -> (min_toks, output_token_ids, stop_token_ids)
self.device = device
self.pin_memory = is_pin_memory
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
# (req_idx_tensor,eos_tok_id_tensor)
self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (
self._device_tensor([], torch.int32),
self._device_tensor([], torch.int32),
)
def is_argmax_invariant(self) -> bool:
"""By censoring stop tokens, min-tokens can change the outcome
of the argmax operation in greedy sampling."""
return False
@staticmethod
def add_request(
params: SamplingParams, _: list[int] | None, output_tok_ids: list[int]
) -> tuple[int, Sequence[int], set[int]] | None:
min_tokens = params.min_tokens
if not min_tokens or len(output_tok_ids) >= min_tokens:
return None
return min_tokens, output_tok_ids, params.all_stop_token_ids
def update_state(self, batch_update: BatchUpdate | None):
needs_update = process_dict_updates(
self.min_toks, batch_update, self.add_request
)
if self.min_toks:
# Check for any requests that have attained their min tokens.
to_remove = tuple(
index
for index, (min_toks, out_tok_ids, _) in self.min_toks.items()
if len(out_tok_ids) >= min_toks
)
if to_remove:
needs_update = True
for index in to_remove:
del self.min_toks[index]
# Update tensors if needed.
if needs_update:
reqs: list[int] = []
tok_ids: list[int] = []
for req, (_, _, stop_tok_ids) in self.min_toks.items():
reqs.extend([req] * len(stop_tok_ids))
tok_ids.extend(stop_tok_ids)
self.logits_slice = (
self._device_tensor(reqs, torch.int32),
self._device_tensor(tok_ids, torch.int32),
)
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
return torch.tensor(
data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
).to(device=self.device, non_blocking=True)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.min_toks:
# Inhibit EOS token for requests which have not reached min length
logits[self.logits_slice] = -float("inf")
return logits
def process_dict_updates(
req_entries: dict[int, T],
batch_update: BatchUpdate | None,
new_state: Callable[[SamplingParams, list[int] | None, list[int]], T | None],
) -> bool:
"""Utility function to update dict state for sparse LogitsProcessors."""
if not batch_update:
# Nothing to do.
return False
updated = False
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
if (state := new_state(params, prompt_tok_ids, output_tok_ids)) is not None:
req_entries[index] = state
updated = True
elif req_entries.pop(index, None) is not None:
updated = True
if req_entries:
# Process removed requests.
for index in batch_update.removed:
if req_entries.pop(index, None):
updated = True
# Process moved requests, unidirectional (a->b) and
# swapped (a<->b)
for a_index, b_index, direct in batch_update.moved:
a_entry = req_entries.pop(a_index, None)
b_entry = req_entries.pop(b_index, None)
if a_entry is not None:
req_entries[b_index] = a_entry
updated = True
if b_entry is not None:
updated = True
if direct == MoveDirectionality.SWAP:
req_entries[a_index] = b_entry
return updated