Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

View File

@@ -0,0 +1,352 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import inspect
import itertools
from abc import abstractmethod
from collections.abc import Sequence
from functools import lru_cache, partial
from typing import TYPE_CHECKING
import torch
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor
from vllm.sampling_params import SamplingParams
from vllm.utils.torch_utils import guard_cuda_initialization
from vllm.v1.sample.logits_processor.builtin import (
LogitBiasLogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
process_dict_updates,
)
from vllm.v1.sample.logits_processor.interface import (
BatchUpdate,
LogitsProcessor,
MoveDirectionality,
)
from vllm.v1.sample.logits_processor.state import BatchUpdateBuilder, LogitsProcessors
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
# Error message when the user tries to initialize vLLM with a pooling model
# and custom logitsproces
STR_POOLING_REJECTS_LOGITSPROCS = (
"Pooling models do not support custom logits processors."
)
# Error message when the user tries to initialize vLLM with a speculative
# decoding enabled and custom logitsproces
STR_SPEC_DEC_REJECTS_LOGITSPROCS = (
"Custom logits processors are not supported when speculative decoding is enabled."
)
LOGITSPROCS_GROUP = "vllm.logits_processors"
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
MinTokensLogitsProcessor,
LogitBiasLogitsProcessor,
MinPLogitsProcessor,
]
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
"""Load all installed logit processor plugins"""
from importlib.metadata import entry_points
installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP)
if len(installed_logitsprocs_plugins) == 0:
logger.debug("No logitsprocs plugins installed (group %s).", LOGITSPROCS_GROUP)
return []
# Load logitsprocs plugins
logger.debug("Loading installed logitsprocs plugins (group %s):", LOGITSPROCS_GROUP)
classes: list[type[LogitsProcessor]] = []
for entrypoint in installed_logitsprocs_plugins:
try:
logger.debug(
"- Loading logitproc plugin entrypoint=%s target=%s",
entrypoint.name,
entrypoint.value,
)
with guard_cuda_initialization():
classes.append(entrypoint.load())
except Exception as e:
logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e)
raise RuntimeError(
f"Failed to load LogitsProcessor plugin {entrypoint}"
) from e
return classes
def _load_logitsprocs_by_fqcns(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
) -> list[type[LogitsProcessor]]:
"""Load logit processor types, identifying them by fully-qualified class
names (FQCNs).
Effectively, a mixed list of logitproc types and FQCN strings is converted
into a list of entirely logitproc types, by loading from the FQCNs.
FQCN syntax is <module>:<type> i.e. x.y.z:CustomLogitProc
Already-loaded logitproc types must be subclasses of LogitsProcessor
Args:
logits_processors: Potentially mixed list of logitsprocs types and FQCN
strings for logitproc types
Returns:
List of logitproc types
"""
if not logits_processors:
return []
logger.debug(
"%s additional custom logits processors specified, checking whether "
"they need to be loaded.",
len(logits_processors),
)
classes: list[type[LogitsProcessor]] = []
for ldx, logitproc in enumerate(logits_processors):
if isinstance(logitproc, type):
logger.debug(" - Already-loaded logit processor: %s", logitproc.__name__)
if not issubclass(logitproc, LogitsProcessor):
raise ValueError(
f"{logitproc.__name__} is not a subclass of LogitsProcessor"
)
classes.append(logitproc)
continue
logger.debug("- Loading logits processor %s", logitproc)
module_path, qualname = logitproc.split(":")
try:
# Load module
with guard_cuda_initialization():
module = importlib.import_module(module_path)
except Exception as e:
logger.error(
"Failed to load %sth LogitsProcessor plugin %s: %s",
ldx,
logitproc,
e,
)
raise RuntimeError(
f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
) from e
# Walk down dotted name to get logitproc class
obj = module
for attr in qualname.split("."):
obj = getattr(obj, attr)
if not isinstance(obj, type):
raise ValueError("Loaded logit processor must be a type.")
if not issubclass(obj, LogitsProcessor):
raise ValueError(f"{obj.__name__} must be a subclass of LogitsProcessor")
classes.append(obj)
return classes
def _load_custom_logitsprocs(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
) -> list[type[LogitsProcessor]]:
"""Load all custom logits processors.
* First load all installed logitproc plugins
* Second load custom logitsprocs pass by the user at initialization time
Args:
logits_processors: potentially mixed list of logitproc types and
logitproc type fully-qualified names (FQCNs)
which need to be loaded
Returns:
A list of all loaded logitproc types
"""
from vllm.platforms import current_platform
if current_platform.is_tpu():
# No logitsprocs specified by caller
# TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs
return []
return _load_logitsprocs_plugins() + _load_logitsprocs_by_fqcns(logits_processors)
def build_logitsprocs(
vllm_config: "VllmConfig",
device: torch.device,
is_pin_memory: bool,
is_pooling_model: bool,
custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (),
) -> LogitsProcessors:
if is_pooling_model:
if custom_logitsprocs:
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
logger.debug(
"Skipping logits processor loading because pooling models"
" do not support logits processors."
)
return LogitsProcessors()
# Check if speculative decoding is enabled.
if vllm_config.speculative_config:
if custom_logitsprocs:
raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS)
logger.warning(
"min_p, logit_bias, and min_tokens parameters won't currently work "
"with speculative decoding enabled."
)
return LogitsProcessors()
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
return LogitsProcessors(
ctor(vllm_config, device, is_pin_memory)
for ctor in itertools.chain(
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes
)
)
cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
def validate_logits_processors_parameters(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
sampling_params: SamplingParams,
):
logits_processors = (
tuple(logits_processors) if logits_processors is not None else None
)
for logits_procs in cached_load_custom_logitsprocs(logits_processors):
logits_procs.validate_params(sampling_params)
class AdapterLogitsProcessor(LogitsProcessor):
"""Wrapper for per-request logits processors
To wrap a specific per-request logits processor,
* Subclass `AdapterLogitsProcessor`
* Implement `self.is_argmax_invariant()` base-class method
* Implement `self.new_req_logits_processor(params)`
`self.__init__(vllm_config, device, is_pin_memory)` does not need to be
overridden in general. However, to implement custom constructor behavior -
especially any logic which operates on or stores `vllm_config`, `device`,
or `is_pin_memory` - `self.__init__(vllm_config, device, is_pin_memory)`
must be overridden and the override must call
`super().__init__(vllm_config, device, is_pin_memory)`
"""
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
"""Subclass must invoke
`super().__init__(vllm_config, device, is_pin_memory)`.
Subclass constructor may find it useful to utilize the `vllm_config`,
`device` and `is_pin_memory` argument. However regardless of whether
these arguments are used, the vLLM logits processor interface requires
all three arguments to be present.
"""
# Map req index -> logits processor state
#
# State representation is a partial[Tensor] comprising a request-level
# logits processor with the output token ids argument and (if required)
# the prompt token ids argument pre-populated
#
# Note that the partial carries a *reference* to output token ids, and
# will thus always operate on the list as it is currently, not as it
# was when the partial was created.
self.req_info: dict[int, partial[torch.Tensor]] = {}
@abstractmethod
def new_req_logits_processor(
self,
params: SamplingParams,
) -> RequestLogitsProcessor | None:
"""Consume request info; return a per-request logits processor.
Return None if logits processor does not need to be applied to request
Args:
params: request sampling params
Returns:
None if logits processor should not be applied to request; otherwise
returns a `RequestLogitsProcessor` instance
"""
raise NotImplementedError
def _new_state(
self,
params: SamplingParams,
prompt_ids: list[int] | None,
output_ids: list[int],
) -> partial[torch.Tensor] | None:
"""Return state representation for new request
Returns None if logits processor is not applicable to request
Args:
params: request sampling params
prompt_ids: request prompt token ids
output_ids: decoded tokens so far for this request
Returns:
logits processor partial[Tensor] or None
"""
if req_lp := self.new_req_logits_processor(params):
args = (
[prompt_ids, output_ids]
if (len(inspect.signature(req_lp).parameters) == 3)
else [output_ids]
)
return partial(req_lp, *args) # type: ignore[misc]
return None
def update_state(self, batch_update: BatchUpdate | None):
process_dict_updates(
self.req_info,
batch_update,
self._new_state,
)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.req_info:
# Apply per-request logits processors to corresponding rows of
# logits tensor
for req_idx, req_lp in self.req_info.items():
req_logits = logits[req_idx]
new_logits = req_lp(req_logits)
if new_logits is not req_logits:
# Modify logits tensor row in-place if necessary
logits[req_idx] = new_logits
return logits
__all__ = [
"LogitsProcessor",
"LogitBiasLogitsProcessor",
"MinPLogitsProcessor",
"MinTokensLogitsProcessor",
"BatchUpdate",
"BatchUpdateBuilder",
"MoveDirectionality",
"LogitsProcessors",
"build_logitsprocs",
"STR_POOLING_REJECTS_LOGITSPROCS",
"LOGITSPROCS_GROUP",
"AdapterLogitsProcessor",
]

View File

@@ -0,0 +1,278 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, TypeVar
import torch
from vllm import SamplingParams
from vllm.v1.sample.logits_processor.interface import (
BatchUpdate,
LogitsProcessor,
MoveDirectionality,
)
if TYPE_CHECKING:
from vllm.config import VllmConfig
T = TypeVar("T")
class MinPLogitsProcessor(LogitsProcessor):
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.min_p_count: int = 0
self.min_p_cpu_tensor = torch.zeros(
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=is_pin_memory
)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.use_double_tensor = torch.device(device).type != "cpu"
if self.use_double_tensor:
# Pre-allocated device tensor
self.min_p_device: torch.Tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device=device
)
else:
self.min_p_device = self.min_p_cpu_tensor
# Current slice of the device tensor
self.min_p: torch.Tensor = self.min_p_device[:0]
def is_argmax_invariant(self) -> bool:
"""Min-p never impacts greedy sampling"""
return True
def get_min_p_by_index(self, index: int) -> float:
return float(self.min_p_cpu[index])
def update_state(self, batch_update: BatchUpdate | None):
if not batch_update:
return
needs_update = False
# Process added requests.
for index, params, _, _ in batch_update.added:
min_p = params.min_p
min_p_before = self.min_p_cpu[index]
if min_p_before != min_p:
needs_update = True
self.min_p_cpu[index] = min_p
if min_p and not min_p_before:
self.min_p_count += 1
elif not min_p and min_p_before:
self.min_p_count -= 1
if self.min_p_count:
# Process removed requests.
if batch_update.removed:
needs_update = True
for index in batch_update.removed:
if self.min_p_cpu[index]:
self.min_p_cpu[index] = 0
self.min_p_count -= 1
# Process moved requests, unidirectional (a->b) and swap (a<->b).
for adx, bdx, direct in batch_update.moved:
min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx]
if min_p_a != min_p_b:
needs_update = True
self.min_p_cpu[bdx] = min_p_a
if direct == MoveDirectionality.SWAP:
self.min_p_cpu[adx] = min_p_b
if direct == MoveDirectionality.UNIDIRECTIONAL:
if min_p_a:
self.min_p_cpu[adx] = 0
if min_p_b:
self.min_p_count -= 1
# Update tensors if needed.
size = batch_update.batch_size
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
self.min_p = self.min_p_device[:size]
if self.use_double_tensor:
self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
self.min_p.unsqueeze_(1)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.min_p_count:
return logits
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
# Adjust min_p
adjusted_min_p = max_probabilities.mul_(self.min_p)
# Identify valid tokens using threshold comparison
invalid_token_mask = probability_values < adjusted_min_p
# Apply mask using boolean indexing
logits.masked_fill_(invalid_token_mask, -float("inf"))
return logits
class LogitBiasLogitsProcessor(LogitsProcessor):
def __init__(self, _, device: torch.device, is_pin_memory: bool):
self.device = device
self.pin_memory = is_pin_memory
self.biases: dict[int, dict[int, float]] = {}
self.bias_tensor: torch.Tensor = torch.tensor(())
self.logits_slice = (
self._device_tensor([], torch.int32),
self._device_tensor([], torch.int32),
)
def is_argmax_invariant(self) -> bool:
"""Logit bias can rebalance token probabilities and change the
outcome of argmax in greedy sampling."""
return False
def update_state(self, batch_update: BatchUpdate | None):
needs_update = process_dict_updates(
self.biases, batch_update, lambda params, _, __: params.logit_bias or None
)
# Update tensors if needed.
if needs_update:
reqs: list[int] = []
tok_ids: list[int] = []
biases: list[float] = []
for req, lb in self.biases.items():
reqs.extend([req] * len(lb))
tok_ids.extend(lb.keys())
biases.extend(lb.values())
self.bias_tensor = self._device_tensor(biases, torch.float32)
self.logits_slice = (
self._device_tensor(reqs, torch.int32),
self._device_tensor(tok_ids, torch.int32),
)
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
return torch.tensor(
data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
).to(device=self.device, non_blocking=True)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.biases:
logits[self.logits_slice] += self.bias_tensor
return logits
class MinTokensLogitsProcessor(LogitsProcessor):
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
# index -> (min_toks, output_token_ids, stop_token_ids)
self.device = device
self.pin_memory = is_pin_memory
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
# (req_idx_tensor,eos_tok_id_tensor)
self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (
self._device_tensor([], torch.int32),
self._device_tensor([], torch.int32),
)
self.neg_inf_tensor = torch.tensor(
-float("inf"), dtype=torch.float32, device=self.device
)
def is_argmax_invariant(self) -> bool:
"""By censoring stop tokens, min-tokens can change the outcome
of the argmax operation in greedy sampling."""
return False
@staticmethod
def add_request(
params: SamplingParams, _: list[int] | None, output_tok_ids: list[int]
) -> tuple[int, Sequence[int], set[int]] | None:
min_tokens = params.min_tokens
if not min_tokens or len(output_tok_ids) >= min_tokens:
return None
return min_tokens, output_tok_ids, params.all_stop_token_ids
def update_state(self, batch_update: BatchUpdate | None):
needs_update = process_dict_updates(
self.min_toks, batch_update, self.add_request
)
if self.min_toks:
# Check for any requests that have attained their min tokens.
to_remove = tuple(
index
for index, (min_toks, out_tok_ids, _) in self.min_toks.items()
if len(out_tok_ids) >= min_toks
)
if to_remove:
needs_update = True
for index in to_remove:
del self.min_toks[index]
# Update tensors if needed.
if needs_update:
reqs: list[int] = []
tok_ids: list[int] = []
for req, (_, _, stop_tok_ids) in self.min_toks.items():
reqs.extend([req] * len(stop_tok_ids))
tok_ids.extend(stop_tok_ids)
self.logits_slice = (
self._device_tensor(reqs, torch.int32),
self._device_tensor(tok_ids, torch.int32),
)
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
return torch.tensor(
data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
).to(device=self.device, non_blocking=True)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.min_toks:
# Inhibit EOS token for requests which have not reached min length
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
return logits
def process_dict_updates(
req_entries: dict[int, T],
batch_update: BatchUpdate | None,
new_state: Callable[[SamplingParams, list[int] | None, list[int]], T | None],
) -> bool:
"""Utility function to update dict state for sparse LogitsProcessors."""
if not batch_update:
# Nothing to do.
return False
updated = False
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
if (state := new_state(params, prompt_tok_ids, output_tok_ids)) is not None:
req_entries[index] = state
updated = True
elif req_entries.pop(index, None) is not None:
updated = True
if req_entries:
# Process removed requests.
for index in batch_update.removed:
if req_entries.pop(index, None):
updated = True
# Process moved requests, unidirectional (a->b) and
# swapped (a<->b)
for a_index, b_index, direct in batch_update.moved:
a_entry = req_entries.pop(a_index, None)
b_entry = req_entries.pop(b_index, None)
if a_entry is not None:
req_entries[b_index] = a_entry
updated = True
if b_entry is not None:
updated = True
if direct == MoveDirectionality.SWAP:
req_entries[a_index] = b_entry
return updated

View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, Optional
import torch
from vllm import SamplingParams
if TYPE_CHECKING:
from vllm.config import VllmConfig
class MoveDirectionality(Enum):
# One-way i1->i2 req move within batch
UNIDIRECTIONAL = auto()
# Two-way i1<->i2 req swap within batch
SWAP = auto()
# Batch indices of any removed requests.
RemovedRequest = int
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
# requests added to the batch.
AddedRequest = tuple[int, SamplingParams, list[int] | None, list[int]]
# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
@dataclass(frozen=True)
class BatchUpdate:
"""Persistent batch state change info for logitsprocs"""
batch_size: int # Current num reqs in batch
# Metadata for requests added to, removed from, and moved
# within the persistent batch.
#
# Key assumption: the `output_tok_ids` list (which is an element of each
# tuple in `added`) is a reference to the request's running output tokens
# list; via this reference, the logits processors always see the latest
# list of generated output tokens.
#
# NOTE:
# * Added or moved requests may replace existing requests with the same
# index.
# * Operations should be processed in the following order:
# - removed, added, moved
removed: Sequence[RemovedRequest]
added: Sequence[AddedRequest]
moved: Sequence[MovedRequest]
class LogitsProcessor(ABC):
@classmethod
def validate_params(cls, sampling_params: SamplingParams):
"""Validate sampling params for this logits processor.
Raise ValueError for invalid ones.
"""
return None
@abstractmethod
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
) -> None:
raise NotImplementedError
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
"""Apply LogitsProcessor to batch logits tensor.
The updated tensor must be returned but may be
modified in-place.
"""
raise NotImplementedError
@abstractmethod
def is_argmax_invariant(self) -> bool:
"""True if logits processor has no impact on the
argmax computation in greedy sampling.
NOTE: may or may not have the same value for all
instances of a given LogitsProcessor subclass,
depending on subclass implementation.
"""
raise NotImplementedError
@abstractmethod
def update_state(
self,
batch_update: Optional["BatchUpdate"],
) -> None:
"""Called when there are new output tokens, prior
to each forward pass.
Args:
batch_update: Non-None iff there have been changes
to the batch makeup.
"""
raise NotImplementedError

View File

@@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator
from itertools import chain
from typing import TYPE_CHECKING
from vllm.v1.sample.logits_processor.interface import (
AddedRequest,
BatchUpdate,
MovedRequest,
RemovedRequest,
)
if TYPE_CHECKING:
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
class BatchUpdateBuilder:
"""Helps track persistent batch state changes and build
a batch update data structure for logitsprocs
Assumptions:
* All information about requests removed from persistent batch
during a step is aggregated in self._removed through calls to
self.removed_append() at the beginning of a step. This must happen
before the first time that self.removed, self.pop_removed()
or self.peek_removed() are invoked in a given step
* After the first time that self.removed, self.pop_removed()
or self.peek_removed() are read in a step, no new removals
are registered using self.removed_append()
* Elements of self._removed are never directly modified, added or
removed (i.e. modification is only via self.removed_append() and
self.pop_removed())
Guarantees under above assumptions:
* self.removed is always sorted in descending order
* self.pop_removed() and self.peek_removed() both return
the lowest removed request index in the current step
"""
_removed: list[RemovedRequest]
_is_removed_sorted: bool
added: list[AddedRequest]
moved: list[MovedRequest]
def __init__(
self,
removed: list[RemovedRequest] | None = None,
added: list[AddedRequest] | None = None,
moved: list[MovedRequest] | None = None,
) -> None:
self._removed = removed or []
self.added = added or []
self.moved = moved or []
self._is_removed_sorted = False
# Used to track changes in the pooling case
# where we don't populate the added list.
self.batch_changed = False
def _ensure_removed_sorted(self) -> None:
"""Sort removed request indices in
descending order.
Idempotent after first call in a
given step, until reset.
"""
if not self._is_removed_sorted:
self._removed.sort(reverse=True)
self._is_removed_sorted = True
@property
def removed(self) -> list[RemovedRequest]:
"""Removed request indices sorted in
descending order"""
self._ensure_removed_sorted()
return self._removed
def removed_append(self, index: int) -> None:
"""Register the removal of a request from the persistent batch.
Must not be called after the first time self.removed,
self.pop_removed() or self.peek_removed() are invoked.
Args:
index: request index
"""
if self._is_removed_sorted:
raise RuntimeError(
"Cannot register new removed request after self.removed has been read."
)
self._removed.append(index)
self.batch_changed = True
def has_removed(self) -> bool:
return bool(self._removed)
def peek_removed(self) -> int | None:
"""Return lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed[-1]
return None
def pop_removed(self) -> int | None:
"""Pop lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed.pop()
return None
def reset(self) -> bool:
"""Returns True if there were any changes to the batch."""
self._is_removed_sorted = False
self._removed.clear()
self.added.clear()
self.moved.clear()
batch_changed = self.batch_changed
self.batch_changed = False
return batch_changed
def get_and_reset(self, batch_size: int) -> BatchUpdate | None:
"""Generate a logitsprocs batch update data structure and reset
internal batch update builder state.
Args:
batch_size: current persistent batch size
Returns:
Frozen logitsprocs batch update instance; `None` if no updates
"""
# Reset removal-sorting logic
self._is_removed_sorted = False
self.batch_changed = False
if not any((self._removed, self.moved, self.added)):
# No update; short-circuit
return None
# Build batch state update
batch_update = BatchUpdate(
batch_size=batch_size,
removed=self._removed,
moved=self.moved,
added=self.added,
)
self._removed = []
self.moved = []
self.added = []
return batch_update
class LogitsProcessors:
"""Encapsulates initialized logitsproc objects."""
def __init__(self, logitsprocs: Iterator["LogitsProcessor"] | None = None) -> None:
self.argmax_invariant: list[LogitsProcessor] = []
self.non_argmax_invariant: list[LogitsProcessor] = []
if logitsprocs:
for logitproc in logitsprocs:
(
self.argmax_invariant
if logitproc.is_argmax_invariant()
else self.non_argmax_invariant
).append(logitproc)
@property
def all(self) -> Iterator["LogitsProcessor"]:
"""Iterator over all logits processors."""
return chain(self.argmax_invariant, self.non_argmax_invariant)

View File

@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.v1.sample.logits_processor import LogitsProcessors
@dataclass
class SamplingMetadata:
temperature: torch.Tensor | None
all_greedy: bool
all_random: bool
top_p: torch.Tensor | None
top_k: torch.Tensor | None
generators: dict[int, torch.Generator]
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: int | None
no_penalties: bool
prompt_token_ids: torch.Tensor | None
frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
output_token_ids: list[list[int]]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask: torch.Tensor | None
# req_index -> bad_words_token_ids
bad_words_token_ids: dict[int, list[list[int]]]
# Loaded logits processors
logitsprocs: LogitsProcessors
# Speculative token ids
spec_token_ids: list[list[int]] | None = None

View File

View File

@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
_SMALLEST_LOGIT = float("-inf")
def _apply_bad_words_single_batch(
logits: torch.Tensor,
bad_words_token_ids: list[list[int]],
past_tokens_ids: list[int],
) -> None:
for bad_word_ids in bad_words_token_ids:
if len(bad_word_ids) > len(past_tokens_ids) + 1:
continue
prefix_length = len(bad_word_ids) - 1
last_token_id = bad_word_ids[-1]
actual_prefix = past_tokens_ids[-prefix_length:] if prefix_length > 0 else []
expected_prefix = bad_word_ids[:prefix_length]
assert len(actual_prefix) == len(expected_prefix)
if actual_prefix == expected_prefix:
logits[last_token_id] = _SMALLEST_LOGIT
def apply_bad_words(
logits: torch.Tensor,
bad_words_token_ids: dict[int, list[list[int]]],
past_tokens_ids: list[list[int]],
) -> None:
for i, bad_words_ids in bad_words_token_ids.items():
_apply_bad_words_single_batch(logits[i], bad_words_ids, past_tokens_ids[i])
def apply_bad_words_with_drafts(
logits: torch.Tensor,
bad_words_token_ids: dict[int, list[list[int]]],
past_tokens_ids: list[list[int]],
num_draft_tokens: list[int],
) -> None:
start_idx = 0
for i, bad_words_ids in bad_words_token_ids.items():
for draft_idx in range(num_draft_tokens[i]):
_apply_bad_words_single_batch(
logits[start_idx + draft_idx],
bad_words_ids,
past_tokens_ids[start_idx + draft_idx],
)
start_idx += num_draft_tokens[i]

View File

@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Some utilities for logprobs, including logits."""
import torch
from vllm.platforms import current_platform
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def batched_count_greater_than(x: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
"""
Counts elements in each row of x that are greater than the corresponding
value in values. Use torch.compile to generate an optimized kernel for
this function. otherwise, it will create additional copies of the input
tensors and cause memory issues.
Args:
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
Returns:
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
"""
return (x >= values).sum(-1)

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.utils import apply_penalties
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import make_tensor_with_pad
def apply_all_penalties(
logits: torch.Tensor,
prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device)
# In the async scheduling case, rows that won't have penalties applied may contain
# -1 placeholder token ids. We must replace these with valid token ids so that the
# scatter done in apply_penalties is valid.
# NOTE(nick): The penalties implementation is currently quite inefficient and
# will be reworked anyhow.
output_tokens_t.masked_fill_(output_tokens_t == -1, vocab_size)
return apply_penalties(
logits,
prompt_token_ids,
output_tokens_t,
presence_penalties,
frequency_penalties,
repetition_penalties,
)
def _convert_to_tensors(
output_token_ids: list[list[int]], vocab_size: int, device: torch.device
) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)

View File

@@ -0,0 +1,384 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from packaging import version
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.model import LogprobsMode
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
logger = init_logger(__name__)
class TopKTopPSampler(nn.Module):
"""
Module that performs optional top-k and top-p filtering followed by
weighted random sampling of logits.
Implementations may update the logits tensor in-place.
"""
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
super().__init__()
self.logprobs_mode = logprobs_mode
# flashinfer optimization does not apply if intermediate
# logprobs/logits after top_k/top_p need to be returned
if (
logprobs_mode not in ("processed_logits", "processed_logprobs")
and current_platform.is_cuda()
):
if envs.VLLM_USE_FLASHINFER_SAMPLER:
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
capability = current_platform.get_device_capability()
assert capability is not None
if not FlashInferBackend.supports_compute_capability(capability):
capability_str = capability.as_version_str()
raise RuntimeError(
"FlashInfer does not support compute capability "
f"{capability_str}, unset VLLM_USE_FLASHINFER_SAMPLER=1."
)
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
logger.info_once(
"Using FlashInfer for top-p & top-k sampling.",
scope="global",
)
self.forward = self.forward_cuda
else:
logger.debug_once(
"FlashInfer top-p/top-k sampling is available but disabled "
"by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
"after verifying accuracy for your workloads."
)
self.forward = self.forward_native
elif current_platform.is_cpu():
arch = current_platform.get_cpu_architecture()
# Fall back to native implementation for POWERPC and RISCV.
# On PowerPC argmax produces incorrect output with torch.compile.
# PR: https://github.com/vllm-project/vllm/pull/26987
if arch in (CpuArchEnum.RISCV, CpuArchEnum.POWERPC):
self.forward = self.forward_native
else:
self.forward = self.forward_cpu
elif (
logprobs_mode not in ("processed_logits", "processed_logprobs")
and rocm_aiter_ops.is_enabled()
):
try:
import aiter.ops.sampling # noqa: F401
self.aiter_ops = torch.ops.aiter
logger.info_once(
"Using aiter sampler on ROCm (lazy import, sampling-only)."
)
self.forward = self.forward_hip
except ImportError:
logger.warning_once(
"aiter.ops.sampling is not available on ROCm. "
"Falling back to forward_native implementation."
)
self.forward = self.forward_native
else:
self.forward = self.forward_native
self.apply_top_k_top_p = apply_top_k_top_p
def forward_native(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), logits_to_return
def forward_cuda(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""More optimized implementation for top-k and top-p sampling."""
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
if (k is None and p is None) or generators:
if generators:
logger.debug_once(
"FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation."
)
return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
"FlashInfer does not support returning logits/logprobs"
)
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
# because of slicing operation in logits_processor.
return flashinfer_sample(logits.contiguous(), k, p, generators), None
def forward_cpu(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
PyTorch-native implementation of top-k and top-p sampling for CPU.
The logits tensor may be updated in-place.
"""
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
if len(generators) != logits.shape[0]:
return compiled_random_sample(logits), logits_to_return
else:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
def forward_hip(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Optimized ROCm/aiter path (same structure as forward_cuda)."""
if (k is None and p is None) or generators:
if generators:
logger.warning_once(
"aiter sampler does not support per-request generators; "
"falling back to PyTorch-native."
)
return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in (
"processed_logits",
"processed_logprobs",
), "aiter sampler does not support returning logits/logprobs."
return self.aiter_sample(logits, k, p, generators), None
def aiter_sample(
self,
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from logits using aiter ops."""
use_top_k = k is not None
use_top_p = p is not None
# Joint k+p path
if use_top_p and use_top_k:
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs(
probs,
None,
*_to_tensor_scalar_tuple(k),
*_to_tensor_scalar_tuple(p),
deterministic=True,
)
return next_token_ids.view(-1)
# Top-p only path
elif use_top_p:
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
next_token_ids = self.aiter_ops.top_p_sampling_from_probs(
probs, None, *_to_tensor_scalar_tuple(p), deterministic=True
)
return next_token_ids.view(-1)
# Top-k only path
elif use_top_k:
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
renorm_probs = self.aiter_ops.top_k_renorm_probs(
probs, *_to_tensor_scalar_tuple(k)
)
return torch.multinomial(renorm_probs, num_samples=1).view(-1)
raise RuntimeError("aiter_sample was called with no active top-k or top-p.")
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@torch.compile(dynamic=True)
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
return probs.div(q).argmax(dim=-1).view(-1)
def apply_top_k_top_p(
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
if p is None:
if k is None:
return logits
# Avoid sorting vocab for top-k only case.
return apply_top_k_only(logits, k)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def apply_top_k_only(
logits: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
"""
no_top_k_mask = k == logits.shape[1]
# Set non-top-k rows to 1 so that we can gather.
k = k.masked_fill(no_top_k_mask, 1)
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))
return logits
def random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def flashinfer_sample(
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the logits using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
import flashinfer
if version.parse(flashinfer.__version__) < version.parse("0.2.3"):
raise ImportError(
"FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
)
assert not (k is None and p is None)
if k is None:
# Top-p only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
probs, p, deterministic=True
)
elif p is None:
# Top-k only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
probs, k, deterministic=True
)
else:
# Both top-k and top-p.
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
logits, k, p, deterministic=True
)
return next_token_ids.view(-1)
def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
return (x, 0)
else:
return (None, x)

View File

@@ -0,0 +1,805 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from dataclasses import replace
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
logger = init_logger(__name__)
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = 0
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 128
class RejectionSampler(nn.Module):
"""
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
end of the sequence. The bonus token is only sampled from the target
probabilities. We pass in the bonus tokens instead of sampling them
in the rejection sampler to allow for more flexibility in the
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def __init__(self, sampler: Sampler):
super().__init__()
self.sampler = sampler
logprobs_mode = self.sampler.logprobs_mode
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
def forward(
self,
metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: torch.Tensor | None,
# [num_tokens + batch_size, vocab_size]
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
"""
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens + batch_size, vocab_size]. Here,
probabilities from different requests are flattened into a
single tensor because this is the shape of the output logits.
NOTE: `logits` can be updated in place to save memory.
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
SamplerOutput:
Contains the final output token IDs and their logprobs if
requested.
"""
assert metadata.max_spec_len <= MAX_SPEC_LEN
bonus_logits_indices = metadata.bonus_logits_indices
target_logits_indices = metadata.target_logits_indices
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[bonus_logits_indices]
bonus_sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=replace(
sampling_metadata,
max_num_logprobs=-1,
),
predict_bonus_token=True,
# Override the logprobs mode to return logits because they are
# needed later to compute the accepted token logprobs.
logprobs_mode_override="processed_logits"
if self.is_processed_logprobs_mode
else "raw_logits",
)
bonus_token_ids = bonus_sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
raw_target_logits = logits[target_logits_indices]
# Use float32 for the target_logits.
raw_target_logits = raw_target_logits.to(torch.float32)
target_logits = self.apply_logits_processors(
raw_target_logits, sampling_metadata, metadata
)
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `apply_sampling_constraints` function.
target_logits = apply_sampling_constraints(
target_logits,
metadata.cu_num_draft_tokens,
sampling_metadata,
)
# Compute probability distribution from target logits.
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
output_token_ids = rejection_sample(
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_probs,
bonus_token_ids,
sampling_metadata,
)
logprobs_tensors = None
if sampling_metadata.max_num_logprobs is not None:
logprobs_tensors = self._get_logprobs_tensors(
sampling_metadata.max_num_logprobs,
metadata,
logits,
target_logits if self.is_processed_logprobs_mode else raw_target_logits,
bonus_sampler_output.logprobs_tensors.logprobs,
output_token_ids,
)
return SamplerOutput(
sampled_token_ids=output_token_ids,
logprobs_tensors=logprobs_tensors,
)
def _get_logprobs_tensors(
self,
max_num_logprobs: int,
metadata: SpecDecodeMetadata,
logits: torch.Tensor,
target_logits: torch.Tensor,
bonus_logits: torch.Tensor,
sampled_token_ids: torch.Tensor,
) -> LogprobsTensors:
cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens)
cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1]
# Collect target and bonus logits.
bonus_logits_indices = metadata.bonus_logits_indices
target_logits_indices = metadata.target_logits_indices
final_logits = torch.zeros_like(logits, dtype=torch.float32)
final_logits[target_logits_indices] = target_logits.to(torch.float32)
final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32)
# Compute accepted token indices.
accepted_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID
num_accepted_tokens = accepted_mask.sum(dim=-1)
accepted_logit_indices = accepted_mask.nonzero(as_tuple=True)[1]
accepted_logit_indices += cu_num_sampled_tokens.repeat_interleave(
num_accepted_tokens
)
# Compute logprobs for accepted tokens.
accepted_logits = final_logits[accepted_logit_indices]
accepted_logprobs = (
accepted_logits
if self.is_logits_logprobs_mode
else self.sampler.compute_logprobs(accepted_logits)
)
accepted_tokens = sampled_token_ids[accepted_mask]
return self.sampler.gather_logprobs(
accepted_logprobs,
max_num_logprobs,
accepted_tokens.to(torch.int64),
)
@staticmethod
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
discard_req_indices: Sequence[int] = (),
return_cu_num_tokens: bool = False,
) -> tuple[list[list[int]], list[int] | None]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
discard_req_indices: Optional row indices to discard tokens in.
return_cu_num_tokens: Whether to also return cumulative token counts.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
cu_num_tokens = None
if return_cu_num_tokens:
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
if len(discard_req_indices) > 0:
valid_mask[discard_req_indices] = False
outputs = [
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
return outputs, cu_num_tokens
def apply_logits_processors(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
) -> torch.Tensor:
has_penalties = not sampling_metadata.no_penalties
any_penalties_or_bad_words = (
sampling_metadata.bad_words_token_ids or has_penalties
)
output_token_ids = sampling_metadata.output_token_ids
if any_penalties_or_bad_words:
output_token_ids = self._combine_outputs_with_spec_tokens(
output_token_ids,
sampling_metadata.spec_token_ids,
)
# Calculate indices of target logits.
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
num_requests = len(sampling_metadata.output_token_ids)
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
original_indices = torch.arange(num_requests, device="cpu")
repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)
repeat_indices = repeat_indices_cpu.to(
device=logits.device, non_blocking=True
)
logits = self.apply_penalties(
logits, sampling_metadata, metadata, repeat_indices, output_token_ids
)
# Apply allowed token ids.
if sampling_metadata.allowed_token_ids_mask is not None:
token_mask = sampling_metadata.allowed_token_ids_mask[repeat_indices]
logits.masked_fill_(token_mask, float("-inf"))
# Apply bad words exclusion.
if bad_words_token_ids := sampling_metadata.bad_words_token_ids:
apply_bad_words_with_drafts(
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
)
return logits
@staticmethod
def apply_penalties(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
repeat_indices: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
if sampling_metadata.no_penalties:
return logits
assert sampling_metadata.prompt_token_ids is not None
prompt_token_ids = sampling_metadata.prompt_token_ids[repeat_indices]
presence_penalties = sampling_metadata.presence_penalties[repeat_indices]
frequency_penalties = sampling_metadata.frequency_penalties[repeat_indices]
repetition_penalties = sampling_metadata.repetition_penalties[repeat_indices]
logits = apply_all_penalties(
logits,
prompt_token_ids,
presence_penalties,
frequency_penalties,
repetition_penalties,
output_token_ids,
)
return logits
@staticmethod
def _combine_outputs_with_spec_tokens(
output_token_ids: list[list[int]],
spec_token_ids: list[list[int]] | None = None,
) -> list[list[int]]:
if spec_token_ids is None:
return output_token_ids
result = []
for out, spec in zip(output_token_ids, spec_token_ids):
if len(spec) == 0:
continue
result.append(out)
for i in range(len(spec) - 1):
result.append([*result[-1], spec[i]])
return result
def rejection_sample(
# [num_tokens]
draft_token_ids: torch.Tensor,
# [batch_size]
num_draft_tokens: list[int],
max_spec_len: int,
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: torch.Tensor | None,
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2
batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1]
device = target_probs.device
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
# Create output buffer.
output_token_ids = torch.full(
(batch_size, max_spec_len + 1),
PLACEHOLDER_TOKEN_ID,
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
device=device,
)
if sampling_metadata.all_greedy:
is_greedy = None
else:
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
rejection_greedy_sample_kernel[(batch_size,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
)
if sampling_metadata.all_greedy:
return output_token_ids
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs = generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device,
)
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel[(batch_size,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
)
return output_token_ids
def apply_sampling_constraints(
logits: torch.Tensor, # [num_tokens, vocab_size]
cu_num_draft_tokens: torch.Tensor, # [batch_size]
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Process logits based on sampling metadata.
This function applies temperature scaling to the logits,
as well as top-k and top-p. For greedy decoding, it returns
the original logits.
Args:
logits: Input logits tensor to be processed.
cu_num_draft_tokens: Cumulative number of draft tokens.
sampling_metadata: Metadata containing sampling parameters such as
temperature and whether greedy sampling is used.
Returns:
torch.Tensor: Processed logits if non-greedy sampling is used,
otherwise returns the original logits.
"""
assert logits.ndim == 2
assert cu_num_draft_tokens.ndim == 1
if sampling_metadata.all_greedy:
return logits
num_tokens = logits.shape[0]
temperature = expand_batch_to_tokens(
sampling_metadata.temperature,
cu_num_draft_tokens,
num_tokens,
replace_from=GREEDY_TEMPERATURE,
replace_to=1,
)
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
logits.div_(temperature.unsqueeze(-1))
# Get expanded top_k and top_p tensors.
top_k = None
if sampling_metadata.top_k is not None:
top_k = expand_batch_to_tokens(
sampling_metadata.top_k,
cu_num_draft_tokens,
num_tokens,
)
top_p = None
if sampling_metadata.top_p is not None:
top_p = expand_batch_to_tokens(
sampling_metadata.top_p,
cu_num_draft_tokens,
num_tokens,
)
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
return apply_top_k_top_p(logits, top_k, top_p)
def expand_batch_to_tokens(
x: torch.Tensor, # [batch_size]
cu_num_tokens: torch.Tensor, # [batch_size]
num_tokens: int,
replace_from: int = 0,
replace_to: int = 0,
) -> torch.Tensor:
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
tokens per batch in cu_num_tokens.
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
Args:
x: [batch_size] tensor to expand.
cu_num_tokens: [batch_size] tensor containing the cumulative number of
tokens per batch. Each element represents the total number of
tokens up to and including that batch.
num_tokens: Total number of tokens.
replace_from: int = 0
Value to be replaced if it is found in x.
replace_to: int = 0
Value to replace with when replace_from is found.
Returns:
expanded_x: [num_tokens] tensor.
"""
batch_size = x.shape[0]
assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens)
expand_kernel[(batch_size,)](
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
)
return expanded_x
def generate_uniform_probs(
num_tokens: int,
num_draft_tokens: list[int],
generators: dict[int, torch.Generator],
device: torch.device,
) -> torch.Tensor:
"""
Generates a batch of uniform random samples, with optional seeding
if available.
This method creates a tensor of shape `(num_tokens, )` filled
with uniform random values in the range [0, 1). If `generators` is provided,
the requests with their own seeds will use the provided `torch.Generator`
for reproducibility. The samples for the other requests will be generated
without a seed.
Args:
num_tokens: int
Total number of tokens.
num_draft_tokens: List[List[int]]
Number of draft tokens per request.
generators: Optional[Dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects.
device: torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand: torch.Tensor
A tensor of shape `(num_tokens, )` containing uniform
random values in the range [0, 1).
"""
# NOTE(woosuk): We deliberately use float64 instead of float32 here
# because when using float32, there's a non-negligible chance that
# uniform_prob is sampled to be exact 0.0 as reported in
# https://github.com/pytorch/pytorch/issues/16706. Using float64
# mitigates the issue.
uniform_probs = torch.rand(
(num_tokens,),
dtype=torch.float64,
device=device,
)
start_idx = 0
for req_idx, n in enumerate(num_draft_tokens):
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if n == 0:
continue
end_idx = start_idx + n
generator = generators.get(req_idx)
if generator is not None:
uniform_probs[start_idx:end_idx].uniform_(generator=generator)
start_idx = end_idx
return uniform_probs
def sample_recovered_tokens(
max_spec_len: int,
num_draft_tokens: list[int],
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens]
draft_token_ids: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: torch.Tensor | None,
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
device: torch.device,
) -> torch.Tensor:
# NOTE(woosuk): Create only one distribution for each request.
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
device=device,
)
q.exponential_()
for i, generator in sampling_metadata.generators.items():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)
recovered_token_ids = torch.empty_like(draft_token_ids)
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
recovered_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
q,
vocab_size,
triton.next_power_of_2(vocab_size),
NO_DRAFT_PROBS=draft_probs is None,
)
return recovered_token_ids
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
max_spec_len,
):
req_idx = tl.program_id(0)
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
# re-compilation may happen during runtime when is_greedy_ptr is None.
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx)
if not is_greedy:
# Early exit for non-greedy sampling requests.
return
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
# Early exit for greedy sampling requests.
return
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept.
token_id = draft_token_id
else:
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id
)
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
def expand_kernel(
output_ptr, # [num_tokens]
input_ptr, # [batch_size]
cu_num_tokens_ptr, # [batch_size]
replace_from,
replace_to,
MAX_NUM_TOKENS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0: # noqa: SIM108
start_idx = 0
else:
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
num_tokens = end_idx - start_idx
src_val = tl.load(input_ptr + req_idx)
src_val = tl.where(src_val == replace_from, replace_to, src_val)
offset = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens)
@triton.jit
def sample_recovered_tokens_kernel(
output_token_ids_ptr, # [num_tokens]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
# Early exit for out-of-range positions.
pos = tl.program_id(1)
if pos >= num_draft_tokens:
return
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)),
other=0,
)
else:
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"),
)
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)

319
vllm/v1/sample/sampler.py Normal file
View File

@@ -0,0 +1,319 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A layer that samples the next tokens from the model's outputs."""
import torch
import torch.nn as nn
from vllm.config.model import LogprobsMode
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
"""
A layer that samples the next tokens from the model's outputs
with the following steps in order:
1. If logprobs are requested:
a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
as the final logprobs to return.
b) If `logprobs_mode` is `raw_logits`, clone the logits
as the final logprobs to return.
2. Convert logits to float32.
3. Apply allowed token ids whitelist.
4. Apply bad words exclusion.
5. Apply logit processors which are not argmax-invariant,
i.e. that can impact greedy sampling.
a) Min tokens processor
b) Logit bias processor
6. Apply penalties
a) Repetition penalty
b) Frequency penalty
c) Presence penalty
7. Sample the next tokens. `sample` method performs the following steps:
a) If not `all_random`, perform greedy sampling. If `all_greedy`,
return the greedily sampled tokens and final logprobs if requested.
b) Apply temperature.
c) Apply logit processors which are argmax-invariant, by default
the min_p processor.
d) Apply top_k and/or top_p.
e) Sample the next tokens with the probability distribution.
f) If `all_random` or temperature >= epsilon (1e-5), return the
randomly sampled tokens and final logprobs if requested. Else,
return the greedily sampled tokens and logprobs if requested.
8. Gather the logprobs of the top `max_num_logprobs` and sampled token
(if requested). Note that if the sampled token is within the top
`max_num_logprobs`, the logprob will be eventually merged in
`LogprobsProcessor` during output processing. Therefore, the
final output may contain either `max_num_logprobs + 1` or
`max_num_logprobs` logprobs.
9. Return the final `SamplerOutput`.
"""
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
self.pin_memory = is_pin_memory_available()
self.logprobs_mode = logprobs_mode
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
predict_bonus_token: bool = False,
logprobs_mode_override: LogprobsMode | None = None,
) -> SamplerOutput:
logprobs_mode = logprobs_mode_override or self.logprobs_mode
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif logprobs_mode == "raw_logits":
if logits.dtype == torch.float32:
raw_logprobs = logits.clone()
else:
raw_logprobs = logits.to(torch.float32)
# Use float32 for the logits.
logits = logits.to(torch.float32)
logits = self.apply_logits_processors(
logits, sampling_metadata, predict_bonus_token
)
# Sample the next token.
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
if processed_logprobs is not None:
raw_logprobs = processed_logprobs
# Convert sampled token ids to int64 (long) type to ensure compatibility
# with subsequent operations that may use these values as indices.
# This conversion is necessary because FlashInfer sampling operations
# return int32 (while PyTorch argmax and topk return int64).
sampled = sampled.long()
if num_logprobs is None:
logprobs_tensors = None
elif num_logprobs == -1:
# Return the full unsorted and unranked logprobs.
logprobs_tensors = LogprobsTensors(
torch.empty(0), raw_logprobs, torch.empty(0)
)
else:
# Gather the logprobs and ranks of the topk and sampled token.
logprobs_tensors = self.gather_logprobs(
raw_logprobs, num_logprobs, token_ids=sampled
)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=logprobs_tensors,
)
return sampler_output
@staticmethod
def apply_temperature(
logits: torch.Tensor,
temp: torch.Tensor,
all_random: bool,
) -> torch.Tensor:
# Use in-place division to avoid creating a new tensor.
# Avoid division by zero if there are greedy requests.
if not all_random:
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
return logits.div_(temp.unsqueeze(dim=1))
@staticmethod
def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
logprobs_mode_override: LogprobsMode | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
may update the logits tensor in-place.
"""
logprobs_mode = logprobs_mode_override or self.logprobs_mode
assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
if sampling_metadata.all_random:
greedy_sampled = None
else:
greedy_sampled = self.greedy_sample(logits)
if sampling_metadata.all_greedy:
processed_logprobs = None
if sampling_metadata.max_num_logprobs is not None:
if logprobs_mode == "processed_logits":
processed_logprobs = logits
elif logprobs_mode == "processed_logprobs":
processed_logprobs = self.compute_logprobs(logits)
return greedy_sampled, processed_logprobs
assert sampling_metadata.temperature is not None
# Apply temperature.
logits = self.apply_temperature(
logits, sampling_metadata.temperature, sampling_metadata.all_random
)
# Apply logits processors that only apply to random sampling
# (argmax invariant)
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
# Apply top_k and/or top_p.
random_sampled, processed_logprobs = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
if greedy_sampled is None:
return random_sampled, processed_logprobs
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
return sampled, processed_logprobs
@staticmethod
def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
@staticmethod
def gather_logprobs(
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == torch.int64
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size.
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
@staticmethod
def _combine_outputs_with_spec_tokens(
output_token_ids: list[list[int]],
spec_token_ids: list[list[int]] | None = None,
) -> list[list[int]]:
if spec_token_ids is None:
return output_token_ids
return [
[*out, *spec] if spec else out
for out, spec in zip(output_token_ids, spec_token_ids)
]
def apply_logits_processors(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
predict_bonus_token: bool,
) -> torch.Tensor:
bad_words_token_ids = sampling_metadata.bad_words_token_ids
any_penalties_or_bad_words = (
bool(bad_words_token_ids) or not sampling_metadata.no_penalties
)
output_token_ids = sampling_metadata.output_token_ids
if predict_bonus_token and any_penalties_or_bad_words:
# Combine base outputs with spec tokens when speculative decoding
# is enabled.
output_token_ids = self._combine_outputs_with_spec_tokens(
output_token_ids,
sampling_metadata.spec_token_ids,
)
# Apply allowed token ids.
if sampling_metadata.allowed_token_ids_mask is not None:
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
# Apply bad words exclusion.
if bad_words_token_ids:
apply_bad_words(logits, bad_words_token_ids, output_token_ids)
# Apply logits processors which can impact greedy sampling.
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
logits = processor.apply(logits)
# Apply penalties (e.g., freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
return logits
@staticmethod
def apply_penalties(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
output_token_ids: list[list[int]],
) -> torch.Tensor:
if sampling_metadata.no_penalties:
return logits
assert sampling_metadata.prompt_token_ids is not None
return apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
output_token_ids,
)

View File

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
import torch
from vllm.v1.worker.tpu_input_batch import InputBatch
DEFAULT_SAMPLING_PARAMS = dict(
temperature=-1.0,
min_p=0.0,
# strictly disabled for now
top_k=0,
top_p=1.0,
# frequency_penalties=0.0,
# presence_penalties=0.0,
# repetition_penalties=0.0,
)
@dataclass
class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
temperature: torch.Tensor = None
min_p: torch.Tensor = None
top_k: torch.Tensor = None
top_p: torch.Tensor = None
all_greedy: bool = True
all_random: bool = False
# Whether logprobs are to be gathered in this batch of request. To balance
# out compile time and runtime, a fixed `max_number_logprobs` value is used
# when gathering logprobs, regardless of the values specified in the batch.
logprobs: bool = False
# TODO No penalties for now
no_penalties: bool = True
prompt_token_ids = None
frequency_penalties = None
presence_penalties = None
repetition_penalties = None
# should use tensor
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
min_tokens = None # impl is not vectorized
logit_bias: list[dict[int, float] | None] = field(default_factory=lambda: list())
allowed_token_ids_mask = None
bad_words_token_ids = None
# Generator not supported by xla
_generators: dict[int, torch.Generator] = field(default_factory=lambda: dict())
@property
def generators(self) -> dict[int, torch.Generator]:
# Generator not supported by torch/xla. This field must be immutable.
return self._generators
@classmethod
def from_input_batch(
cls,
input_batch: InputBatch,
padded_num_reqs: int,
xla_device: torch.device,
generate_params_if_all_greedy: bool = False,
) -> "TPUSupportedSamplingMetadata":
"""
Copy sampling tensors slices from `input_batch` to on device tensors.
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
Args:
input_batch: The input batch containing sampling parameters.
padded_num_reqs: The padded number of requests.
xla_device: The XLA device.
generate_params_if_all_greedy: If True, generate sampling parameters
even if all requests are greedy. this is useful for cases where
we want to pre-compile a graph with sampling parameters, even if
they are not strictly needed for greedy decoding.
"""
needs_logprobs = (
input_batch.max_num_logprobs > 0 if input_batch.max_num_logprobs else False
)
# Early return to avoid unnecessary cpu to tpu copy
if input_batch.all_greedy is True and generate_params_if_all_greedy is False:
return cls(all_greedy=True, logprobs=needs_logprobs)
num_reqs = input_batch.num_reqs
def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
# Pad value is the default one.
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
fill_slice(
input_batch.temperature_cpu_tensor, DEFAULT_SAMPLING_PARAMS["temperature"]
)
fill_slice(input_batch.min_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["min_p"])
fill_slice(input_batch.top_k_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_k"])
fill_slice(input_batch.top_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_p"])
# Slice persistent device tensors to a fixed pre-compiled padded shape.
return cls(
temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].to(
xla_device
),
all_greedy=input_batch.all_greedy,
all_random=input_batch.all_random,
# TODO enable more and avoid returning None values
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device),
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device),
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(xla_device),
logprobs=needs_logprobs,
)

View File

@@ -0,0 +1,215 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sampler layer implementing TPU supported operations."""
import torch
import torch.nn as nn
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(self):
# TODO(houseroad): Add support for logprobs_mode.
super().__init__()
def forward(
self,
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> SamplerOutput:
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# These are TPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=None,
)
return sampler_output
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
all_random: bool = False,
) -> torch.Tensor:
# Avoid division by zero for greedy sampling (temperature ~ 0.0).
if not all_random:
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
greedy_sampled = self.greedy_sample(logits)
assert sampling_metadata.temperature is not None
# Apply temperature.
logits = self.apply_temperature(
logits, sampling_metadata.temperature, sampling_metadata.all_random
)
# Apply min_p.
if sampling_metadata.min_p is not None:
logits = self.apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p.
logits = apply_top_k_top_p(
logits,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
# Random sample.
probs = logits.softmax(dim=-1, dtype=torch.float32)
random_sampled = self.random_sample(probs, sampling_metadata.generators)
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
)
return sampled
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
def gather_logprobs(
self,
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size.
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
def apply_min_p(
self,
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
valid_token_mask = probability_values >= adjusted_min_p
# Apply mask using boolean indexing (xla friendly)
logits.masked_fill_(~valid_token_mask, -float("inf"))
return logits
def random_sample(
self,
probs: torch.Tensor,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
q.exponential_()
if generators:
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def apply_top_k_top_p(
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> torch.Tensor:
"""
Apply top-k and top-p optimized for TPU.
This algorithm avoids using torch.scatter which is extremely slow on TPU.
This is achieved by finding a "cut-off" element in the original logit, and
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.
Note: in the case of tie (i.e. multiple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
if k is not None:
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
top_k_count = top_k_count.unsqueeze(dim=1)
top_k_cutoff = probs_sort.gather(-1, top_k_count)
# Make sure the no top-k rows are no-op.
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
elements_to_discard = probs < top_k_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
if p is not None:
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
return logits