Sync from v0.13
This commit is contained in:
352
vllm/v1/sample/logits_processor/__init__.py
Normal file
352
vllm/v1/sample/logits_processor/__init__.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
import inspect
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.torch_utils import guard_cuda_initialization
|
||||
from vllm.v1.sample.logits_processor.builtin import (
|
||||
LogitBiasLogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
MinTokensLogitsProcessor,
|
||||
process_dict_updates,
|
||||
)
|
||||
from vllm.v1.sample.logits_processor.interface import (
|
||||
BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality,
|
||||
)
|
||||
from vllm.v1.sample.logits_processor.state import BatchUpdateBuilder, LogitsProcessors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Error message when the user tries to initialize vLLM with a pooling model
|
||||
# and custom logitsproces
|
||||
STR_POOLING_REJECTS_LOGITSPROCS = (
|
||||
"Pooling models do not support custom logits processors."
|
||||
)
|
||||
|
||||
# Error message when the user tries to initialize vLLM with a speculative
|
||||
# decoding enabled and custom logitsproces
|
||||
STR_SPEC_DEC_REJECTS_LOGITSPROCS = (
|
||||
"Custom logits processors are not supported when speculative decoding is enabled."
|
||||
)
|
||||
|
||||
LOGITSPROCS_GROUP = "vllm.logits_processors"
|
||||
|
||||
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
|
||||
MinTokensLogitsProcessor,
|
||||
LogitBiasLogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
]
|
||||
|
||||
|
||||
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
|
||||
"""Load all installed logit processor plugins"""
|
||||
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP)
|
||||
if len(installed_logitsprocs_plugins) == 0:
|
||||
logger.debug("No logitsprocs plugins installed (group %s).", LOGITSPROCS_GROUP)
|
||||
return []
|
||||
|
||||
# Load logitsprocs plugins
|
||||
logger.debug("Loading installed logitsprocs plugins (group %s):", LOGITSPROCS_GROUP)
|
||||
classes: list[type[LogitsProcessor]] = []
|
||||
for entrypoint in installed_logitsprocs_plugins:
|
||||
try:
|
||||
logger.debug(
|
||||
"- Loading logitproc plugin entrypoint=%s target=%s",
|
||||
entrypoint.name,
|
||||
entrypoint.value,
|
||||
)
|
||||
with guard_cuda_initialization():
|
||||
classes.append(entrypoint.load())
|
||||
except Exception as e:
|
||||
logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e)
|
||||
raise RuntimeError(
|
||||
f"Failed to load LogitsProcessor plugin {entrypoint}"
|
||||
) from e
|
||||
return classes
|
||||
|
||||
|
||||
def _load_logitsprocs_by_fqcns(
|
||||
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
|
||||
) -> list[type[LogitsProcessor]]:
|
||||
"""Load logit processor types, identifying them by fully-qualified class
|
||||
names (FQCNs).
|
||||
|
||||
Effectively, a mixed list of logitproc types and FQCN strings is converted
|
||||
into a list of entirely logitproc types, by loading from the FQCNs.
|
||||
|
||||
FQCN syntax is <module>:<type> i.e. x.y.z:CustomLogitProc
|
||||
|
||||
Already-loaded logitproc types must be subclasses of LogitsProcessor
|
||||
|
||||
Args:
|
||||
logits_processors: Potentially mixed list of logitsprocs types and FQCN
|
||||
strings for logitproc types
|
||||
|
||||
Returns:
|
||||
List of logitproc types
|
||||
|
||||
"""
|
||||
if not logits_processors:
|
||||
return []
|
||||
|
||||
logger.debug(
|
||||
"%s additional custom logits processors specified, checking whether "
|
||||
"they need to be loaded.",
|
||||
len(logits_processors),
|
||||
)
|
||||
|
||||
classes: list[type[LogitsProcessor]] = []
|
||||
for ldx, logitproc in enumerate(logits_processors):
|
||||
if isinstance(logitproc, type):
|
||||
logger.debug(" - Already-loaded logit processor: %s", logitproc.__name__)
|
||||
if not issubclass(logitproc, LogitsProcessor):
|
||||
raise ValueError(
|
||||
f"{logitproc.__name__} is not a subclass of LogitsProcessor"
|
||||
)
|
||||
classes.append(logitproc)
|
||||
continue
|
||||
|
||||
logger.debug("- Loading logits processor %s", logitproc)
|
||||
module_path, qualname = logitproc.split(":")
|
||||
|
||||
try:
|
||||
# Load module
|
||||
with guard_cuda_initialization():
|
||||
module = importlib.import_module(module_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load %sth LogitsProcessor plugin %s: %s",
|
||||
ldx,
|
||||
logitproc,
|
||||
e,
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
|
||||
) from e
|
||||
|
||||
# Walk down dotted name to get logitproc class
|
||||
obj = module
|
||||
for attr in qualname.split("."):
|
||||
obj = getattr(obj, attr)
|
||||
if not isinstance(obj, type):
|
||||
raise ValueError("Loaded logit processor must be a type.")
|
||||
if not issubclass(obj, LogitsProcessor):
|
||||
raise ValueError(f"{obj.__name__} must be a subclass of LogitsProcessor")
|
||||
classes.append(obj)
|
||||
|
||||
return classes
|
||||
|
||||
|
||||
def _load_custom_logitsprocs(
|
||||
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
|
||||
) -> list[type[LogitsProcessor]]:
|
||||
"""Load all custom logits processors.
|
||||
|
||||
* First load all installed logitproc plugins
|
||||
* Second load custom logitsprocs pass by the user at initialization time
|
||||
|
||||
Args:
|
||||
logits_processors: potentially mixed list of logitproc types and
|
||||
logitproc type fully-qualified names (FQCNs)
|
||||
which need to be loaded
|
||||
|
||||
Returns:
|
||||
A list of all loaded logitproc types
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_tpu():
|
||||
# No logitsprocs specified by caller
|
||||
# TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs
|
||||
return []
|
||||
|
||||
return _load_logitsprocs_plugins() + _load_logitsprocs_by_fqcns(logits_processors)
|
||||
|
||||
|
||||
def build_logitsprocs(
|
||||
vllm_config: "VllmConfig",
|
||||
device: torch.device,
|
||||
is_pin_memory: bool,
|
||||
is_pooling_model: bool,
|
||||
custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (),
|
||||
) -> LogitsProcessors:
|
||||
if is_pooling_model:
|
||||
if custom_logitsprocs:
|
||||
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
|
||||
logger.debug(
|
||||
"Skipping logits processor loading because pooling models"
|
||||
" do not support logits processors."
|
||||
)
|
||||
return LogitsProcessors()
|
||||
|
||||
# Check if speculative decoding is enabled.
|
||||
if vllm_config.speculative_config:
|
||||
if custom_logitsprocs:
|
||||
raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS)
|
||||
logger.warning(
|
||||
"min_p, logit_bias, and min_tokens parameters won't currently work "
|
||||
"with speculative decoding enabled."
|
||||
)
|
||||
return LogitsProcessors()
|
||||
|
||||
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
|
||||
return LogitsProcessors(
|
||||
ctor(vllm_config, device, is_pin_memory)
|
||||
for ctor in itertools.chain(
|
||||
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
|
||||
|
||||
|
||||
def validate_logits_processors_parameters(
|
||||
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
logits_processors = (
|
||||
tuple(logits_processors) if logits_processors is not None else None
|
||||
)
|
||||
for logits_procs in cached_load_custom_logitsprocs(logits_processors):
|
||||
logits_procs.validate_params(sampling_params)
|
||||
|
||||
|
||||
class AdapterLogitsProcessor(LogitsProcessor):
|
||||
"""Wrapper for per-request logits processors
|
||||
|
||||
To wrap a specific per-request logits processor,
|
||||
* Subclass `AdapterLogitsProcessor`
|
||||
* Implement `self.is_argmax_invariant()` base-class method
|
||||
* Implement `self.new_req_logits_processor(params)`
|
||||
|
||||
`self.__init__(vllm_config, device, is_pin_memory)` does not need to be
|
||||
overridden in general. However, to implement custom constructor behavior -
|
||||
especially any logic which operates on or stores `vllm_config`, `device`,
|
||||
or `is_pin_memory` - `self.__init__(vllm_config, device, is_pin_memory)`
|
||||
must be overridden and the override must call
|
||||
`super().__init__(vllm_config, device, is_pin_memory)`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
"""Subclass must invoke
|
||||
`super().__init__(vllm_config, device, is_pin_memory)`.
|
||||
|
||||
Subclass constructor may find it useful to utilize the `vllm_config`,
|
||||
`device` and `is_pin_memory` argument. However regardless of whether
|
||||
these arguments are used, the vLLM logits processor interface requires
|
||||
all three arguments to be present.
|
||||
"""
|
||||
|
||||
# Map req index -> logits processor state
|
||||
#
|
||||
# State representation is a partial[Tensor] comprising a request-level
|
||||
# logits processor with the output token ids argument and (if required)
|
||||
# the prompt token ids argument pre-populated
|
||||
#
|
||||
# Note that the partial carries a *reference* to output token ids, and
|
||||
# will thus always operate on the list as it is currently, not as it
|
||||
# was when the partial was created.
|
||||
self.req_info: dict[int, partial[torch.Tensor]] = {}
|
||||
|
||||
@abstractmethod
|
||||
def new_req_logits_processor(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> RequestLogitsProcessor | None:
|
||||
"""Consume request info; return a per-request logits processor.
|
||||
|
||||
Return None if logits processor does not need to be applied to request
|
||||
|
||||
Args:
|
||||
params: request sampling params
|
||||
|
||||
Returns:
|
||||
None if logits processor should not be applied to request; otherwise
|
||||
returns a `RequestLogitsProcessor` instance
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _new_state(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
prompt_ids: list[int] | None,
|
||||
output_ids: list[int],
|
||||
) -> partial[torch.Tensor] | None:
|
||||
"""Return state representation for new request
|
||||
|
||||
Returns None if logits processor is not applicable to request
|
||||
|
||||
Args:
|
||||
params: request sampling params
|
||||
prompt_ids: request prompt token ids
|
||||
output_ids: decoded tokens so far for this request
|
||||
|
||||
Returns:
|
||||
logits processor partial[Tensor] or None
|
||||
|
||||
"""
|
||||
if req_lp := self.new_req_logits_processor(params):
|
||||
args = (
|
||||
[prompt_ids, output_ids]
|
||||
if (len(inspect.signature(req_lp).parameters) == 3)
|
||||
else [output_ids]
|
||||
)
|
||||
return partial(req_lp, *args) # type: ignore[misc]
|
||||
return None
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
process_dict_updates(
|
||||
self.req_info,
|
||||
batch_update,
|
||||
self._new_state,
|
||||
)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.req_info:
|
||||
# Apply per-request logits processors to corresponding rows of
|
||||
# logits tensor
|
||||
for req_idx, req_lp in self.req_info.items():
|
||||
req_logits = logits[req_idx]
|
||||
new_logits = req_lp(req_logits)
|
||||
if new_logits is not req_logits:
|
||||
# Modify logits tensor row in-place if necessary
|
||||
logits[req_idx] = new_logits
|
||||
return logits
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LogitsProcessor",
|
||||
"LogitBiasLogitsProcessor",
|
||||
"MinPLogitsProcessor",
|
||||
"MinTokensLogitsProcessor",
|
||||
"BatchUpdate",
|
||||
"BatchUpdateBuilder",
|
||||
"MoveDirectionality",
|
||||
"LogitsProcessors",
|
||||
"build_logitsprocs",
|
||||
"STR_POOLING_REJECTS_LOGITSPROCS",
|
||||
"LOGITSPROCS_GROUP",
|
||||
"AdapterLogitsProcessor",
|
||||
]
|
||||
278
vllm/v1/sample/logits_processor/builtin.py
Normal file
278
vllm/v1/sample/logits_processor/builtin.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.v1.sample.logits_processor.interface import (
|
||||
BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MinPLogitsProcessor(LogitsProcessor):
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.min_p_count: int = 0
|
||||
|
||||
self.min_p_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=is_pin_memory
|
||||
)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
|
||||
self.use_double_tensor = torch.device(device).type != "cpu"
|
||||
|
||||
if self.use_double_tensor:
|
||||
# Pre-allocated device tensor
|
||||
self.min_p_device: torch.Tensor = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float32, device=device
|
||||
)
|
||||
else:
|
||||
self.min_p_device = self.min_p_cpu_tensor
|
||||
# Current slice of the device tensor
|
||||
self.min_p: torch.Tensor = self.min_p_device[:0]
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Min-p never impacts greedy sampling"""
|
||||
return True
|
||||
|
||||
def get_min_p_by_index(self, index: int) -> float:
|
||||
return float(self.min_p_cpu[index])
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
if not batch_update:
|
||||
return
|
||||
|
||||
needs_update = False
|
||||
# Process added requests.
|
||||
for index, params, _, _ in batch_update.added:
|
||||
min_p = params.min_p
|
||||
min_p_before = self.min_p_cpu[index]
|
||||
if min_p_before != min_p:
|
||||
needs_update = True
|
||||
self.min_p_cpu[index] = min_p
|
||||
if min_p and not min_p_before:
|
||||
self.min_p_count += 1
|
||||
elif not min_p and min_p_before:
|
||||
self.min_p_count -= 1
|
||||
|
||||
if self.min_p_count:
|
||||
# Process removed requests.
|
||||
if batch_update.removed:
|
||||
needs_update = True
|
||||
for index in batch_update.removed:
|
||||
if self.min_p_cpu[index]:
|
||||
self.min_p_cpu[index] = 0
|
||||
self.min_p_count -= 1
|
||||
|
||||
# Process moved requests, unidirectional (a->b) and swap (a<->b).
|
||||
for adx, bdx, direct in batch_update.moved:
|
||||
min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx]
|
||||
if min_p_a != min_p_b:
|
||||
needs_update = True
|
||||
self.min_p_cpu[bdx] = min_p_a
|
||||
if direct == MoveDirectionality.SWAP:
|
||||
self.min_p_cpu[adx] = min_p_b
|
||||
if direct == MoveDirectionality.UNIDIRECTIONAL:
|
||||
if min_p_a:
|
||||
self.min_p_cpu[adx] = 0
|
||||
if min_p_b:
|
||||
self.min_p_count -= 1
|
||||
|
||||
# Update tensors if needed.
|
||||
size = batch_update.batch_size
|
||||
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
|
||||
self.min_p = self.min_p_device[:size]
|
||||
if self.use_double_tensor:
|
||||
self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
|
||||
self.min_p.unsqueeze_(1)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.min_p_count:
|
||||
return logits
|
||||
|
||||
# Convert logits to probability distribution
|
||||
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||
# Calculate maximum probabilities per sequence
|
||||
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
|
||||
# Adjust min_p
|
||||
adjusted_min_p = max_probabilities.mul_(self.min_p)
|
||||
# Identify valid tokens using threshold comparison
|
||||
invalid_token_mask = probability_values < adjusted_min_p
|
||||
# Apply mask using boolean indexing
|
||||
logits.masked_fill_(invalid_token_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
|
||||
class LogitBiasLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self, _, device: torch.device, is_pin_memory: bool):
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory
|
||||
self.biases: dict[int, dict[int, float]] = {}
|
||||
|
||||
self.bias_tensor: torch.Tensor = torch.tensor(())
|
||||
self.logits_slice = (
|
||||
self._device_tensor([], torch.int32),
|
||||
self._device_tensor([], torch.int32),
|
||||
)
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Logit bias can rebalance token probabilities and change the
|
||||
outcome of argmax in greedy sampling."""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
needs_update = process_dict_updates(
|
||||
self.biases, batch_update, lambda params, _, __: params.logit_bias or None
|
||||
)
|
||||
|
||||
# Update tensors if needed.
|
||||
if needs_update:
|
||||
reqs: list[int] = []
|
||||
tok_ids: list[int] = []
|
||||
biases: list[float] = []
|
||||
for req, lb in self.biases.items():
|
||||
reqs.extend([req] * len(lb))
|
||||
tok_ids.extend(lb.keys())
|
||||
biases.extend(lb.values())
|
||||
|
||||
self.bias_tensor = self._device_tensor(biases, torch.float32)
|
||||
self.logits_slice = (
|
||||
self._device_tensor(reqs, torch.int32),
|
||||
self._device_tensor(tok_ids, torch.int32),
|
||||
)
|
||||
|
||||
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.tensor(
|
||||
data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
|
||||
).to(device=self.device, non_blocking=True)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.biases:
|
||||
logits[self.logits_slice] += self.bias_tensor
|
||||
return logits
|
||||
|
||||
|
||||
class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
# index -> (min_toks, output_token_ids, stop_token_ids)
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory
|
||||
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
|
||||
|
||||
# (req_idx_tensor,eos_tok_id_tensor)
|
||||
self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (
|
||||
self._device_tensor([], torch.int32),
|
||||
self._device_tensor([], torch.int32),
|
||||
)
|
||||
|
||||
self.neg_inf_tensor = torch.tensor(
|
||||
-float("inf"), dtype=torch.float32, device=self.device
|
||||
)
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""By censoring stop tokens, min-tokens can change the outcome
|
||||
of the argmax operation in greedy sampling."""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def add_request(
|
||||
params: SamplingParams, _: list[int] | None, output_tok_ids: list[int]
|
||||
) -> tuple[int, Sequence[int], set[int]] | None:
|
||||
min_tokens = params.min_tokens
|
||||
if not min_tokens or len(output_tok_ids) >= min_tokens:
|
||||
return None
|
||||
return min_tokens, output_tok_ids, params.all_stop_token_ids
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
needs_update = process_dict_updates(
|
||||
self.min_toks, batch_update, self.add_request
|
||||
)
|
||||
if self.min_toks:
|
||||
# Check for any requests that have attained their min tokens.
|
||||
to_remove = tuple(
|
||||
index
|
||||
for index, (min_toks, out_tok_ids, _) in self.min_toks.items()
|
||||
if len(out_tok_ids) >= min_toks
|
||||
)
|
||||
if to_remove:
|
||||
needs_update = True
|
||||
for index in to_remove:
|
||||
del self.min_toks[index]
|
||||
|
||||
# Update tensors if needed.
|
||||
if needs_update:
|
||||
reqs: list[int] = []
|
||||
tok_ids: list[int] = []
|
||||
for req, (_, _, stop_tok_ids) in self.min_toks.items():
|
||||
reqs.extend([req] * len(stop_tok_ids))
|
||||
tok_ids.extend(stop_tok_ids)
|
||||
|
||||
self.logits_slice = (
|
||||
self._device_tensor(reqs, torch.int32),
|
||||
self._device_tensor(tok_ids, torch.int32),
|
||||
)
|
||||
|
||||
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.tensor(
|
||||
data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
|
||||
).to(device=self.device, non_blocking=True)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.min_toks:
|
||||
# Inhibit EOS token for requests which have not reached min length
|
||||
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
|
||||
return logits
|
||||
|
||||
|
||||
def process_dict_updates(
|
||||
req_entries: dict[int, T],
|
||||
batch_update: BatchUpdate | None,
|
||||
new_state: Callable[[SamplingParams, list[int] | None, list[int]], T | None],
|
||||
) -> bool:
|
||||
"""Utility function to update dict state for sparse LogitsProcessors."""
|
||||
|
||||
if not batch_update:
|
||||
# Nothing to do.
|
||||
return False
|
||||
|
||||
updated = False
|
||||
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
|
||||
if (state := new_state(params, prompt_tok_ids, output_tok_ids)) is not None:
|
||||
req_entries[index] = state
|
||||
updated = True
|
||||
elif req_entries.pop(index, None) is not None:
|
||||
updated = True
|
||||
|
||||
if req_entries:
|
||||
# Process removed requests.
|
||||
for index in batch_update.removed:
|
||||
if req_entries.pop(index, None):
|
||||
updated = True
|
||||
|
||||
# Process moved requests, unidirectional (a->b) and
|
||||
# swapped (a<->b)
|
||||
for a_index, b_index, direct in batch_update.moved:
|
||||
a_entry = req_entries.pop(a_index, None)
|
||||
b_entry = req_entries.pop(b_index, None)
|
||||
if a_entry is not None:
|
||||
req_entries[b_index] = a_entry
|
||||
updated = True
|
||||
if b_entry is not None:
|
||||
updated = True
|
||||
if direct == MoveDirectionality.SWAP:
|
||||
req_entries[a_index] = b_entry
|
||||
|
||||
return updated
|
||||
106
vllm/v1/sample/logits_processor/interface.py
Normal file
106
vllm/v1/sample/logits_processor/interface.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class MoveDirectionality(Enum):
|
||||
# One-way i1->i2 req move within batch
|
||||
UNIDIRECTIONAL = auto()
|
||||
# Two-way i1<->i2 req swap within batch
|
||||
SWAP = auto()
|
||||
|
||||
|
||||
# Batch indices of any removed requests.
|
||||
RemovedRequest = int
|
||||
|
||||
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
|
||||
# requests added to the batch.
|
||||
AddedRequest = tuple[int, SamplingParams, list[int] | None, list[int]]
|
||||
|
||||
# (index 1, index 2, directionality) tuples representing
|
||||
# one-way moves or two-way swaps of requests in batch
|
||||
MovedRequest = tuple[int, int, MoveDirectionality]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BatchUpdate:
|
||||
"""Persistent batch state change info for logitsprocs"""
|
||||
|
||||
batch_size: int # Current num reqs in batch
|
||||
|
||||
# Metadata for requests added to, removed from, and moved
|
||||
# within the persistent batch.
|
||||
#
|
||||
# Key assumption: the `output_tok_ids` list (which is an element of each
|
||||
# tuple in `added`) is a reference to the request's running output tokens
|
||||
# list; via this reference, the logits processors always see the latest
|
||||
# list of generated output tokens.
|
||||
#
|
||||
# NOTE:
|
||||
# * Added or moved requests may replace existing requests with the same
|
||||
# index.
|
||||
# * Operations should be processed in the following order:
|
||||
# - removed, added, moved
|
||||
removed: Sequence[RemovedRequest]
|
||||
added: Sequence[AddedRequest]
|
||||
moved: Sequence[MovedRequest]
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
@classmethod
|
||||
def validate_params(cls, sampling_params: SamplingParams):
|
||||
"""Validate sampling params for this logits processor.
|
||||
|
||||
Raise ValueError for invalid ones.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply LogitsProcessor to batch logits tensor.
|
||||
|
||||
The updated tensor must be returned but may be
|
||||
modified in-place.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""True if logits processor has no impact on the
|
||||
argmax computation in greedy sampling.
|
||||
NOTE: may or may not have the same value for all
|
||||
instances of a given LogitsProcessor subclass,
|
||||
depending on subclass implementation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_state(
|
||||
self,
|
||||
batch_update: Optional["BatchUpdate"],
|
||||
) -> None:
|
||||
"""Called when there are new output tokens, prior
|
||||
to each forward pass.
|
||||
|
||||
Args:
|
||||
batch_update: Non-None iff there have been changes
|
||||
to the batch makeup.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
165
vllm/v1/sample/logits_processor/state.py
Normal file
165
vllm/v1/sample/logits_processor/state.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterator
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.v1.sample.logits_processor.interface import (
|
||||
AddedRequest,
|
||||
BatchUpdate,
|
||||
MovedRequest,
|
||||
RemovedRequest,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
||||
|
||||
|
||||
class BatchUpdateBuilder:
|
||||
"""Helps track persistent batch state changes and build
|
||||
a batch update data structure for logitsprocs
|
||||
Assumptions:
|
||||
* All information about requests removed from persistent batch
|
||||
during a step is aggregated in self._removed through calls to
|
||||
self.removed_append() at the beginning of a step. This must happen
|
||||
before the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are invoked in a given step
|
||||
* After the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are read in a step, no new removals
|
||||
are registered using self.removed_append()
|
||||
* Elements of self._removed are never directly modified, added or
|
||||
removed (i.e. modification is only via self.removed_append() and
|
||||
self.pop_removed())
|
||||
Guarantees under above assumptions:
|
||||
* self.removed is always sorted in descending order
|
||||
* self.pop_removed() and self.peek_removed() both return
|
||||
the lowest removed request index in the current step
|
||||
"""
|
||||
|
||||
_removed: list[RemovedRequest]
|
||||
_is_removed_sorted: bool
|
||||
added: list[AddedRequest]
|
||||
moved: list[MovedRequest]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
removed: list[RemovedRequest] | None = None,
|
||||
added: list[AddedRequest] | None = None,
|
||||
moved: list[MovedRequest] | None = None,
|
||||
) -> None:
|
||||
self._removed = removed or []
|
||||
self.added = added or []
|
||||
self.moved = moved or []
|
||||
self._is_removed_sorted = False
|
||||
|
||||
# Used to track changes in the pooling case
|
||||
# where we don't populate the added list.
|
||||
self.batch_changed = False
|
||||
|
||||
def _ensure_removed_sorted(self) -> None:
|
||||
"""Sort removed request indices in
|
||||
descending order.
|
||||
Idempotent after first call in a
|
||||
given step, until reset.
|
||||
"""
|
||||
if not self._is_removed_sorted:
|
||||
self._removed.sort(reverse=True)
|
||||
self._is_removed_sorted = True
|
||||
|
||||
@property
|
||||
def removed(self) -> list[RemovedRequest]:
|
||||
"""Removed request indices sorted in
|
||||
descending order"""
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed
|
||||
|
||||
def removed_append(self, index: int) -> None:
|
||||
"""Register the removal of a request from the persistent batch.
|
||||
|
||||
Must not be called after the first time self.removed,
|
||||
self.pop_removed() or self.peek_removed() are invoked.
|
||||
|
||||
Args:
|
||||
index: request index
|
||||
"""
|
||||
if self._is_removed_sorted:
|
||||
raise RuntimeError(
|
||||
"Cannot register new removed request after self.removed has been read."
|
||||
)
|
||||
self._removed.append(index)
|
||||
self.batch_changed = True
|
||||
|
||||
def has_removed(self) -> bool:
|
||||
return bool(self._removed)
|
||||
|
||||
def peek_removed(self) -> int | None:
|
||||
"""Return lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed[-1]
|
||||
return None
|
||||
|
||||
def pop_removed(self) -> int | None:
|
||||
"""Pop lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed.pop()
|
||||
return None
|
||||
|
||||
def reset(self) -> bool:
|
||||
"""Returns True if there were any changes to the batch."""
|
||||
self._is_removed_sorted = False
|
||||
self._removed.clear()
|
||||
self.added.clear()
|
||||
self.moved.clear()
|
||||
batch_changed = self.batch_changed
|
||||
self.batch_changed = False
|
||||
return batch_changed
|
||||
|
||||
def get_and_reset(self, batch_size: int) -> BatchUpdate | None:
|
||||
"""Generate a logitsprocs batch update data structure and reset
|
||||
internal batch update builder state.
|
||||
|
||||
Args:
|
||||
batch_size: current persistent batch size
|
||||
|
||||
Returns:
|
||||
Frozen logitsprocs batch update instance; `None` if no updates
|
||||
"""
|
||||
# Reset removal-sorting logic
|
||||
self._is_removed_sorted = False
|
||||
self.batch_changed = False
|
||||
if not any((self._removed, self.moved, self.added)):
|
||||
# No update; short-circuit
|
||||
return None
|
||||
# Build batch state update
|
||||
batch_update = BatchUpdate(
|
||||
batch_size=batch_size,
|
||||
removed=self._removed,
|
||||
moved=self.moved,
|
||||
added=self.added,
|
||||
)
|
||||
self._removed = []
|
||||
self.moved = []
|
||||
self.added = []
|
||||
return batch_update
|
||||
|
||||
|
||||
class LogitsProcessors:
|
||||
"""Encapsulates initialized logitsproc objects."""
|
||||
|
||||
def __init__(self, logitsprocs: Iterator["LogitsProcessor"] | None = None) -> None:
|
||||
self.argmax_invariant: list[LogitsProcessor] = []
|
||||
self.non_argmax_invariant: list[LogitsProcessor] = []
|
||||
if logitsprocs:
|
||||
for logitproc in logitsprocs:
|
||||
(
|
||||
self.argmax_invariant
|
||||
if logitproc.is_argmax_invariant()
|
||||
else self.non_argmax_invariant
|
||||
).append(logitproc)
|
||||
|
||||
@property
|
||||
def all(self) -> Iterator["LogitsProcessor"]:
|
||||
"""Iterator over all logits processors."""
|
||||
return chain(self.argmax_invariant, self.non_argmax_invariant)
|
||||
Reference in New Issue
Block a user