# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib import inspect import itertools from abc import abstractmethod from collections.abc import Sequence from functools import lru_cache, partial from typing import TYPE_CHECKING import torch from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor from vllm.sampling_params import SamplingParams from vllm.utils.torch_utils import guard_cuda_initialization from vllm.v1.sample.logits_processor.builtin import ( LogitBiasLogitsProcessor, MinPLogitsProcessor, MinTokensLogitsProcessor, process_dict_updates, ) from vllm.v1.sample.logits_processor.interface import ( BatchUpdate, LogitsProcessor, MoveDirectionality, ) from vllm.v1.sample.logits_processor.state import BatchUpdateBuilder, LogitsProcessors if TYPE_CHECKING: from vllm.config import VllmConfig logger = init_logger(__name__) # Error message when the user tries to initialize vLLM with a pooling model # and custom logitsproces STR_POOLING_REJECTS_LOGITSPROCS = ( "Pooling models do not support custom logits processors." ) # Error message when the user tries to initialize vLLM with a speculative # decoding enabled and custom logitsproces STR_SPEC_DEC_REJECTS_LOGITSPROCS = ( "Custom logits processors are not supported when speculative decoding is enabled." ) LOGITSPROCS_GROUP = "vllm.logits_processors" BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ MinTokensLogitsProcessor, LogitBiasLogitsProcessor, MinPLogitsProcessor, ] def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: """Load all installed logit processor plugins""" from importlib.metadata import entry_points installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP) if len(installed_logitsprocs_plugins) == 0: logger.debug("No logitsprocs plugins installed (group %s).", LOGITSPROCS_GROUP) return [] # Load logitsprocs plugins logger.debug("Loading installed logitsprocs plugins (group %s):", LOGITSPROCS_GROUP) classes: list[type[LogitsProcessor]] = [] for entrypoint in installed_logitsprocs_plugins: try: logger.debug( "- Loading logitproc plugin entrypoint=%s target=%s", entrypoint.name, entrypoint.value, ) with guard_cuda_initialization(): classes.append(entrypoint.load()) except Exception as e: logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e) raise RuntimeError( f"Failed to load LogitsProcessor plugin {entrypoint}" ) from e return classes def _load_logitsprocs_by_fqcns( logits_processors: Sequence[str | type[LogitsProcessor]] | None, ) -> list[type[LogitsProcessor]]: """Load logit processor types, identifying them by fully-qualified class names (FQCNs). Effectively, a mixed list of logitproc types and FQCN strings is converted into a list of entirely logitproc types, by loading from the FQCNs. FQCN syntax is : i.e. x.y.z:CustomLogitProc Already-loaded logitproc types must be subclasses of LogitsProcessor Args: logits_processors: Potentially mixed list of logitsprocs types and FQCN strings for logitproc types Returns: List of logitproc types """ if not logits_processors: return [] logger.debug( "%s additional custom logits processors specified, checking whether " "they need to be loaded.", len(logits_processors), ) classes: list[type[LogitsProcessor]] = [] for ldx, logitproc in enumerate(logits_processors): if isinstance(logitproc, type): logger.debug(" - Already-loaded logit processor: %s", logitproc.__name__) if not issubclass(logitproc, LogitsProcessor): raise ValueError( f"{logitproc.__name__} is not a subclass of LogitsProcessor" ) classes.append(logitproc) continue logger.debug("- Loading logits processor %s", logitproc) module_path, qualname = logitproc.split(":") try: # Load module with guard_cuda_initialization(): module = importlib.import_module(module_path) except Exception as e: logger.error( "Failed to load %sth LogitsProcessor plugin %s: %s", ldx, logitproc, e, ) raise RuntimeError( f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}" ) from e # Walk down dotted name to get logitproc class obj = module for attr in qualname.split("."): obj = getattr(obj, attr) if not isinstance(obj, type): raise ValueError("Loaded logit processor must be a type.") if not issubclass(obj, LogitsProcessor): raise ValueError(f"{obj.__name__} must be a subclass of LogitsProcessor") classes.append(obj) return classes def _load_custom_logitsprocs( logits_processors: Sequence[str | type[LogitsProcessor]] | None, ) -> list[type[LogitsProcessor]]: """Load all custom logits processors. * First load all installed logitproc plugins * Second load custom logitsprocs pass by the user at initialization time Args: logits_processors: potentially mixed list of logitproc types and logitproc type fully-qualified names (FQCNs) which need to be loaded Returns: A list of all loaded logitproc types """ from vllm.platforms import current_platform if current_platform.is_tpu(): # No logitsprocs specified by caller # TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs return [] return _load_logitsprocs_plugins() + _load_logitsprocs_by_fqcns(logits_processors) def build_logitsprocs( vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool, is_pooling_model: bool, custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (), ) -> LogitsProcessors: if is_pooling_model: if custom_logitsprocs: raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) logger.debug( "Skipping logits processor loading because pooling models" " do not support logits processors." ) return LogitsProcessors() # Check if speculative decoding is enabled. if vllm_config.speculative_config: if custom_logitsprocs: raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS) logger.warning( "min_p, logit_bias, and min_tokens parameters won't currently work " "with speculative decoding enabled." ) return LogitsProcessors() custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) return LogitsProcessors( ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes ) ) cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs) def validate_logits_processors_parameters( logits_processors: Sequence[str | type[LogitsProcessor]] | None, sampling_params: SamplingParams, ): logits_processors = ( tuple(logits_processors) if logits_processors is not None else None ) for logits_procs in cached_load_custom_logitsprocs(logits_processors): logits_procs.validate_params(sampling_params) class AdapterLogitsProcessor(LogitsProcessor): """Wrapper for per-request logits processors To wrap a specific per-request logits processor, * Subclass `AdapterLogitsProcessor` * Implement `self.is_argmax_invariant()` base-class method * Implement `self.new_req_logits_processor(params)` `self.__init__(vllm_config, device, is_pin_memory)` does not need to be overridden in general. However, to implement custom constructor behavior - especially any logic which operates on or stores `vllm_config`, `device`, or `is_pin_memory` - `self.__init__(vllm_config, device, is_pin_memory)` must be overridden and the override must call `super().__init__(vllm_config, device, is_pin_memory)` """ def __init__( self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool ): """Subclass must invoke `super().__init__(vllm_config, device, is_pin_memory)`. Subclass constructor may find it useful to utilize the `vllm_config`, `device` and `is_pin_memory` argument. However regardless of whether these arguments are used, the vLLM logits processor interface requires all three arguments to be present. """ # Map req index -> logits processor state # # State representation is a partial[Tensor] comprising a request-level # logits processor with the output token ids argument and (if required) # the prompt token ids argument pre-populated # # Note that the partial carries a *reference* to output token ids, and # will thus always operate on the list as it is currently, not as it # was when the partial was created. self.req_info: dict[int, partial[torch.Tensor]] = {} @abstractmethod def new_req_logits_processor( self, params: SamplingParams, ) -> RequestLogitsProcessor | None: """Consume request info; return a per-request logits processor. Return None if logits processor does not need to be applied to request Args: params: request sampling params Returns: None if logits processor should not be applied to request; otherwise returns a `RequestLogitsProcessor` instance """ raise NotImplementedError def _new_state( self, params: SamplingParams, prompt_ids: list[int] | None, output_ids: list[int], ) -> partial[torch.Tensor] | None: """Return state representation for new request Returns None if logits processor is not applicable to request Args: params: request sampling params prompt_ids: request prompt token ids output_ids: decoded tokens so far for this request Returns: logits processor partial[Tensor] or None """ if req_lp := self.new_req_logits_processor(params): args = ( [prompt_ids, output_ids] if (len(inspect.signature(req_lp).parameters) == 3) else [output_ids] ) return partial(req_lp, *args) return None def update_state(self, batch_update: BatchUpdate | None): process_dict_updates( self.req_info, batch_update, self._new_state, ) def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.req_info: # Apply per-request logits processors to corresponding rows of # logits tensor for req_idx, req_lp in self.req_info.items(): req_logits = logits[req_idx] new_logits = req_lp(req_logits) if new_logits is not req_logits: # Modify logits tensor row in-place if necessary logits[req_idx] = new_logits return logits __all__ = [ "LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor", "MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder", "MoveDirectionality", "LogitsProcessors", "build_logitsprocs", "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP", "AdapterLogitsProcessor", ]