Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,352 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import inspect
import itertools
from abc import abstractmethod
from collections.abc import Sequence
from functools import lru_cache, partial
from typing import TYPE_CHECKING
import torch
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor
from vllm.sampling_params import SamplingParams
from vllm.utils.torch_utils import guard_cuda_initialization
from vllm.v1.sample.logits_processor.builtin import (
LogitBiasLogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
process_dict_updates,
)
from vllm.v1.sample.logits_processor.interface import (
BatchUpdate,
LogitsProcessor,
MoveDirectionality,
)
from vllm.v1.sample.logits_processor.state import BatchUpdateBuilder, LogitsProcessors
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
# Error message when the user tries to initialize vLLM with a pooling model
# and custom logitsproces
STR_POOLING_REJECTS_LOGITSPROCS = (
"Pooling models do not support custom logits processors."
)
# Error message when the user tries to initialize vLLM with a speculative
# decoding enabled and custom logitsproces
STR_SPEC_DEC_REJECTS_LOGITSPROCS = (
"Custom logits processors are not supported when speculative decoding is enabled."
)
LOGITSPROCS_GROUP = "vllm.logits_processors"
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
MinTokensLogitsProcessor,
LogitBiasLogitsProcessor,
MinPLogitsProcessor,
]
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
"""Load all installed logit processor plugins"""
from importlib.metadata import entry_points
installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP)
if len(installed_logitsprocs_plugins) == 0:
logger.debug("No logitsprocs plugins installed (group %s).", LOGITSPROCS_GROUP)
return []
# Load logitsprocs plugins
logger.debug("Loading installed logitsprocs plugins (group %s):", LOGITSPROCS_GROUP)
classes: list[type[LogitsProcessor]] = []
for entrypoint in installed_logitsprocs_plugins:
try:
logger.debug(
"- Loading logitproc plugin entrypoint=%s target=%s",
entrypoint.name,
entrypoint.value,
)
with guard_cuda_initialization():
classes.append(entrypoint.load())
except Exception as e:
logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e)
raise RuntimeError(
f"Failed to load LogitsProcessor plugin {entrypoint}"
) from e
return classes
def _load_logitsprocs_by_fqcns(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
) -> list[type[LogitsProcessor]]:
"""Load logit processor types, identifying them by fully-qualified class
names (FQCNs).
Effectively, a mixed list of logitproc types and FQCN strings is converted
into a list of entirely logitproc types, by loading from the FQCNs.
FQCN syntax is <module>:<type> i.e. x.y.z:CustomLogitProc
Already-loaded logitproc types must be subclasses of LogitsProcessor
Args:
logits_processors: Potentially mixed list of logitsprocs types and FQCN
strings for logitproc types
Returns:
List of logitproc types
"""
if not logits_processors:
return []
logger.debug(
"%s additional custom logits processors specified, checking whether "
"they need to be loaded.",
len(logits_processors),
)
classes: list[type[LogitsProcessor]] = []
for ldx, logitproc in enumerate(logits_processors):
if isinstance(logitproc, type):
logger.debug(" - Already-loaded logit processor: %s", logitproc.__name__)
if not issubclass(logitproc, LogitsProcessor):
raise ValueError(
f"{logitproc.__name__} is not a subclass of LogitsProcessor"
)
classes.append(logitproc)
continue
logger.debug("- Loading logits processor %s", logitproc)
module_path, qualname = logitproc.split(":")
try:
# Load module
with guard_cuda_initialization():
module = importlib.import_module(module_path)
except Exception as e:
logger.error(
"Failed to load %sth LogitsProcessor plugin %s: %s",
ldx,
logitproc,
e,
)
raise RuntimeError(
f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
) from e
# Walk down dotted name to get logitproc class
obj = module
for attr in qualname.split("."):
obj = getattr(obj, attr)
if not isinstance(obj, type):
raise ValueError("Loaded logit processor must be a type.")
if not issubclass(obj, LogitsProcessor):
raise ValueError(f"{obj.__name__} must be a subclass of LogitsProcessor")
classes.append(obj)
return classes
def _load_custom_logitsprocs(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
) -> list[type[LogitsProcessor]]:
"""Load all custom logits processors.
* First load all installed logitproc plugins
* Second load custom logitsprocs pass by the user at initialization time
Args:
logits_processors: potentially mixed list of logitproc types and
logitproc type fully-qualified names (FQCNs)
which need to be loaded
Returns:
A list of all loaded logitproc types
"""
from vllm.platforms import current_platform
if current_platform.is_tpu():
# No logitsprocs specified by caller
# TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs
return []
return _load_logitsprocs_plugins() + _load_logitsprocs_by_fqcns(logits_processors)
def build_logitsprocs(
vllm_config: "VllmConfig",
device: torch.device,
is_pin_memory: bool,
is_pooling_model: bool,
custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (),
) -> LogitsProcessors:
if is_pooling_model:
if custom_logitsprocs:
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
logger.debug(
"Skipping logits processor loading because pooling models"
" do not support logits processors."
)
return LogitsProcessors()
# Check if speculative decoding is enabled.
if vllm_config.speculative_config:
if custom_logitsprocs:
raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS)
logger.warning(
"min_p, logit_bias, and min_tokens parameters won't currently work "
"with speculative decoding enabled."
)
return LogitsProcessors()
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
return LogitsProcessors(
ctor(vllm_config, device, is_pin_memory)
for ctor in itertools.chain(
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes
)
)
cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
def validate_logits_processors_parameters(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
sampling_params: SamplingParams,
):
logits_processors = (
tuple(logits_processors) if logits_processors is not None else None
)
for logits_procs in cached_load_custom_logitsprocs(logits_processors):
logits_procs.validate_params(sampling_params)
class AdapterLogitsProcessor(LogitsProcessor):
"""Wrapper for per-request logits processors
To wrap a specific per-request logits processor,
* Subclass `AdapterLogitsProcessor`
* Implement `self.is_argmax_invariant()` base-class method
* Implement `self.new_req_logits_processor(params)`
`self.__init__(vllm_config, device, is_pin_memory)` does not need to be
overridden in general. However, to implement custom constructor behavior -
especially any logic which operates on or stores `vllm_config`, `device`,
or `is_pin_memory` - `self.__init__(vllm_config, device, is_pin_memory)`
must be overridden and the override must call
`super().__init__(vllm_config, device, is_pin_memory)`
"""
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
"""Subclass must invoke
`super().__init__(vllm_config, device, is_pin_memory)`.
Subclass constructor may find it useful to utilize the `vllm_config`,
`device` and `is_pin_memory` argument. However regardless of whether
these arguments are used, the vLLM logits processor interface requires
all three arguments to be present.
"""
# Map req index -> logits processor state
#
# State representation is a partial[Tensor] comprising a request-level
# logits processor with the output token ids argument and (if required)
# the prompt token ids argument pre-populated
#
# Note that the partial carries a *reference* to output token ids, and
# will thus always operate on the list as it is currently, not as it
# was when the partial was created.
self.req_info: dict[int, partial[torch.Tensor]] = {}
@abstractmethod
def new_req_logits_processor(
self,
params: SamplingParams,
) -> RequestLogitsProcessor | None:
"""Consume request info; return a per-request logits processor.
Return None if logits processor does not need to be applied to request
Args:
params: request sampling params
Returns:
None if logits processor should not be applied to request; otherwise
returns a `RequestLogitsProcessor` instance
"""
raise NotImplementedError
def _new_state(
self,
params: SamplingParams,
prompt_ids: list[int] | None,
output_ids: list[int],
) -> partial[torch.Tensor] | None:
"""Return state representation for new request
Returns None if logits processor is not applicable to request
Args:
params: request sampling params
prompt_ids: request prompt token ids
output_ids: decoded tokens so far for this request
Returns:
logits processor partial[Tensor] or None
"""
if req_lp := self.new_req_logits_processor(params):
args = (
[prompt_ids, output_ids]
if (len(inspect.signature(req_lp).parameters) == 3)
else [output_ids]
)
return partial(req_lp, *args) # type: ignore[misc]
return None
def update_state(self, batch_update: BatchUpdate | None):
process_dict_updates(
self.req_info,
batch_update,
self._new_state,
)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.req_info:
# Apply per-request logits processors to corresponding rows of
# logits tensor
for req_idx, req_lp in self.req_info.items():
req_logits = logits[req_idx]
new_logits = req_lp(req_logits)
if new_logits is not req_logits:
# Modify logits tensor row in-place if necessary
logits[req_idx] = new_logits
return logits
__all__ = [
"LogitsProcessor",
"LogitBiasLogitsProcessor",
"MinPLogitsProcessor",
"MinTokensLogitsProcessor",
"BatchUpdate",
"BatchUpdateBuilder",
"MoveDirectionality",
"LogitsProcessors",
"build_logitsprocs",
"STR_POOLING_REJECTS_LOGITSPROCS",
"LOGITSPROCS_GROUP",
"AdapterLogitsProcessor",
]

View File

@@ -0,0 +1,278 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, TypeVar
import torch
from vllm import SamplingParams
from vllm.v1.sample.logits_processor.interface import (
BatchUpdate,
LogitsProcessor,
MoveDirectionality,
)
if TYPE_CHECKING:
from vllm.config import VllmConfig
T = TypeVar("T")
class MinPLogitsProcessor(LogitsProcessor):
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.min_p_count: int = 0
self.min_p_cpu_tensor = torch.zeros(
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=is_pin_memory
)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.use_double_tensor = torch.device(device).type != "cpu"
if self.use_double_tensor:
# Pre-allocated device tensor
self.min_p_device: torch.Tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device=device
)
else:
self.min_p_device = self.min_p_cpu_tensor
# Current slice of the device tensor
self.min_p: torch.Tensor = self.min_p_device[:0]
def is_argmax_invariant(self) -> bool:
"""Min-p never impacts greedy sampling"""
return True
def get_min_p_by_index(self, index: int) -> float:
return float(self.min_p_cpu[index])
def update_state(self, batch_update: BatchUpdate | None):
if not batch_update:
return
needs_update = False
# Process added requests.
for index, params, _, _ in batch_update.added:
min_p = params.min_p
min_p_before = self.min_p_cpu[index]
if min_p_before != min_p:
needs_update = True
self.min_p_cpu[index] = min_p
if min_p and not min_p_before:
self.min_p_count += 1
elif not min_p and min_p_before:
self.min_p_count -= 1
if self.min_p_count:
# Process removed requests.
if batch_update.removed:
needs_update = True
for index in batch_update.removed:
if self.min_p_cpu[index]:
self.min_p_cpu[index] = 0
self.min_p_count -= 1
# Process moved requests, unidirectional (a->b) and swap (a<->b).
for adx, bdx, direct in batch_update.moved:
min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx]
if min_p_a != min_p_b:
needs_update = True
self.min_p_cpu[bdx] = min_p_a
if direct == MoveDirectionality.SWAP:
self.min_p_cpu[adx] = min_p_b
if direct == MoveDirectionality.UNIDIRECTIONAL:
if min_p_a:
self.min_p_cpu[adx] = 0
if min_p_b:
self.min_p_count -= 1
# Update tensors if needed.
size = batch_update.batch_size
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
self.min_p = self.min_p_device[:size]
if self.use_double_tensor:
self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
self.min_p.unsqueeze_(1)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.min_p_count:
return logits
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
# Adjust min_p
adjusted_min_p = max_probabilities.mul_(self.min_p)
# Identify valid tokens using threshold comparison
invalid_token_mask = probability_values < adjusted_min_p
# Apply mask using boolean indexing
logits.masked_fill_(invalid_token_mask, -float("inf"))
return logits
class LogitBiasLogitsProcessor(LogitsProcessor):
def __init__(self, _, device: torch.device, is_pin_memory: bool):
self.device = device
self.pin_memory = is_pin_memory
self.biases: dict[int, dict[int, float]] = {}
self.bias_tensor: torch.Tensor = torch.tensor(())
self.logits_slice = (
self._device_tensor([], torch.int32),
self._device_tensor([], torch.int32),
)
def is_argmax_invariant(self) -> bool:
"""Logit bias can rebalance token probabilities and change the
outcome of argmax in greedy sampling."""
return False
def update_state(self, batch_update: BatchUpdate | None):
needs_update = process_dict_updates(
self.biases, batch_update, lambda params, _, __: params.logit_bias or None
)
# Update tensors if needed.
if needs_update:
reqs: list[int] = []
tok_ids: list[int] = []
biases: list[float] = []
for req, lb in self.biases.items():
reqs.extend([req] * len(lb))
tok_ids.extend(lb.keys())
biases.extend(lb.values())
self.bias_tensor = self._device_tensor(biases, torch.float32)
self.logits_slice = (
self._device_tensor(reqs, torch.int32),
self._device_tensor(tok_ids, torch.int32),
)
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
return torch.tensor(
data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
).to(device=self.device, non_blocking=True)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.biases:
logits[self.logits_slice] += self.bias_tensor
return logits
class MinTokensLogitsProcessor(LogitsProcessor):
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
# index -> (min_toks, output_token_ids, stop_token_ids)
self.device = device
self.pin_memory = is_pin_memory
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
# (req_idx_tensor,eos_tok_id_tensor)
self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (
self._device_tensor([], torch.int32),
self._device_tensor([], torch.int32),
)
self.neg_inf_tensor = torch.tensor(
-float("inf"), dtype=torch.float32, device=self.device
)
def is_argmax_invariant(self) -> bool:
"""By censoring stop tokens, min-tokens can change the outcome
of the argmax operation in greedy sampling."""
return False
@staticmethod
def add_request(
params: SamplingParams, _: list[int] | None, output_tok_ids: list[int]
) -> tuple[int, Sequence[int], set[int]] | None:
min_tokens = params.min_tokens
if not min_tokens or len(output_tok_ids) >= min_tokens:
return None
return min_tokens, output_tok_ids, params.all_stop_token_ids
def update_state(self, batch_update: BatchUpdate | None):
needs_update = process_dict_updates(
self.min_toks, batch_update, self.add_request
)
if self.min_toks:
# Check for any requests that have attained their min tokens.
to_remove = tuple(
index
for index, (min_toks, out_tok_ids, _) in self.min_toks.items()
if len(out_tok_ids) >= min_toks
)
if to_remove:
needs_update = True
for index in to_remove:
del self.min_toks[index]
# Update tensors if needed.
if needs_update:
reqs: list[int] = []
tok_ids: list[int] = []
for req, (_, _, stop_tok_ids) in self.min_toks.items():
reqs.extend([req] * len(stop_tok_ids))
tok_ids.extend(stop_tok_ids)
self.logits_slice = (
self._device_tensor(reqs, torch.int32),
self._device_tensor(tok_ids, torch.int32),
)
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
return torch.tensor(
data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
).to(device=self.device, non_blocking=True)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.min_toks:
# Inhibit EOS token for requests which have not reached min length
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
return logits
def process_dict_updates(
req_entries: dict[int, T],
batch_update: BatchUpdate | None,
new_state: Callable[[SamplingParams, list[int] | None, list[int]], T | None],
) -> bool:
"""Utility function to update dict state for sparse LogitsProcessors."""
if not batch_update:
# Nothing to do.
return False
updated = False
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
if (state := new_state(params, prompt_tok_ids, output_tok_ids)) is not None:
req_entries[index] = state
updated = True
elif req_entries.pop(index, None) is not None:
updated = True
if req_entries:
# Process removed requests.
for index in batch_update.removed:
if req_entries.pop(index, None):
updated = True
# Process moved requests, unidirectional (a->b) and
# swapped (a<->b)
for a_index, b_index, direct in batch_update.moved:
a_entry = req_entries.pop(a_index, None)
b_entry = req_entries.pop(b_index, None)
if a_entry is not None:
req_entries[b_index] = a_entry
updated = True
if b_entry is not None:
updated = True
if direct == MoveDirectionality.SWAP:
req_entries[a_index] = b_entry
return updated

View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, Optional
import torch
from vllm import SamplingParams
if TYPE_CHECKING:
from vllm.config import VllmConfig
class MoveDirectionality(Enum):
# One-way i1->i2 req move within batch
UNIDIRECTIONAL = auto()
# Two-way i1<->i2 req swap within batch
SWAP = auto()
# Batch indices of any removed requests.
RemovedRequest = int
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
# requests added to the batch.
AddedRequest = tuple[int, SamplingParams, list[int] | None, list[int]]
# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
@dataclass(frozen=True)
class BatchUpdate:
"""Persistent batch state change info for logitsprocs"""
batch_size: int # Current num reqs in batch
# Metadata for requests added to, removed from, and moved
# within the persistent batch.
#
# Key assumption: the `output_tok_ids` list (which is an element of each
# tuple in `added`) is a reference to the request's running output tokens
# list; via this reference, the logits processors always see the latest
# list of generated output tokens.
#
# NOTE:
# * Added or moved requests may replace existing requests with the same
# index.
# * Operations should be processed in the following order:
# - removed, added, moved
removed: Sequence[RemovedRequest]
added: Sequence[AddedRequest]
moved: Sequence[MovedRequest]
class LogitsProcessor(ABC):
@classmethod
def validate_params(cls, sampling_params: SamplingParams):
"""Validate sampling params for this logits processor.
Raise ValueError for invalid ones.
"""
return None
@abstractmethod
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
) -> None:
raise NotImplementedError
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
"""Apply LogitsProcessor to batch logits tensor.
The updated tensor must be returned but may be
modified in-place.
"""
raise NotImplementedError
@abstractmethod
def is_argmax_invariant(self) -> bool:
"""True if logits processor has no impact on the
argmax computation in greedy sampling.
NOTE: may or may not have the same value for all
instances of a given LogitsProcessor subclass,
depending on subclass implementation.
"""
raise NotImplementedError
@abstractmethod
def update_state(
self,
batch_update: Optional["BatchUpdate"],
) -> None:
"""Called when there are new output tokens, prior
to each forward pass.
Args:
batch_update: Non-None iff there have been changes
to the batch makeup.
"""
raise NotImplementedError

View File

@@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator
from itertools import chain
from typing import TYPE_CHECKING
from vllm.v1.sample.logits_processor.interface import (
AddedRequest,
BatchUpdate,
MovedRequest,
RemovedRequest,
)
if TYPE_CHECKING:
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
class BatchUpdateBuilder:
"""Helps track persistent batch state changes and build
a batch update data structure for logitsprocs
Assumptions:
* All information about requests removed from persistent batch
during a step is aggregated in self._removed through calls to
self.removed_append() at the beginning of a step. This must happen
before the first time that self.removed, self.pop_removed()
or self.peek_removed() are invoked in a given step
* After the first time that self.removed, self.pop_removed()
or self.peek_removed() are read in a step, no new removals
are registered using self.removed_append()
* Elements of self._removed are never directly modified, added or
removed (i.e. modification is only via self.removed_append() and
self.pop_removed())
Guarantees under above assumptions:
* self.removed is always sorted in descending order
* self.pop_removed() and self.peek_removed() both return
the lowest removed request index in the current step
"""
_removed: list[RemovedRequest]
_is_removed_sorted: bool
added: list[AddedRequest]
moved: list[MovedRequest]
def __init__(
self,
removed: list[RemovedRequest] | None = None,
added: list[AddedRequest] | None = None,
moved: list[MovedRequest] | None = None,
) -> None:
self._removed = removed or []
self.added = added or []
self.moved = moved or []
self._is_removed_sorted = False
# Used to track changes in the pooling case
# where we don't populate the added list.
self.batch_changed = False
def _ensure_removed_sorted(self) -> None:
"""Sort removed request indices in
descending order.
Idempotent after first call in a
given step, until reset.
"""
if not self._is_removed_sorted:
self._removed.sort(reverse=True)
self._is_removed_sorted = True
@property
def removed(self) -> list[RemovedRequest]:
"""Removed request indices sorted in
descending order"""
self._ensure_removed_sorted()
return self._removed
def removed_append(self, index: int) -> None:
"""Register the removal of a request from the persistent batch.
Must not be called after the first time self.removed,
self.pop_removed() or self.peek_removed() are invoked.
Args:
index: request index
"""
if self._is_removed_sorted:
raise RuntimeError(
"Cannot register new removed request after self.removed has been read."
)
self._removed.append(index)
self.batch_changed = True
def has_removed(self) -> bool:
return bool(self._removed)
def peek_removed(self) -> int | None:
"""Return lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed[-1]
return None
def pop_removed(self) -> int | None:
"""Pop lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed.pop()
return None
def reset(self) -> bool:
"""Returns True if there were any changes to the batch."""
self._is_removed_sorted = False
self._removed.clear()
self.added.clear()
self.moved.clear()
batch_changed = self.batch_changed
self.batch_changed = False
return batch_changed
def get_and_reset(self, batch_size: int) -> BatchUpdate | None:
"""Generate a logitsprocs batch update data structure and reset
internal batch update builder state.
Args:
batch_size: current persistent batch size
Returns:
Frozen logitsprocs batch update instance; `None` if no updates
"""
# Reset removal-sorting logic
self._is_removed_sorted = False
self.batch_changed = False
if not any((self._removed, self.moved, self.added)):
# No update; short-circuit
return None
# Build batch state update
batch_update = BatchUpdate(
batch_size=batch_size,
removed=self._removed,
moved=self.moved,
added=self.added,
)
self._removed = []
self.moved = []
self.added = []
return batch_update
class LogitsProcessors:
"""Encapsulates initialized logitsproc objects."""
def __init__(self, logitsprocs: Iterator["LogitsProcessor"] | None = None) -> None:
self.argmax_invariant: list[LogitsProcessor] = []
self.non_argmax_invariant: list[LogitsProcessor] = []
if logitsprocs:
for logitproc in logitsprocs:
(
self.argmax_invariant
if logitproc.is_argmax_invariant()
else self.non_argmax_invariant
).append(logitproc)
@property
def all(self) -> Iterator["LogitsProcessor"]:
"""Iterator over all logits processors."""
return chain(self.argmax_invariant, self.non_argmax_invariant)