Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

217
vllm/v1/engine/__init__.py Normal file
View File

@@ -0,0 +1,217 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
from collections.abc import Mapping
from typing import Any
import msgspec
import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
from vllm.v1.serial_utils import UtilityResult
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
class FinishReason(enum.IntEnum):
"""
Reason a request finished - stop, length, abort, or error.
Int rather than Str for more compact serialization.
stop - a stop string was emitted
length - max_tokens was consumed, or max_model_len was reached
abort - aborted by client
error - retryable request-level internal error (e.g., KV load failure).
Invariant: always converted to 500 Internal Server Error.
"""
STOP = 0
LENGTH = 1
ABORT = 2
ERROR = 3
def __str__(self):
return FINISH_REASON_STRINGS[self.value]
class EngineCoreRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False,
): # type: ignore[call-arg]
request_id: str
prompt_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec] | None
sampling_params: SamplingParams | None
pooling_params: PoolingParams | None
eos_token_id: int | None
arrival_time: float
lora_request: LoRARequest | None
cache_salt: str | None
data_parallel_rank: int | None
prompt_embeds: torch.Tensor | None = None
# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.
client_index: int = 0
# Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before
# a wave finished notification is received.
current_wave: int = 0
priority: int = 0
trace_headers: Mapping[str, str] | None = None
@property
def params(self) -> SamplingParams | PoolingParams:
"""Return the processed params (sampling or pooling)."""
if self.sampling_params is not None:
return self.sampling_params
assert self.pooling_params is not None
return self.pooling_params
class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
QUEUED = 1
SCHEDULED = 2
PREEMPTED = 3
class EngineCoreEvent(msgspec.Struct):
"""A timestamped engine core event associated with a request.
The timestamp is a monotonic timestamps and is used for by the engine
frontend to calculate intervals between engine core events. These
timestamps should not be compared with timestamps from other processes.
"""
type: EngineCoreEventType
timestamp: float
@classmethod
def new_event(
cls, event_type: EngineCoreEventType, timestamp: float | None = None
) -> "EngineCoreEvent":
timestamp = time.monotonic() if timestamp is None else timestamp
return cls(event_type, timestamp)
class EngineCoreOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False,
): # type: ignore[call-arg]
request_id: str
new_token_ids: list[int]
new_logprobs: LogprobsLists | None = None
new_prompt_logprobs_tensors: LogprobsTensors | None = None
pooling_output: torch.Tensor | None = None
finish_reason: FinishReason | None = None
stop_reason: int | str | None = None
events: list[EngineCoreEvent] | None = None
kv_transfer_params: dict[str, Any] | None = None
trace_headers: Mapping[str, str] | None = None
# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0
# The number of NaNs in logits.
# A value greater than 0 indicates that the output is corrupted.
num_nans_in_logits: int = 0
@property
def finished(self) -> bool:
return self.finish_reason is not None
class UtilityOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
gc=False,
): # type: ignore[call-arg]
call_id: int
# Non-None implies the call failed, result should be None.
failure_message: str | None = None
result: UtilityResult | None = None
class EngineCoreOutputs(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False,
): # type: ignore[call-arg]
# NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout
engine_index: int = 0
# [num_reqs]
outputs: list[EngineCoreOutput] = []
scheduler_stats: SchedulerStats | None = None
timestamp: float = 0.0
utility_output: UtilityOutput | None = None
finished_requests: set[str] | None = None
# In DP case, used to signal that the current wave of requests
# has finished and the engines are paused.
wave_complete: int | None = None
# In DP case, used to signal that a request was received for an
# "old" wave, so the next wave needs to be started in other engines.
start_wave: int | None = None
def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.monotonic()
class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
ADD = b"\x00"
ABORT = b"\x01"
START_DP_WAVE = b"\x02"
UTILITY = b"\x03"
# Sentinel used within EngineCoreProc.
EXECUTOR_FAILED = b"\x04"
class ReconfigureDistributedRequest(msgspec.Struct):
new_data_parallel_size: int
new_data_parallel_rank: int
new_data_parallel_rank_local: int
new_data_parallel_master_ip: str
new_data_parallel_master_port: int
class ReconfigureRankType(enum.IntEnum):
"""
Rank type for reconfiguring distributed request.
"""
KEEP_CURRENT_RANK = -1
SHUTDOWN_CURRENT_RANK = -2

866
vllm/v1/engine/async_llm.py Normal file
View File

@@ -0,0 +1,866 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
import socket
import time
from collections.abc import AsyncGenerator, Iterable, Mapping
from copy import copy
from typing import Any, cast
import numpy as np
import torch
from typing_extensions import deprecated
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.usage.usage_lib import UsageContext
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
from vllm.utils.math_utils import cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import (
StatLoggerFactory,
StatLoggerManager,
load_stat_logger_plugin_factories,
)
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__)
class AsyncLLM(EngineClient):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: list[StatLoggerFactory] | None = None,
aggregate_engine_logging: bool = False,
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
) -> None:
"""
Create an AsyncLLM.
Args:
vllm_config: global configuration.
executor_class: an Executor impl, e.g. MultiprocExecutor.
log_stats: Whether to log stats.
usage_context: Usage context of the LLM.
mm_registry: Multi-modal registry.
use_cached_outputs: Whether to use cached outputs.
log_requests: Whether to log requests.
start_engine_loop: Whether to start the engine loop.
stat_loggers: customized stat loggers for the engine.
If not provided, default stat loggers will be used.
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
Returns:
None
"""
# Ensure we can serialize custom transformer configs
maybe_register_config_serialize_by_value()
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
self.log_requests = log_requests
custom_stat_loggers = list(stat_loggers or [])
custom_stat_loggers.extend(load_stat_logger_plugin_factories())
has_custom_loggers = bool(custom_stat_loggers)
self.log_stats = log_stats or has_custom_loggers
if not log_stats and has_custom_loggers:
logger.info(
"AsyncLLM created with log_stats=False, "
"but custom stat loggers were found; "
"enabling logging without default stat loggers."
)
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_tokenizer_from_config(self.model_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor(
self.vllm_config,
self.model_config.io_processor_plugin,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(
self.tokenizer,
log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval,
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
tracer = init_tracer("vllm.llm_engine", endpoint)
self.output_processor.tracer = tracer
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
client_addresses=client_addresses,
client_count=client_count,
client_index=client_index,
)
# Loggers.
self.logger_manager: StatLoggerManager | None = None
if self.log_stats:
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks_managed,
custom_stat_loggers=custom_stat_loggers,
enable_default_loggers=log_stats,
client_count=client_count,
aggregate_engine_logging=aggregate_engine_logging,
)
self.logger_manager.log_engine_initialized()
# Pause / resume state for async RL workflows.
self._pause_cond = asyncio.Condition()
self._paused = False
self.output_handler: asyncio.Task | None = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
asyncio.get_running_loop()
self._run_output_handler()
except RuntimeError:
pass
if (
vllm_config.profiler_config.profiler == "torch"
and not vllm_config.profiler_config.ignore_frontend
):
profiler_dir = vllm_config.profiler_config.torch_profiler_dir
logger.info(
"Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501
profiler_dir,
)
worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
],
with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
profiler_dir,
worker_name=worker_name,
use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
),
)
else:
self.profiler = None
@property
@deprecated(
"`AsyncLLM.processor` has been renamed to `AsyncLLM.input_processor`. "
"The old name will be removed in v0.14."
)
def processor(self):
return self.input_processor
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: list[StatLoggerFactory] | None = None,
enable_log_requests: bool = False,
aggregate_engine_logging: bool = False,
disable_log_stats: bool = False,
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
) -> "AsyncLLM":
# Create the LLMEngine.
return cls(
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop,
stat_loggers=stat_loggers,
log_requests=enable_log_requests,
log_stats=not disable_log_stats,
aggregate_engine_logging=aggregate_engine_logging,
usage_context=usage_context,
client_addresses=client_addresses,
client_count=client_count,
client_index=client_index,
)
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: list[StatLoggerFactory] | None = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
# Create the AsyncLLM.
return cls(
vllm_config=vllm_config,
executor_class=executor_class,
log_requests=engine_args.enable_log_requests,
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
def __del__(self):
self.shutdown()
def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC."""
shutdown_prometheus()
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
handler = getattr(self, "output_handler", None)
if handler is not None:
cancel_task_threadsafe(handler)
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async()
async def add_request(
self,
request_id: str,
prompt: EngineCoreRequest | PromptType,
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
prompt_text: str | None = None,
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""
if self.errored:
raise EngineDeadError()
is_pooling = isinstance(params, PoolingParams)
# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)
# Convert Input --> Request.
if isinstance(prompt, EngineCoreRequest):
request = prompt
else:
assert prompt_text is None
request = self.input_processor.process_inputs(
request_id,
prompt,
params,
arrival_time,
lora_request,
tokenization_kwargs,
trace_headers,
priority,
data_parallel_rank,
)
if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
# Use cloned params that may have been updated in process_inputs()
params = request.params
if is_pooling or params.n == 1:
await self._add_request(request, prompt_text, None, 0, queue)
return queue
parent_params = params
assert isinstance(parent_params, SamplingParams)
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, parent_params)
for idx in range(parent_params.n):
request_id, child_params = parent_request.get_child_info(idx)
child_request = request if idx == parent_params.n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = child_params
await self._add_request(
child_request, prompt_text, parent_request, idx, queue
)
return queue
async def _add_request(
self,
request: EngineCoreRequest,
prompt: str | None,
parent_req: ParentRequest | None,
index: int,
queue: RequestOutputCollector,
):
# Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, prompt, parent_req, index, queue)
# Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: EngineCoreRequest | PromptType,
sampling_params: SamplingParams,
request_id: str,
*,
prompt_text: str | None = None,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
) -> AsyncGenerator[RequestOutput, None]:
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
* 2) Processing the Input.
* 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process).
A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the
per-request AsyncStream.
The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller.
"""
if (
self.vllm_config.cache_config.kv_sharing_fast_prefill
and sampling_params.prompt_logprobs
):
raise ValueError(
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, please disable it when the requests need "
"prompt logprobs"
)
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
# Wait until generation is resumed if the engine is paused.
async with self._pause_cond:
await self._pause_cond.wait_for(lambda: not self._paused)
if tokenization_kwargs is None:
tokenization_kwargs = {}
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
_validate_truncation_size(
self.model_config.max_model_len,
truncate_prompt_tokens,
tokenization_kwargs,
)
q = await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
prompt_text=prompt_text,
)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() or await q.get()
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
assert isinstance(out, RequestOutput)
yield out
# If the request is disconnected by the client, generate()
# is cancelled or the generator is garbage collected. So,
# we abort the request if we end up here.
except (asyncio.CancelledError, GeneratorExit):
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
# Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
if self.output_handler is not None:
return
# Ensure that the task doesn't have a circular ref back to the AsyncLLM
# object, or else it won't be garbage collected and cleaned up properly.
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
logger_manager = self.logger_manager
input_processor = self.input_processor
async def output_handler():
try:
while True:
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await engine_core.get_output_async()
num_outputs = len(outputs.outputs)
iteration_stats = (
IterationStats() if (log_stats and num_outputs) else None
)
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
if num_outputs <= envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs,)
else:
slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE),
)
for i, outputs_slice in enumerate(slices):
# 2) Process EngineCoreOutputs.
processed_outputs = output_processor.process_outputs(
outputs_slice, outputs.timestamp, iteration_stats
)
# NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs
# Allow other asyncio tasks to run between chunks
if i + 1 < len(slices):
await asyncio.sleep(0)
# 3) Abort any reqs that finished due to stop strings.
await engine_core.abort_requests_async(
processed_outputs.reqs_to_abort
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if logger_manager:
logger_manager.record(
engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=input_processor.stat_mm_cache(),
)
except Exception as e:
logger.exception("AsyncLLM output_handler failed.")
output_processor.propagate_error(e)
self.output_handler = asyncio.create_task(output_handler())
async def abort(self, request_id: str | Iterable[str]) -> None:
"""Abort RequestId in OutputProcessor and EngineCore."""
request_ids = (
(request_id,) if isinstance(request_id, str) else as_list(request_id)
)
all_request_ids = self.output_processor.abort_requests(request_ids)
await self.engine_core.abort_requests_async(all_request_ids)
if self.log_requests:
logger.info("Aborted request(s) %s.", ",".join(request_ids))
async def pause_generation(
self,
*,
wait_for_inflight_requests: bool = False,
clear_cache: bool = True,
) -> None:
"""
Pause generation to allow model weight updates.
New generation/encoding requests are blocked until resume.
Args:
wait_for_inflight_requests: When ``True`` waits for in-flight
requests to finish before pausing. When ``False`` (default),
immediately aborts any in-flight requests.
clear_cache: Whether to clear KV cache and prefix cache after
draining. Set to ``False`` to preserve cache for faster resume.
Default is ``True`` (clear caches).
"""
async with self._pause_cond:
if self._paused:
return
self._paused = True
if not wait_for_inflight_requests:
request_ids = list(self.output_processor.request_states.keys())
if request_ids:
await self.abort(request_ids)
# Wait for running requests to drain before clearing cache.
if self.output_processor.has_unfinished_requests():
await self.output_processor.wait_for_requests_to_drain()
# Clear cache
if clear_cache:
await self.reset_prefix_cache()
await self.reset_mm_cache()
async def resume_generation(self) -> None:
"""Resume generation after :meth:`pause_generation`."""
async with self._pause_cond:
self._paused = False
self._pause_cond.notify_all() # Wake up all waiting requests
async def is_paused(self) -> bool:
"""Return whether the engine is currently paused."""
async with self._pause_cond:
return self._paused
async def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
* 2) Processing the Input.
* 3) Adding the Request to the EngineCore (separate process).
A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the
per-request AsyncStream.
The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller.
"""
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
# Respect pause state before accepting new requests.
async with self._pause_cond:
await self._pause_cond.wait_for(lambda: not self._paused)
if tokenization_kwargs is None:
tokenization_kwargs = {}
_validate_truncation_size(
self.model_config.max_model_len,
truncate_prompt_tokens,
tokenization_kwargs,
)
q = await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() or await q.get()
assert isinstance(out, PoolingRequestOutput)
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
yield out
# If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
# Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
@property
def tokenizer(self) -> TokenizerLike | None:
return self.input_processor.tokenizer
async def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return self.tokenizer
async def is_tracing_enabled(self) -> bool:
return self.observability_config.otlp_traces_endpoint is not None # type: ignore
async def do_log_stats(self) -> None:
if self.logger_manager:
self.logger_manager.log()
async def check_health(self) -> None:
logger.debug("Called check_health.")
if self.errored:
raise self.dead_error
async def start_profile(self) -> None:
coros = [self.engine_core.profile_async(True)]
if self.profiler is not None:
coros.append(asyncio.to_thread(self.profiler.start))
await asyncio.gather(*coros)
async def stop_profile(self) -> None:
coros = [self.engine_core.profile_async(False)]
if self.profiler is not None:
coros.append(asyncio.to_thread(self.profiler.stop))
await asyncio.gather(*coros)
async def reset_mm_cache(self) -> None:
self.input_processor.clear_mm_cache()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return await self.engine_core.reset_prefix_cache_async(
reset_running_requests, reset_connector
)
async def sleep(self, level: int = 1) -> None:
await self.reset_prefix_cache()
await self.engine_core.sleep_async(level)
if self.logger_manager is not None:
self.logger_manager.record_sleep_state(1, level)
async def wake_up(self, tags: list[str] | None = None) -> None:
await self.engine_core.wake_up_async(tags)
if self.logger_manager is not None:
self.logger_manager.record_sleep_state(0, 0)
async def is_sleeping(self) -> bool:
return await self.engine_core.is_sleeping_async()
async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
return await self.engine_core.add_lora_async(lora_request)
async def remove_lora(self, lora_id: int) -> bool:
"""Remove an already loaded LoRA adapter."""
return await self.engine_core.remove_lora_async(lora_id)
async def list_loras(self) -> set[int]:
"""List all registered adapters."""
return await self.engine_core.list_loras_async()
async def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted."""
return await self.engine_core.pin_lora_async(lora_id)
async def collective_rpc(
self,
method: str,
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
):
"""
Perform a collective RPC call to the given path.
"""
return await self.engine_core.collective_rpc_async(
method, timeout, args, kwargs
)
async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
"""Wait for all requests to be drained."""
start_time = time.time()
while time.time() - start_time < drain_timeout:
if not self.engine_core.dp_engines_running():
logger.info("Engines are idle, requests have been drained")
return
logger.info("Engines are still running, waiting for requests to drain...")
await asyncio.sleep(1) # Wait 1 second before checking again
raise TimeoutError(
f"Timeout reached after {drain_timeout} seconds "
"waiting for requests to drain."
)
async def scale_elastic_ep(
self, new_data_parallel_size: int, drain_timeout: int = 300
):
"""
Scale up or down the data parallel size by adding or removing
engine cores.
Args:
new_data_parallel_size: The new number of data parallel workers
drain_timeout:
Maximum time to wait for requests to drain (seconds)
"""
old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if old_data_parallel_size == new_data_parallel_size:
logger.info(
"Data parallel size is already %s, skipping scale",
new_data_parallel_size,
)
return
logger.info(
"Waiting for requests to drain before scaling up to %s engines...",
new_data_parallel_size,
)
await self.wait_for_requests_to_drain(drain_timeout)
logger.info(
"Requests have been drained, proceeding with scale to %s engines",
new_data_parallel_size,
)
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# recreate stat loggers
if new_data_parallel_size > old_data_parallel_size and self.log_stats:
# TODO(rob): fix this after talking with Ray team.
# This resets all the prometheus metrics since we
# unregister during initialization. Need to understand
# the intended behavior here better.
self.logger_manager = StatLoggerManager(
vllm_config=self.vllm_config,
engine_idxs=list(range(new_data_parallel_size)),
custom_stat_loggers=None,
)
@property
def is_running(self) -> bool:
# Is None before the loop is started.
return self.output_handler is None or not self.output_handler.done()
@property
def is_stopped(self) -> bool:
return self.errored
@property
def errored(self) -> bool:
return self.engine_core.resources.engine_dead or not self.is_running
@property
def dead_error(self) -> BaseException:
return EngineDeadError()

View File

@@ -0,0 +1,377 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import multiprocessing
import time
import weakref
import msgspec.msgpack
import zmq
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils.network_utils import make_zmq_socket
from vllm.utils.system_utils import get_mp_context, set_process_title
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
from vllm.v1.serial_utils import MsgpackDecoder
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
logger = init_logger(__name__)
class DPCoordinator:
"""Coordinator process used for data-parallel deployments (DP>1).
Intermediates between multiple DP engine rank processes and one or more
front-end API server processes.
* Collects stats from each DP engine (currently just waiting and running
queue lengths), and publishes these to all front-ends for use in
load-balancing decisions.
* Keeps track of the current DP "request wave" number and running state
of the engines. This is received from the DP rank 0 engine and published
to the front-end processes along with the current load stats.
The engines alternate between a global running/paused state. The global
"request wave" number is a count of the number of times that the workers
collectively move from a running state to a paused state. This transition
is synchronized via the all-reduce operation performed in the
DPEngineCoreProc._has_global_unfinished_reqs method.
* Broadcasts the START_DP_WAVE message to engines to move them from paused
to running state when one engine receives a new request. This can happen
in two cases:
1) A front-end sending a new request while the engines are paused will
concurrently notify the coordinator.
2) An engine receiving a request for a stale request wave while in paused
state will notify the coordinator.
Engines will move into running state when receiving a new request or
START_DP_WAVE message.
Note that when deployed in External LB mode, no stats will be published by
the engines and thus updates will only be sent to front-ends when the
request wave / running state changes.
"""
def __init__(self, parallel_config: ParallelConfig):
dp_size = parallel_config.data_parallel_size
assert dp_size > 1, "Coordinator only used for data parallel"
host = parallel_config.data_parallel_master_ip
external_lb = parallel_config.data_parallel_external_lb
hybrid_lb = parallel_config.data_parallel_hybrid_lb
# Assume coordinator is colocated with front-end procs when not in
# either external or hybrid DP LB mode.
local_only = not (external_lb or hybrid_lb)
front_publish_address = get_engine_client_zmq_addr(
local_only=local_only, host=host
)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
context = get_mp_context()
self.proc: multiprocessing.Process = context.Process(
target=DPCoordinatorProc.run_coordinator,
name="VLLM_DP_Coordinator",
kwargs={
"engine_count": parallel_config.data_parallel_size,
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
},
daemon=True,
)
self.proc.start()
self.stats_publish_address = front_publish_address
self.coord_in_address = back_publish_address
self.coord_out_address = back_output_address
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
def get_stats_publish_address(self) -> str:
return self.stats_publish_address
def get_engine_socket_addresses(self) -> tuple[str, str]:
"""Returns tuple of ZMQ input address, output address."""
return self.coord_in_address, self.coord_out_address
def close(self):
self._finalizer()
class EngineState:
def __init__(self):
self.request_counts = [0, 0] # [waiting, running]
class DPCoordinatorProc:
def __init__(self, engine_count: int, min_stats_update_interval_ms: int = 100):
set_process_title("DPCoordinator")
self.ctx = zmq.Context()
self.engines = [EngineState() for _ in range(engine_count)]
self.stats_update_interval_ms = min_stats_update_interval_ms
@staticmethod
def run_coordinator(
engine_count: int,
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
min_stats_update_interval_ms: int = 100,
):
coordinator = DPCoordinatorProc(
engine_count=engine_count,
min_stats_update_interval_ms=min_stats_update_interval_ms,
)
try:
coordinator.process_input_socket(
front_publish_address,
back_output_address,
back_publish_address,
)
except KeyboardInterrupt:
logger.info("DP Coordinator process exiting")
def process_input_socket(
self,
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
):
decoder = MsgpackDecoder(EngineCoreOutputs)
# For tracking request wave progression.
current_wave = 0
engines_running = False
# For tracking request counts for internal load-balancing.
stats_changed = False
last_stats_step = -1
last_stats_wave = -1
last_step_counts: list[list[int]] | None = None
with (
make_zmq_socket(
path=front_publish_address, # IPC
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_front,
make_zmq_socket(
path=back_output_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.PULL,
bind=True,
) as output_back,
make_zmq_socket(
path=back_publish_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_back,
):
# Wait until all engines subscribe.
for _ in self.engines:
if publish_back.recv() != b"\x01":
logger.error(
"DP Coordinator received unexpected message while "
"waiting for engines to subscribe"
)
return
# Send ready message to engines.
publish_back.send(b"READY")
logger.info("All engine subscriptions received by DP coordinator")
poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN)
last_publish_time = 0
while True:
elapsed = int(time.time() * 1000) - last_publish_time
# Send at stats_update_interval_ms interval if the stats have
# changed, or otherwise every 5 seconds.
wait_for = self.stats_update_interval_ms if stats_changed else 5000
# Wait at least 50ms to ensure we've received all stats for
# the current step.
min_timeout = 50 if last_step_counts is None else 0
events = poller.poll(timeout=max(min_timeout, wait_for - elapsed))
if not events:
# Poller timeout - publish current stats to front-ends.
if last_step_counts is not None:
engine_req_counts_list = last_step_counts
last_step_counts = None
else:
engine_req_counts_list = self._get_engine_counts()
stats_changed = False
to_publish = (engine_req_counts_list, current_wave, engines_running)
publish_front.send(msgspec.msgpack.encode(to_publish))
last_publish_time = int(time.time() * 1000)
continue
events = dict(events)
wave_state_changed = False
if publish_front in events:
buffer = publish_front.recv()
if buffer in (b"\x01", b"\x00"):
# Ignore subscription messages.
continue
decoded = msgspec.msgpack.decode(buffer)
if (
isinstance(decoded, (list, tuple))
and len(decoded) == 2
and decoded[0] == "SCALE_ELASTIC_EP"
):
# Handle scale up notification
new_engine_count = decoded[1]
current_count = len(self.engines)
if new_engine_count > current_count:
for _ in range(new_engine_count - current_count):
self.engines.append(EngineState())
# NOTE(yongji): handle the case
# where newly started engines have current_wave = 0
# if existing engines just finished a wave
# and engine_running isn't updated yet at
# CoordinatorProc requests routed to newly started
# engines may not wake up existing engines, as long
# as 0 < request.wave < existing engines'
# current_wave
# we note that 0 is the wave number for the new
# engine
engines_running = False
logger.info(
"DPCoordinator scaled up from %s to %s engines",
current_count,
new_engine_count,
)
else:
self.engines = self.engines[:new_engine_count]
logger.info(
"DPCoordinator scaled down from %s to %s engines",
current_count,
new_engine_count,
)
continue # Skip normal engine notification processing
# We received a message on the front-end XPUB socket,
# from an API server sending a new request while the
# engines are paused, so that we can wake the other
# engines.
engine_to_exclude, wave = decoded
if not engines_running:
if wave < current_wave:
# If the wave number is stale, ensure the message
# is handled by all the engines.
engine_to_exclude = None
engines_running = True
wave_state_changed = True
self._send_start_wave(
publish_back, current_wave, engine_to_exclude
)
if output_back in events:
# We received a message from one of the engines.
buffer = output_back.recv()
outputs: EngineCoreOutputs = decoder.decode(buffer)
assert not outputs.outputs
assert outputs.utility_output is None
eng_index = outputs.engine_index
scheduler_stats = outputs.scheduler_stats
if scheduler_stats:
# 1. Updated request load stats - update our local
# state with these.
stats = self.engines[eng_index].request_counts
stats_step = scheduler_stats.step_counter
stats_wave = scheduler_stats.current_wave
if (
stats_wave > last_stats_wave
or stats_wave == last_stats_wave
and stats_step > last_stats_step
):
if stats_changed:
last_step_counts = self._get_engine_counts(do_copy=True)
last_stats_step = stats_step
last_stats_wave = stats_wave
elif stats_wave != last_stats_wave or (
stats_step != last_stats_step
):
logger.warning(
"Received stats for out-of-order "
"step (%d, %d) from engine %d (expected "
"> (%d, %d))",
stats_wave,
stats_step,
eng_index,
last_stats_wave,
last_stats_step,
)
stats[0] = scheduler_stats.num_waiting_reqs
stats[1] = scheduler_stats.num_running_reqs
stats_changed = True
if (wave := outputs.wave_complete) is not None:
# 2. Notification from rank 0 engine that we've
# moved into the global paused state
# (engines_running==False).
if current_wave <= wave:
new_wave = wave + 1
logger.debug(
"Moving DP wave from %d to %d.", current_wave, new_wave
)
current_wave = new_wave
engines_running = False
wave_state_changed = True
elif (wave := outputs.start_wave) is not None and (
wave > current_wave
or (wave == current_wave and not engines_running)
):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger.debug(
"Starting wave %d after notification of "
"stale wave request from engine.",
wave,
)
current_wave = wave
engines_running = True
wave_state_changed = True
self._send_start_wave(publish_back, wave, eng_index)
if wave_state_changed:
message = (None, current_wave, engines_running)
publish_front.send(msgspec.msgpack.encode(message))
@staticmethod
def _send_start_wave(
socket: zmq.Socket, wave: int, exclude_engine_index: int | None
):
"""Broadcast the START_DP_WAVE message to all the engines.
It includes the current wave number and index of engine which
has already received a request with this wave number and so doesn't
require additional notification.
"""
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
socket.send_multipart((EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
"""Return list of [waiting, running] count lists for each engine."""
if do_copy:
return [copy.copy(e.request_counts) for e in self.engines]
return [e.request_counts for e in self.engines]

1455
vllm/v1/engine/core.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,351 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import tokenizers
from packaging import version
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.detokenizer_utils import (
convert_prompt_ids_to_tokens,
detokenize_incrementally,
)
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
# Only tokenizers >= 0.21.1 supports DecodeStream used for
# FastIncrementalDetokenizer.
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1")
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
class IncrementalDetokenizer:
def __init__(self):
self.token_ids: list[int] = []
@property
def output_token_ids(self) -> list[int]:
return self.token_ids
def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
self.token_ids.extend(new_token_ids)
return None
def get_next_output_text(self, finished: bool, delta: bool) -> str:
return ""
@classmethod
def from_new_request(
cls,
tokenizer: TokenizerLike | None,
request: EngineCoreRequest,
) -> "IncrementalDetokenizer":
assert request.sampling_params is not None
if tokenizer is None:
# No tokenizer => skipping detokenization.
return IncrementalDetokenizer()
if USE_FAST_DETOKENIZER and isinstance(tokenizer, PreTrainedTokenizerFast):
# Fast tokenizer => use tokenizers library DecodeStream.
return FastIncrementalDetokenizer(tokenizer, request)
# Fall back to slow python-based incremental detokenization.
return SlowIncrementalDetokenizer(tokenizer, request)
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
def __init__(self, request: EngineCoreRequest):
super().__init__()
# Stop strings
params = request.sampling_params
assert params is not None
stop_list: list[str]
if params.stop is None:
stop_list = []
elif isinstance(params.stop, str):
stop_list = [params.stop]
else:
stop_list = params.stop
self.stop = stop_list
self.min_tokens = params.min_tokens
self.include_stop_str_in_output = params.include_stop_str_in_output
# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
if self.stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.stop_buffer_length = 0
self._last_output_text_offset: int = 0
# Generation data
self.output_text = ""
def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
"""
Update RequestState for the request_id by:
1) Detokenize the new token ids incrementally.
2) Evaluate stop criteria.
Return matched stop string or None.
"""
if not new_token_ids:
# Skip detokenization if no new token ids.
return None
if stop_terminated and not self.include_stop_str_in_output:
# If stop-terminated, exclude last token from detokenization
# based on include_stop_str_in_output parameter.
skipped_stop_token_id = new_token_ids[-1]
new_token_ids = new_token_ids[:-1]
else:
skipped_stop_token_id = None
# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
stop_check_offset = len(self.output_text)
for new_token_id in new_token_ids:
self.token_ids.append(new_token_id)
self.output_text += self.decode_next(new_token_id)
# Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
if self.min_tokens and len(self.output_token_ids) <= self.min_tokens:
stop_check_offset = len(self.output_text)
if skipped_stop_token_id is not None:
# Cleanup after skipping detokenization.
self.token_ids.append(skipped_stop_token_id)
# 2) Evaluate stop strings.
stop_string = None
if self.stop and len(self.output_token_ids) > self.min_tokens:
stop = check_stop_strings(
output_text=self.output_text,
new_char_count=len(self.output_text) - stop_check_offset,
stop=self.stop,
include_in_output=self.include_stop_str_in_output,
)
if stop is not None:
stop_string, truncate_to = stop
if truncate_to != -1:
self.output_text = self.output_text[:truncate_to]
return stop_string
@abstractmethod
def decode_next(self, next_token_id: int) -> str:
raise NotImplementedError
def get_next_output_text(self, finished: bool, delta: bool) -> str:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished.
buffer_length = 0 if finished else self.stop_buffer_length
if not delta:
return (
self.output_text[:-buffer_length]
if buffer_length
else (self.output_text)
)
length = len(self.output_text) - buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
return self.output_text[last_offset:length]
return ""
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest):
super().__init__(request)
sampling_params = request.sampling_params
assert sampling_params is not None
self.request_id = request.request_id
self.skip_special_tokens = sampling_params.skip_special_tokens
self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
self.tokenizer: Tokenizer = tokenizer._tokenizer
# Find a safe place to start.
prompt_token_ids = request.prompt_token_ids or []
prompt_suffix = prompt_token_ids
prompt_len = len(prompt_suffix)
if prompt_len > 4:
for i in range(4, min(prompt_len + 1, 24)):
suffix = prompt_token_ids[-i:]
if "<EFBFBD>" not in self.tokenizer.decode(suffix):
prompt_suffix = suffix
break
# Prime the stream.
for tid in prompt_suffix:
self._protected_step(tid)
self.spaces_between_special_tokens = (
sampling_params.skip_special_tokens
or sampling_params.spaces_between_special_tokens
)
if not self.spaces_between_special_tokens:
# Store dict of added token ids so that we can suppress
# the spaces between them.
if (
added_token_ids := getattr(self.tokenizer, "added_token_ids", None)
) is None:
self.tokenizer.added_token_ids = added_token_ids = {
tid: tok.content
for tid, tok in self.tokenizer.get_added_tokens_decoder().items()
}
if added_token_ids:
self.last_special = False
self.added_token_ids = added_token_ids
else:
# No added tokens.
self.spaces_between_special_tokens = True
def decode_next(self, next_token_id: int) -> str:
token = self._protected_step(next_token_id)
if not self.spaces_between_special_tokens:
special_token = self.added_token_ids.get(next_token_id)
is_special = special_token is not None
if is_special and self.last_special:
# Return raw token string without any prefixed spaces.
token = special_token
self.last_special = is_special
return token or ""
def _protected_step(self, next_token_id: int) -> str | None:
try:
token = self.stream.step(self.tokenizer, next_token_id)
except (OverflowError, TypeError):
# Handle rare observed overflow, still to be diagnosed.
# See https://github.com/vllm-project/vllm/issues/21951.
logger.exception("Encountered invalid token id: %r", next_token_id)
token = None
except Exception as e:
if not str(e).startswith(INVALID_PREFIX_ERR_MSG):
raise e
# Recover from edge case where tokenizer can produce non-monotonic,
# invalid UTF-8 output, which breaks the internal state of
# tokenizers' DecodeStream.
# See https://github.com/vllm-project/vllm/issues/17448.
logger.warning(
"Encountered invalid prefix detokenization error"
" for request %s, resetting decode stream.",
self.request_id,
)
self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
token = self.stream.step(self.tokenizer, next_token_id)
return token
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: TokenizerLike, request: EngineCoreRequest):
super().__init__(request)
self.tokenizer = tokenizer
params = request.sampling_params
assert params is not None
self.prompt_len = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds
)
# Metadata for incremental detokenization.
if request.prompt_token_ids is not None:
self.tokens, self.prefix_offset, self.read_offset = (
convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=params.skip_special_tokens,
)
)
else:
# Prompt embedding requests cannot be detokenized, in general.
self.tokens = [""] * self.prompt_len
self.prefix_offset = 0
self.read_offest = 0
self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len)
self.skip_special_tokens = params.skip_special_tokens
self.spaces_between_special_tokens = params.spaces_between_special_tokens
@property
def output_token_ids(self) -> list[int]:
return (
self.token_ids
if not self.prompt_len
else (self.token_ids[self.prompt_len :])
)
def decode_next(self, next_token_id: int) -> str:
new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=self.token_ids,
prev_tokens=self.tokens,
prefix_offset=self.prefix_offset,
read_offset=self.read_offset,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
)
self.tokens.extend(new_tokens)
self.prefix_offset = prefix_offset
self.read_offset = read_offset
return decoded_text
def check_stop_strings(
output_text: str,
new_char_count: int,
stop: list[str],
include_in_output: bool,
) -> tuple[str, int] | None:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns tuple (stop_string, offset) if matched or else None.
Where stop_string is the matched stop string and offset is the
length to which output_text should be truncated, or -1 for no
truncation.
"""
if not new_char_count or not stop:
return None
for stop_str in stop:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = output_text.find(stop_str, 1 - new_char_count - stop_string_len)
if stop_index == -1:
continue
if include_in_output:
# Truncate to end of stop string.
stop_index += stop_string_len
if stop_index >= len(output_text):
# No truncation required.
return stop_str, -1
# Truncate the output text to either the beginning
# or end of the stop string.
return stop_str, stop_index
return None

View File

@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
class EngineGenerateError(Exception):
"""Raised when a AsyncLLM.generate() fails. Recoverable."""
pass
class EngineDeadError(Exception):
"""Raised when the EngineCore dies. Unrecoverable."""
def __init__(self, *args, suppress_context: bool = False, **kwargs):
ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501
super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs)
# Make stack trace clearer when using with LLMEngine by
# silencing irrelevant ZMQError.
self.__suppress_context__ = suppress_context

View File

@@ -0,0 +1,643 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Mapping
from typing import Any, Literal, cast
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar
from vllm.v1.structured_output.backend_lm_format_enforcer import (
validate_structured_output_request_lm_format_enforcer,
)
from vllm.v1.structured_output.backend_outlines import (
validate_structured_output_request_outlines,
)
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar
logger = init_logger(__name__)
class InputProcessor:
def __init__(
self,
vllm_config: VllmConfig,
tokenizer: TokenizerLike | None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.structured_outputs_config = vllm_config.structured_outputs_config
self.generation_config_fields = self.model_config.try_get_generation_config()
self.mm_registry = mm_registry
self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry)
self.input_preprocessor = InputPreprocessor(
self.model_config,
tokenizer,
mm_registry,
mm_processor_cache=self.mm_processor_cache,
)
@property
def tokenizer(self) -> TokenizerLike | None:
return self.input_preprocessor.tokenizer
def _validate_logprobs(
self,
params: SamplingParams,
) -> None:
max_logprobs = self.model_config.max_logprobs
if max_logprobs == -1:
max_logprobs = self.model_config.get_vocab_size()
# Validate sample logprobs.
if params.logprobs:
num_logprobs = params.logprobs
if num_logprobs == -1:
num_logprobs = self.model_config.get_vocab_size()
if num_logprobs > max_logprobs:
raise ValueError(
f"Requested sample logprobs of {num_logprobs}, "
f"which is greater than max allowed: {max_logprobs}"
)
# Validate prompt logprobs.
if params.prompt_logprobs:
num_prompt_logprobs = params.prompt_logprobs
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.model_config.get_vocab_size()
if num_prompt_logprobs > max_logprobs:
raise ValueError(
f"Requested prompt logprobs of {num_prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}"
)
def _validate_sampling_params(
self,
params: SamplingParams,
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
if params.allowed_token_ids is None:
return
if not params.allowed_token_ids:
raise ValueError("allowed_token_ids is not None and empty!")
if self.tokenizer is None:
# When skip_tokenizer_init=True, we can't validate token IDs
# Skip validation and let the model handle invalid tokens
return
vocab_size = len(self.tokenizer)
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
raise ValueError("allowed_token_ids contains out-of-vocab token id!")
def _validate_logit_bias(
self,
params: SamplingParams,
) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not params.logit_bias:
return
vocab_size = self.model_config.get_vocab_size()
invalid_token_ids = []
for token_id in params.logit_bias:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if invalid_token_ids:
raise ValueError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}"
)
def _validate_supported_sampling_params(
self,
params: SamplingParams,
) -> None:
# Logits processors not supported.
if params.logits_processors:
raise ValueError(
"vLLM V1 does not support per request user provided logits processors."
)
# Async scheduling + spec decode currently incompatible with some
# sampling parameters.
if (
self.vllm_config.speculative_config is not None
and self.vllm_config.scheduler_config.async_scheduling
and (
params.frequency_penalty != 0.0
or params.presence_penalty != 0.0
or params.repetition_penalty != 1.0
or params.bad_words_token_ids
or params.structured_outputs
)
):
raise ValueError(
"async scheduling with spec decoding doesn't yet support "
"penalties, bad words or structured outputs in sampling parameters."
)
def _validate_params(
self,
params: SamplingParams | PoolingParams,
):
"""
Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server.
"""
if isinstance(params, PoolingParams):
return
self._validate_logprobs(params)
self._validate_sampling_params(params)
self._validate_supported_sampling_params(params)
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
"""
Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s).
Only checks lengths; `None` entries are allowed and will be
auto-hashed downstream.
"""
def _validate_single_prompt(single_prompt: dict | str) -> None:
if not isinstance(single_prompt, dict):
return
mm_data = single_prompt.get("multi_modal_data")
mm_uuids = single_prompt.get("multi_modal_uuids")
if not mm_data or not mm_uuids:
return
import torch
def _get_len(items: object):
if isinstance(items, dict): # Embedding inputs
return _get_len(next(iter(items.values()))) if items else 1
if isinstance(items, list):
return len(items)
if isinstance(items, torch.Tensor):
# To keep backwards compatibility for single item embedding input
return 1 if getattr(items, "_is_single_item", False) else len(items)
return 1
for modality, items in mm_data.items():
if modality in mm_uuids:
data_len = _get_len(items)
uuid_len = _get_len(mm_uuids[modality])
if uuid_len != data_len:
raise ValueError(
f"multi_modal_uuids for modality {modality!r} "
"must have same length as data: got "
f"{uuid_len} uuids vs {data_len} items."
)
else:
raise ValueError(
f"multi_modal_uuids for modality {modality!r} must "
"be provided if multi_modal_data is provided."
)
# Handle explicit encoder/decoder prompts or singleton prompt
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc = prompt.get("encoder_prompt")
dec = prompt.get("decoder_prompt")
if enc is not None:
_validate_single_prompt(cast(dict | str, enc))
if dec is not None:
_validate_single_prompt(cast(dict | str, dec))
else:
_validate_single_prompt(prompt) # type: ignore[arg-type]
def _validate_lora(self, lora_request: LoRARequest | None) -> None:
if lora_request is None:
return
# LoRA request passed in while LoRA is not enabled
if not self.lora_config:
raise ValueError(
f"Got lora_request {lora_request} but LoRA is not enabled!"
)
if self.tokenizer is not None:
logger.warning_once(
"vLLM has deprecated support for supporting different "
"tokenizers for different LoRAs. By default, vLLM uses base "
"model's tokenizer. If you are using a LoRA "
"with its own tokenizer, consider specifying `--tokenizer "
"[lora_path]` to use the LoRA tokenizer."
)
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.structured_outputs or not self.structured_outputs_config:
return
if self.model_config.skip_tokenizer_init and params.structured_outputs:
raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
)
backend = self.structured_outputs_config.backend
if _backend := params.structured_outputs._backend:
# Request-level backend selection is not supported.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_backend_was_auto` field set in the params.
if backend != _backend and not (
backend == "auto" and params.structured_outputs._backend_was_auto
):
raise ValueError(
"Request-level structured output backend selection is not "
f"supported. The request specified '{_backend}', but vLLM "
f"was initialised with '{backend}'. This error can be "
"resolved by removing '_backend' from the request."
)
else:
params.structured_outputs._backend = backend
# Request content validation
if (
isinstance(params.structured_outputs.choice, list)
and not params.structured_outputs.choice
):
# It is invalid for choice to be an empty list
raise ValueError(
f"Choice '{params.structured_outputs.choice}' cannot be an empty list" # noqa: E501
)
# Reject empty string grammar early to avoid engine-side crashes
if (
isinstance(params.structured_outputs.grammar, str)
and params.structured_outputs.grammar.strip() == ""
):
raise ValueError("structured_outputs.grammar cannot be an empty string")
if backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_xgrammar_grammar(params)
elif backend.startswith("guidance"):
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
if isinstance(self.tokenizer, MistralTokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'guidance' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_guidance_grammar(params, tokenizer=None)
elif backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(params)
elif backend == "lm-format-enforcer":
# lm format enforcer backend
if isinstance(self.tokenizer, MistralTokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_structured_output_request_lm_format_enforcer(params)
else:
# NOTE: backend must be "auto" here, because we have
# checked supported_backends above.
# In this mode, we set opinionated defaults based on what we think
# will satisfy the most use cases without having to worry about
# this setting. We include fallback behavior here, but not with any
# other setting where a specific backend was specified.
try:
validate_xgrammar_grammar(params)
params.structured_outputs._backend = "xgrammar"
except ValueError:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar.
if isinstance(self.tokenizer, MistralTokenizer):
# Fall back to outlines if the tokenizer is Mistral
validate_structured_output_request_outlines(params)
params.structured_outputs._backend = "outlines"
else:
# Fall back to guidance by default.
validate_guidance_grammar(params, tokenizer=None)
params.structured_outputs._backend = "guidance"
# Remember that this backend was set automatically
params.structured_outputs._backend_was_auto = True
def _maybe_build_mm_uuids(
self,
request_id: str,
prompt: PromptType,
) -> MultiModalUUIDDict | None:
"""Build per-item multimodal hash overrides when enabled. In this case,
multimodal data items are identified by their request id, modality and
index rather than their content.
Returns a dictionary of modality -> list[str] of overrides, or None if
disabled or no multimodal data is present.
"""
def _extract_mm_data(p: PromptType):
if isinstance(p, dict) and "encoder_prompt" in p:
enc = p.get("encoder_prompt")
if isinstance(enc, dict):
return enc.get("multi_modal_data")
return None
if isinstance(p, dict):
return p.get("multi_modal_data")
return None
mm_data = _extract_mm_data(prompt)
if not mm_data:
return None
mm_uuids: dict[str, list[str | None] | str] = {}
for modality, data in mm_data.items():
# Hash each item for embedding inputs.
n = (
len(data)
if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
else 1
)
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
return mm_uuids
def process_inputs(
self,
request_id: str,
prompt: PromptType,
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
) -> EngineCoreRequest:
self._validate_lora(lora_request)
self._validate_params(params)
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (
0 <= data_parallel_rank < data_parallel_size
):
raise ValueError(
f"data_parallel_rank {data_parallel_rank} "
f"is out of range [0, {data_parallel_size})."
)
if arrival_time is None:
arrival_time = time.time()
# Optionally generate multimodal hash overrides to avoid hashing
# multimodal data items by their content as their identifiers.
# NOTE: when users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# request id-modality-index as multimodal hash overrides.
if (
self.model_config.multimodal_config
and self.model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.cache_config.enable_prefix_caching
):
mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
else:
# Otherwise, use user-provided uuids as multimodal hash overrides
# if provided.
self._validate_multi_modal_uuids(prompt)
if isinstance(prompt, dict):
mm_uuids = cast(
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
)
else:
mm_uuids = None
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
from vllm.platforms import current_platform
current_platform.validate_request(
prompt=prompt,
params=params,
processed_inputs=processed_inputs,
)
eos_token_id = self.input_preprocessor.get_eos_token_id()
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._validate_model_inputs(encoder_inputs, decoder_inputs)
# Mypy can be conservative for TypedDict unions; normalize access.
if decoder_inputs["type"] == "embeds":
prompt_token_ids = None
prompt_embeds = decoder_inputs["prompt_embeds"]
else:
prompt_token_ids = decoder_inputs["prompt_token_ids"]
prompt_embeds = None
sampling_params = None
pooling_params = None
if isinstance(params, SamplingParams):
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
seq_len = length_from_prompt_token_ids_or_embeds(
prompt_token_ids, prompt_embeds
)
sampling_params.max_tokens = self.model_config.max_model_len - seq_len
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id
)
if self.tokenizer is not None:
sampling_params.update_from_tokenizer(self.tokenizer)
else:
pooling_params = params.clone()
# Multimodal related.
mm_features: list[MultiModalFeatureSpec] | None = None
if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
decoder_mm_positions = decoder_inputs["mm_placeholders"]
decoder_mm_hashes = decoder_inputs["mm_hashes"]
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
mm_features = []
for modality, idx in sorted_mm_idxs:
mm_features.append(
MultiModalFeatureSpec(
data=decoder_mm_inputs[modality][idx],
modality=modality,
identifier=decoder_mm_hashes[modality][idx],
mm_position=decoder_mm_positions[modality][idx],
)
)
return EngineCoreRequest(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
prompt_embeds=prompt_embeds,
mm_features=mm_features,
sampling_params=sampling_params,
pooling_params=pooling_params,
eos_token_id=eos_token_id,
arrival_time=arrival_time,
lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"),
priority=priority,
data_parallel_rank=data_parallel_rank,
trace_headers=trace_headers,
)
def _validate_model_inputs(
self, encoder_inputs: SingletonInputs | None, decoder_inputs: SingletonInputs
):
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs, prompt_type="encoder")
self._validate_model_input(decoder_inputs, prompt_type="decoder")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
*,
prompt_type: Literal["encoder", "decoder"],
):
model_config = self.model_config
prompt_ids = (
None
if prompt_inputs["type"] == "embeds"
else prompt_inputs["prompt_token_ids"]
)
prompt_embeds = (
prompt_inputs["prompt_embeds"]
if prompt_inputs["type"] == "embeds"
else None
)
prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds)
if not prompt_ids:
if prompt_type == "encoder" and model_config.is_multimodal_model:
pass # Mllama may have empty encoder inputs for text-only data
elif prompt_inputs["type"] == "embeds":
pass # Prompt embeds should not have prompt_ids.
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
tokenizer = self.tokenizer
if tokenizer is not None:
max_input_id = max(prompt_ids or [], default=0)
# NOTE: tokenizer.max_token_id is the tokenizers vocab size while
# self.model_config.get_vocab_size() is the models vocab size.
# For Qwen3 models, the language model has extra tokens that do
# not exist in the tokenizer, and vice versa for multimodal
# placeholder tokens in some multimodal models.
# See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
# and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501
# Here we take the max of the two to determine if a token id is
# truly out-of-vocabulary.
if max_input_id > max(
tokenizer.max_token_id, self.model_config.get_vocab_size() - 1
):
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if prompt_len > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config,
tokenizer=tokenizer,
)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper
if model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well."
)
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens."
)
raise ValueError(
f"The {prompt_type} prompt (length {prompt_len}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}"
)
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
if (
prompt_len == max_prompt_len
and prompt_type == "decoder"
and not model_config.is_multimodal_model
and self.model_config.runner_type != "pooling"
):
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens (prompt + requested output tokens)."
)
raise ValueError(
f"The {prompt_type} prompt (length {prompt_len}) plus the number of "
f"requested output tokens (at least 1) is longer than the maximum "
f"model length of {max_prompt_len}. {suggestion}"
)
def stat_mm_cache(self) -> MultiModalCacheStats | None:
return self.input_preprocessor.stat_mm_cache()
def clear_mm_cache(self) -> None:
self.input_preprocessor.clear_mm_cache()

View File

@@ -0,0 +1,414 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Callable, Mapping
from copy import copy
from typing import Any, cast
import torch.nn as nn
from typing_extensions import TypeVar, deprecated
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.distributed.parallel_state import get_dp_group
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.tracing import init_tracer
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class LLMEngine:
"""Legacy LLMEngine for backwards compatibility."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
aggregate_engine_logging: bool = False,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: list[StatLoggerFactory] | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
) -> None:
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.log_stats = log_stats
executor_backend = self.vllm_config.parallel_config.distributed_executor_backend
parallel_config = vllm_config.parallel_config
self.external_launcher_dp = (
parallel_config.data_parallel_size > 1
and executor_backend == "external_launcher"
)
# important: init dp group before init the engine_core
# In the decoupled engine case this is handled in EngineCoreProc.
if (
not multiprocess_mode
and parallel_config.data_parallel_size > 1
and not self.external_launcher_dp
):
self.dp_group = parallel_config.stateless_init_dp_group()
else:
self.dp_group = None
self.should_execute_dummy_batch = False
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_tokenizer_from_config(self.model_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor(
self.vllm_config,
self.model_config.io_processor_plugin,
)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(
self.tokenizer,
log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval,
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
tracer = init_tracer("vllm.llm_engine", endpoint)
self.output_processor.tracer = tracer
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(
multiprocess_mode=multiprocess_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)
self.logger_manager: StatLoggerManager | None = None
if self.log_stats:
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
custom_stat_loggers=stat_loggers,
enable_default_loggers=log_stats,
aggregate_engine_logging=aggregate_engine_logging,
)
self.logger_manager.log_engine_initialized()
if not multiprocess_mode:
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
if self.external_launcher_dp:
# If we use DP in external launcher mode, we reuse the
# existing DP group used for data communication.
self.dp_group = get_dp_group().cpu_group
# Don't keep the dummy data in memory
self.reset_mm_cache()
@property
@deprecated(
"`LLMEngine.processor` has been renamed to `LLMEngine.input_processor`. "
"The old name will be removed in v0.14."
)
def processor(self):
return self.input_processor
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: list[StatLoggerFactory] | None = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
return cls(
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=(not disable_log_stats),
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING,
)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: list[StatLoggerFactory] | None = None,
enable_multiprocessing: bool = False,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.")
enable_multiprocessing = True
# Create the LLMEngine.
return cls(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing,
)
def get_num_unfinished_requests(self) -> int:
return self.output_processor.get_num_unfinished_requests()
def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
if self.dp_group is None:
return has_unfinished or self.engine_core.dp_engines_running()
return self.has_unfinished_requests_dp(has_unfinished)
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
self.dp_group, has_unfinished
)
if not has_unfinished and aggregated_has_unfinished:
self.should_execute_dummy_batch = True
return aggregated_has_unfinished
@classmethod
def validate_outputs(cls, outputs, output_type):
return outputs
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks()
def abort_request(self, request_ids: list[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
request_ids = self.output_processor.abort_requests(request_ids)
self.engine_core.abort_requests(request_ids)
def add_request(
self,
request_id: str,
prompt: EngineCoreRequest | PromptType,
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
prompt_text: str | None = None,
) -> None:
# Validate the request_id type.
if not isinstance(request_id, str):
raise TypeError(f"request_id must be a string, got {type(request_id)}")
# Process raw inputs into the request.
if isinstance(prompt, EngineCoreRequest):
request = prompt
else:
assert prompt_text is None
request = self.input_processor.process_inputs(
request_id,
prompt,
params,
arrival_time,
lora_request,
tokenization_kwargs,
trace_headers,
priority,
)
if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
# Use cloned params that may have been updated in process_inputs()
params = request.params
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
# Make a new RequestState and queue.
self.output_processor.add_request(request, prompt_text, None, 0)
# Add the request to EngineCore.
self.engine_core.add_request(request)
return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, child_params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = child_params
# Make a new RequestState and queue.
self.output_processor.add_request(
child_request, prompt_text, parent_req, idx
)
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput | PoolingRequestOutput]:
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
self.engine_core.execute_dummy_batch()
return []
# 1) Get EngineCoreOutput from the EngineCore.
with record_function_or_nullcontext("llm_engine step: get_output"):
outputs = self.engine_core.get_output()
# 2) Process EngineCoreOutputs.
with record_function_or_nullcontext("llm_engine step: process_outputs"):
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats,
)
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
# 3) Abort any reqs that finished due to stop strings.
with record_function_or_nullcontext("llm_engine step: abort_requests"):
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
# 4) Record stats
with record_function_or_nullcontext("llm_engine step: record_stats"):
if self.logger_manager is not None and outputs.scheduler_stats is not None:
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=self.input_processor.stat_mm_cache(),
)
self.do_log_stats_with_interval()
return processed_outputs.request_outputs
def start_profile(self):
self.engine_core.profile(True)
def stop_profile(self):
self.engine_core.profile(False)
def reset_mm_cache(self):
self.input_processor.clear_mm_cache()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.engine_core.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1):
self.engine_core.sleep(level)
if self.logger_manager is not None:
self.logger_manager.record_sleep_state(1, level)
def wake_up(self, tags: list[str] | None = None):
self.engine_core.wake_up(tags)
if self.logger_manager is not None:
self.logger_manager.record_sleep_state(0, 0)
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()
def get_metrics(self) -> list[Metric]:
assert self.log_stats, "Stat logging disabled"
return get_metrics_snapshot()
@property
def tokenizer(self) -> TokenizerLike | None:
return self.input_processor.tokenizer
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return self.tokenizer
def do_log_stats(self) -> None:
"""Log stats if logging is enabled."""
if self.logger_manager:
self.logger_manager.log()
def do_log_stats_with_interval(self) -> None:
"""Log stats when the time interval has passed."""
now = time.time()
if not hasattr(self, "_last_log_time"):
self._last_log_time = now
if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL:
self.do_log_stats()
self._last_log_time = now
def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
return self.engine_core.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
"""Remove an already loaded LoRA adapter."""
return self.engine_core.remove_lora(lora_id)
def list_loras(self) -> set[int]:
"""List all registered adapters."""
return self.engine_core.list_loras()
def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id)
def collective_rpc(
self,
method: str | Callable[[WorkerBase], _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
return self.collective_rpc("apply_model", args=(func,))
def __del__(self):
dp_group = getattr(self, "dp_group", None)
if dp_group is not None and not self.external_launcher_dp:
stateless_destroy_torch_distributed_process_group(dp_group)

189
vllm/v1/engine/logprobs.py Normal file
View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from dataclasses import dataclass
from vllm.logger import init_logger
from vllm.logprobs import (
PromptLogprobs,
SampleLogprobs,
append_logprobs_for_next_position,
create_prompt_logprobs,
create_sample_logprobs,
)
from vllm.tokenizers.detokenizer_utils import (
TokenizerLike,
convert_ids_list_to_tokens,
)
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
logger = init_logger(__name__)
NONES = itertools.repeat(None)
@dataclass
class LogprobsProcessor:
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: TokenizerLike | None
# Logprobs for this request
logprobs: SampleLogprobs | None
prompt_logprobs: PromptLogprobs | None
cumulative_logprob: float | None
num_logprobs: int | None
num_prompt_logprobs: int | None
@classmethod
def from_new_request(
cls,
tokenizer: TokenizerLike | None,
request: EngineCoreRequest,
) -> "LogprobsProcessor":
sampling_params = request.sampling_params
assert sampling_params is not None
num_logprobs = sampling_params.logprobs
num_prompt_logprobs = sampling_params.prompt_logprobs
return cls(
tokenizer=tokenizer,
cumulative_logprob=(None if num_logprobs is None else 0.0),
logprobs=(
None
if num_logprobs is None
else create_sample_logprobs(sampling_params.flat_logprobs)
),
prompt_logprobs=(
None
if num_prompt_logprobs is None
else create_prompt_logprobs(sampling_params.flat_logprobs)
),
num_prompt_logprobs=num_prompt_logprobs,
num_logprobs=num_logprobs,
)
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
"""Update with sample logprobs from EngineCore.
Outer lists are only of len > 1 if EngineCore made
>1 tokens in prior step (e.g. in spec decoding).
Args:
logprobs_lists: the lists of logprob tokens, logprobs, and ranks.
"""
assert self.num_logprobs is not None
assert self.logprobs is not None
assert self.cumulative_logprob is not None
token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists
for rank_np, logprobs_np, token_ids_np in zip(
ranks_lst, logprobs_lst, token_ids_lst
):
rank = rank_np.tolist()
logprobs = logprobs_np.tolist()
token_ids = token_ids_np.tolist()
# Detokenize (non-incrementally).
decoded_tokens = (
NONES
if self.tokenizer is None
else (convert_ids_list_to_tokens(self.tokenizer, token_ids))
)
# Sampler puts the sampled logprob in first.
sampled_token_logprob = logprobs[0]
self.cumulative_logprob += sampled_token_logprob
# Update with the Logprob container for this pos.
append_logprobs_for_next_position(
self.logprobs,
token_ids,
logprobs,
decoded_tokens,
rank,
self.num_logprobs,
)
def _update_prompt_logprobs(
self,
prompt_logprobs_tensors: LogprobsTensors,
) -> None:
"""Update with prompt logprobs from EngineCore.
Args:
prompt_logprobs_tensors: tuple containing the prompt logprobs
tensors.
"""
# Prompt logprobs are enabled.
assert self.num_prompt_logprobs is not None
assert self.prompt_logprobs is not None
token_ids, logprobs, ranks = prompt_logprobs_tensors
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = (
None
if self.tokenizer is None
else (
convert_ids_list_to_tokens(self.tokenizer, token_ids.flatten().tolist())
)
)
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the torch tensors.
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
# Make Logprob for each position.
for pos in range(num_prompt_tokens):
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = (
NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
)
# Update with the Logprob container for this pos.
append_logprobs_for_next_position(
self.prompt_logprobs,
token_ids[pos],
prompt_logprobs[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
self.num_prompt_logprobs,
)
def pop_prompt_logprobs(self) -> PromptLogprobs | None:
"""Pop and return all request prompt logprobs
The logprobs processor aggregates prompt chunk logprobs
over one or more prefill chunks. This method returns
all prompt logprobs at once and then forgets them.
Ensures correct RequestOutputKind.DELTA semantics
wherein all prompt logprobs are returned at once at
the end of prefill.
Returns:
None if prompt logprobs are disabled for this request.
List of all prompt logprobs, otherwise.
"""
plp = self.prompt_logprobs
if plp:
self.prompt_logprobs = []
return plp
def update_from_output(self, output: EngineCoreOutput) -> None:
if output.new_logprobs is not None:
self._update_sample_logprobs(output.new_logprobs)
if output.new_prompt_logprobs_tensors is not None:
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)

View File

@@ -0,0 +1,659 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, cast
import torch
from vllm.outputs import (
CompletionOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
)
from vllm.sampling_params import RequestOutputKind
from vllm.tokenizers import TokenizerLike
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import (
IterationStats,
LoRARequestStates,
RequestStateStats,
SchedulerStats,
)
class RequestOutputCollector:
"""
Collects streamed RequestOutputs per individual request,
for hand-off to the consuming asyncio generate task.
When streaming deltas, RequestOutputs are merged if the
producer gets ahead of the consumer.
"""
def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
self.ready = asyncio.Event()
def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
"""Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
self.output = output
self.ready.set()
elif isinstance(self.output, RequestOutput) and isinstance(
output, RequestOutput
):
# This ensures that request outputs with different request indexes
# (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate)
elif isinstance(self.output, PoolingRequestOutput) and isinstance(
output, PoolingRequestOutput
):
self.output = output
async def get(self) -> RequestOutput | PoolingRequestOutput:
"""Get operation blocks on put event."""
while (output := self.output) is None:
await self.ready.wait()
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
"""Non-blocking get operation."""
output = self.output
if output is not None:
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
@dataclass
class OutputProcessorOutput:
request_outputs: list[RequestOutput | PoolingRequestOutput]
reqs_to_abort: list[str]
class RequestState:
def __init__(
self,
request_id: str,
parent_req: ParentRequest | None,
request_index: int,
lora_name: str | None,
output_kind: RequestOutputKind,
prompt: str | None,
prompt_token_ids: list[int] | None,
prompt_embeds: torch.Tensor | None,
logprobs_processor: LogprobsProcessor | None,
detokenizer: IncrementalDetokenizer | None,
max_tokens_param: int | None,
arrival_time: float,
queue: RequestOutputCollector | None,
log_stats: bool,
stream_interval: int,
top_p: float | None = None,
n: int | None = None,
temperature: float | None = None,
):
self.request_id = request_id
self.parent_req = parent_req
self.request_index = request_index
self.lora_name = lora_name
self.output_kind = output_kind
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_embeds = prompt_embeds
self.prompt_len = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
)
self.logprobs_processor = logprobs_processor
self.detokenizer = detokenizer
self.max_tokens_param = max_tokens_param
self.top_p = top_p
self.n = n
self.temperature = temperature
self.is_prefilling = True
self.queue = queue
self.num_cached_tokens = 0
self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
# Stream Interval
self.stream_interval = stream_interval
self.sent_tokens_offset = 0 # Offset of sent tokens
@classmethod
def from_new_request(
cls,
tokenizer: TokenizerLike | None,
request: EngineCoreRequest,
prompt: str | None,
parent_req: ParentRequest | None,
request_index: int,
queue: RequestOutputCollector | None,
log_stats: bool,
stream_interval: int,
) -> "RequestState":
if sampling_params := request.sampling_params:
if not sampling_params.detokenize:
tokenizer = None
output_kind = sampling_params.output_kind
logprobs_processor = LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
)
detokenizer = IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
)
max_tokens_param = sampling_params.max_tokens
top_p = sampling_params.top_p
n = sampling_params.n
temperature = sampling_params.temperature
else:
logprobs_processor = None
detokenizer = None
max_tokens_param = None
top_p = None
n = None
temperature = None
assert request.pooling_params is not None
output_kind = request.pooling_params.output_kind
return cls(
request_id=request.request_id,
parent_req=parent_req,
request_index=request_index,
lora_name=(
request.lora_request.name if request.lora_request is not None else None
),
output_kind=output_kind,
prompt=prompt,
prompt_token_ids=request.prompt_token_ids,
prompt_embeds=request.prompt_embeds,
logprobs_processor=logprobs_processor,
detokenizer=detokenizer,
max_tokens_param=max_tokens_param,
top_p=top_p,
n=n,
temperature=temperature,
arrival_time=request.arrival_time,
queue=queue,
log_stats=log_stats,
stream_interval=stream_interval,
)
def make_request_output(
self,
new_token_ids: list[int],
pooling_output: torch.Tensor | None,
finish_reason: FinishReason | None,
stop_reason: int | str | None,
kv_transfer_params: dict[str, Any] | None = None,
) -> RequestOutput | PoolingRequestOutput | None:
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
if not finished and final_only:
# Only the final output is required in FINAL_ONLY mode.
return None
if self.stream_interval > 1:
assert self.detokenizer is not None
# Send output request only when
# 1. It has finished, or
# 2. It is the first token, or
# 3. It has reached the stream interval number of tokens
if not (
finished
or self.sent_tokens_offset == 0
or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset
>= self.stream_interval
):
return None
if self.output_kind == RequestOutputKind.DELTA:
# Send tokens from the offset in DELTA mode, otherwise all
# tokens are sent.
new_token_ids = self.detokenizer.output_token_ids[
self.sent_tokens_offset :
]
self.sent_tokens_offset = len(self.detokenizer.output_token_ids)
request_id = self.request_id
if pooling_output is not None:
return self._new_request_output(
request_id, [self._new_pooling_output(pooling_output)], finished
)
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
if self.parent_req is None:
outputs = [output]
else:
request_id, outputs, finished = self.parent_req.get_outputs(
request_id, output
)
if not outputs:
return None
return self._new_request_output(
request_id, outputs, finished, kv_transfer_params
)
def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput] | list[PoolingOutput],
finished: bool,
kv_transfer_params: dict[str, Any] | None = None,
) -> RequestOutput | PoolingRequestOutput:
first_output = outputs[0]
if isinstance(first_output, PoolingOutput):
assert len(outputs) == 1
# Prompt embeddings are currently not supported by pooling requests.
assert self.prompt_token_ids is not None
return PoolingRequestOutput(
request_id=request_id,
outputs=first_output,
num_cached_tokens=self.num_cached_tokens,
prompt_token_ids=self.prompt_token_ids,
finished=finished,
)
assert self.logprobs_processor is not None
if self.output_kind == RequestOutputKind.DELTA:
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
else:
prompt_logprobs = self.logprobs_processor.prompt_logprobs
# If prompt embeds were used, put placeholder prompt token ids
prompt_token_ids = self.prompt_token_ids
if prompt_token_ids is None and self.prompt_embeds is not None:
prompt_token_ids = [0] * len(self.prompt_embeds)
return RequestOutput(
request_id=request_id,
prompt=self.prompt,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
outputs=cast(list[CompletionOutput], outputs),
finished=finished,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=self.num_cached_tokens,
metrics=self.stats,
)
def _new_completion_output(
self,
token_ids: list[int],
finish_reason: FinishReason | None,
stop_reason: int | str | None,
) -> CompletionOutput:
assert self.detokenizer is not None
assert self.logprobs_processor is not None
finished = finish_reason is not None
delta = self.output_kind == RequestOutputKind.DELTA
# Prepare text and token_ids, based on delta mode
text = self.detokenizer.get_next_output_text(finished, delta)
if not delta:
token_ids = self.detokenizer.output_token_ids
# Prepare logprobs, based on delta mode
logprobs = self.logprobs_processor.logprobs
if delta and logprobs:
logprobs = logprobs[-len(token_ids) :]
return CompletionOutput(
index=self.request_index,
text=text,
token_ids=token_ids,
logprobs=logprobs,
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
finish_reason=str(finish_reason) if finished else None,
stop_reason=stop_reason if finished else None,
)
def _new_pooling_output(
self,
pooling_output: torch.Tensor,
) -> PoolingOutput:
return PoolingOutput(data=pooling_output)
class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""
def __init__(
self,
tokenizer: TokenizerLike | None,
log_stats: bool,
stream_interval: int = 1,
):
self.log_stats = log_stats
self.tokenizer = tokenizer
self.stream_interval = stream_interval
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates(log_stats)
self.tracer: Tracer | None = None
self._requests_drained = asyncio.Event()
self._requests_drained.set()
def get_num_unfinished_requests(self):
return len(self.request_states)
def has_unfinished_requests(self) -> bool:
return len(self.request_states) > 0
async def wait_for_requests_to_drain(self) -> None:
if not self.request_states:
return
await self._requests_drained.wait()
def propagate_error(self, e: Exception):
"""Propagate error to all generate() tasks."""
for _, state in self.request_states.items():
assert state.queue is not None
state.queue.put(e)
def abort_requests(
self,
request_ids: Iterable[str],
) -> list[str]:
request_ids_to_abort = []
for request_id in request_ids:
req_state = self.request_states.pop(request_id, None)
if req_state is not None:
self.lora_states.request_finished(request_id, req_state.lora_name)
request_ids_to_abort.append(request_id)
# Produce final abort output.
if req_state.queue is not None and (
request_output := req_state.make_request_output(
new_token_ids=[],
# Set pooling_output is not None to
# correctly enter the abort pooling branch
pooling_output=torch.randn(0, device="cpu")
if req_state.detokenizer is None
else None,
finish_reason=FinishReason.ABORT,
stop_reason=None,
kv_transfer_params=None,
)
):
req_state.queue.put(request_output)
elif parent := self.parent_requests.get(request_id):
# Abort children prior to removing the parent.
if parent.child_requests:
child_reqs = list(parent.child_requests)
child_reqs = self.abort_requests(child_reqs)
request_ids_to_abort.extend(child_reqs)
self.parent_requests.pop(request_id, None)
if not self.request_states:
self._requests_drained.set()
return request_ids_to_abort
def add_request(
self,
request: EngineCoreRequest,
prompt: str | None,
parent_req: ParentRequest | None = None,
request_index: int = 0,
queue: RequestOutputCollector | None = None,
) -> None:
request_id = request.request_id
if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.")
req_state = RequestState.from_new_request(
tokenizer=self.tokenizer,
request=request,
prompt=prompt,
parent_req=parent_req,
request_index=request_index,
queue=queue,
log_stats=self.log_stats,
stream_interval=self.stream_interval,
)
if self._requests_drained.is_set():
self._requests_drained.clear()
self.request_states[request_id] = req_state
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req
def process_outputs(
self,
engine_core_outputs: list[EngineCoreOutput],
engine_core_timestamp: float | None = None,
iteration_stats: IterationStats | None = None,
) -> OutputProcessorOutput:
"""
Process the EngineCoreOutputs:
1) Compute stats for logging
2) Detokenize
3) Create and handle RequestOutput objects:
* If there is a queue (for usage with AsyncLLM),
put the RequestOutput objects into the queue for
handling by the per-request generate() tasks.
* If there is no queue (for usage with LLMEngine),
return a list of RequestOutput objects.
NOTE FOR DEVELOPERS
vLLM V1 minimizes the number of python loops over the full
batch to ensure system overheads are minimized. This is the
only function that should loop over EngineCoreOutputs.
If you need to touch every element of the batch, do it from
within the loop below.
"""
request_outputs: list[RequestOutput | PoolingRequestOutput] = []
reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id
req_state = self.request_states.get(req_id)
if req_state is None:
# Ignore output for already-aborted request.
continue
# 1) Compute stats for this iteration.
self._update_stats_from_output(
req_state, engine_core_output, engine_core_timestamp, iteration_stats
)
new_token_ids = engine_core_output.new_token_ids
pooling_output = engine_core_output.pooling_output
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
req_state.num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False
if pooling_output is None:
assert req_state.detokenizer is not None
assert req_state.logprobs_processor is not None
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP
)
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request,
# if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids,
pooling_output,
finish_reason,
stop_reason,
kv_transfer_params,
):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
else:
# LLMEngine: return list of RequestOutputs.
request_outputs.append(request_output)
# Free completed requests.
if finish_reason is not None:
self.request_states.pop(req_id)
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
if not self.request_states:
self._requests_drained.set()
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
# Track per-request stats
self._update_stats_from_finished(
req_state, finish_reason, iteration_stats
)
if self.tracer:
self.do_tracing(engine_core_output, req_state, iteration_stats)
return OutputProcessorOutput(
request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort,
)
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
self.lora_states.update_scheduler_stats(scheduler_stats)
def do_tracing(
self,
engine_core_output: EngineCoreOutput,
req_state: RequestState,
iteration_stats: IterationStats | None,
) -> None:
assert req_state.stats is not None
assert iteration_stats is not None
assert self.tracer is not None
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
trace_context = extract_trace_context(engine_core_output.trace_headers)
prompt_length = length_from_prompt_token_ids_or_embeds(
req_state.prompt_token_ids, req_state.prompt_embeds
)
with self.tracer.start_as_current_span(
"llm_request",
kind=SpanKind.SERVER,
context=trace_context,
start_time=arrival_time_nano_seconds,
) as span:
metrics = req_state.stats
e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
queued_time = metrics.scheduled_ts - metrics.queued_ts
prefill_time = metrics.first_token_ts - metrics.scheduled_ts
decode_time = metrics.last_token_ts - metrics.first_token_ts
inference_time = metrics.last_token_ts - metrics.scheduled_ts
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
metrics.first_token_latency,
)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
span.set_attribute(
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
metrics.num_generation_tokens,
)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
)
# meta
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
if req_state.top_p:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
if req_state.max_tokens_param:
span.set_attribute(
SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
)
if req_state.temperature:
span.set_attribute(
SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
)
if req_state.n:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
def _update_stats_from_output(
self,
req_state: RequestState,
engine_core_output: EngineCoreOutput,
engine_core_timestamp: float | None,
iteration_stats: IterationStats | None,
):
if iteration_stats is None:
return
assert engine_core_timestamp is not None
assert req_state.stats is not None
iteration_stats.update_from_output(
engine_core_output,
engine_core_timestamp,
req_state.is_prefilling,
req_state.prompt_len,
req_state.stats,
self.lora_states,
req_state.lora_name,
)
def _update_stats_from_finished(
self,
req_state: RequestState,
finish_reason: FinishReason | None,
iteration_stats: IterationStats | None,
):
if iteration_stats is None:
return
assert finish_reason is not None
assert req_state.stats is not None
iteration_stats.update_from_finished_request(
finish_reason=finish_reason,
num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
req_state.prompt_token_ids, req_state.prompt_embeds
),
max_tokens_param=req_state.max_tokens_param,
req_stats=req_state.stats,
num_cached_tokens=req_state.num_cached_tokens,
)
self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
ParentRequest.observe_finished_request(
req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
)

View File

@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import copy
from typing import Optional, cast
from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.metrics.stats import IterationStats
class ParentRequest:
"""Info, state & processing for parallel sampling request.
Store parent request ID and sampling params.
Facilitate generating child request sampling params.
"""
request_id: str
sampling_params: SamplingParams
# To track the completion of child requests
child_requests: set[str]
# To aggregate child completions when not streaming
output_aggregator: list[CompletionOutput]
# To find the max number of generated tokens across all children
max_num_generation_tokens: int
# To efficiently obtain child sampling params
cached_child_sampling_params: SamplingParams | None
def __init__(self, request_id: str, sampling_params: SamplingParams) -> None:
self.request_id = request_id
self.sampling_params = sampling_params
self.child_requests = set()
self.output_aggregator = (
[cast(CompletionOutput, None)] * sampling_params.n
if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY)
else []
)
self.max_num_generation_tokens = 0
self.cached_child_sampling_params = None
def _get_child_sampling_params(
self,
index: int,
) -> SamplingParams:
"""Efficiently obtain child `sampling_params`
If `sampling_params.seed` is not `None` then
each child request requires a unique clone of
parent `sampling_params` with a unique seed.
Args:
index: index within `n` child requests
Returns:
Child `sampling_params` instance.
"""
seed = self.sampling_params.seed
if self.cached_child_sampling_params:
# Reuse child sampling_params data structure
return self.cached_child_sampling_params
# Build child sampling_params
child_sampling_params = copy(self.sampling_params)
child_sampling_params.n = 1
if seed is None:
# Cache child sampling_params for later reuse
self.cached_child_sampling_params = child_sampling_params
else:
# Each child gets a clone with a unique seed
child_sampling_params.seed = seed + index
return child_sampling_params
def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
"""Get child request ID and sampling params.
Args:
index: index within `n` child requests.
Returns:
(request ID, sampling_params) tuple
"""
child_req_id = f"{index}_{self.request_id}"
self.child_requests.add(child_req_id)
return child_req_id, self._get_child_sampling_params(index)
@property
def n(self) -> int:
return self.sampling_params.n
def get_outputs(
self,
child_request_id: str,
completion_output: CompletionOutput,
) -> tuple[str, list[CompletionOutput], bool]:
already_finished_and_returned: bool = False
if completion_output.finished():
if child_request_id in self.child_requests:
self.child_requests.remove(child_request_id)
else:
# child request ID is not available in child_requests
# which means the request had finished in previous
# batch step and returned to the client earlier
already_finished_and_returned = True
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
# If streaming, just return the current output
#
# DO NOT output finished and already returned child request to client again
outputs = [] if already_finished_and_returned else [completion_output]
else:
# If not streaming, aggregate the n final outputs.
self.output_aggregator[completion_output.index] = completion_output
outputs = [] if self.child_requests else self.output_aggregator
finished = not self.child_requests
return self.request_id, outputs, finished
def observe_num_generation_tokens(self, num_generation_tokens: int):
self.max_num_generation_tokens = max(
num_generation_tokens, self.max_num_generation_tokens
)
return self.max_num_generation_tokens
@staticmethod
def observe_finished_request(
parent_req: Optional["ParentRequest"],
iteration_stats: IterationStats,
num_generation_tokens: int,
):
n_param = parent_req.n if parent_req is not None else 1
if parent_req is not None:
num_generation_tokens = parent_req.observe_num_generation_tokens(
num_generation_tokens
)
# Child requests finished, we can now record to iteration stats
if parent_req is None or not parent_req.child_requests:
iteration_stats.max_num_generation_tokens_iter.append(num_generation_tokens)
iteration_stats.n_params_iter.append(n_param)

View File

@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
def __getattr__(name: str):
if name == "Processor":
from .input_processor import InputProcessor
warnings.warn(
"`vllm.v1.engine.processor.Processor` has been moved to "
"`vllm.v1.engine.input_processor.InputProcessor`. "
"The old name will be removed in v0.14.",
DeprecationWarning,
stacklevel=2,
)
return InputProcessor
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

1068
vllm/v1/engine/utils.py Normal file

File diff suppressed because it is too large Load Diff