Sync from v0.13
This commit is contained in:
217
vllm/v1/engine/__init__.py
Normal file
217
vllm/v1/engine/__init__.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
from vllm.v1.serial_utils import UtilityResult
|
||||
|
||||
# These are possible values of RequestOutput.finish_reason,
|
||||
# so form part of the external API.
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
|
||||
|
||||
|
||||
class FinishReason(enum.IntEnum):
|
||||
"""
|
||||
Reason a request finished - stop, length, abort, or error.
|
||||
|
||||
Int rather than Str for more compact serialization.
|
||||
|
||||
stop - a stop string was emitted
|
||||
length - max_tokens was consumed, or max_model_len was reached
|
||||
abort - aborted by client
|
||||
error - retryable request-level internal error (e.g., KV load failure).
|
||||
Invariant: always converted to 500 Internal Server Error.
|
||||
|
||||
"""
|
||||
|
||||
STOP = 0
|
||||
LENGTH = 1
|
||||
ABORT = 2
|
||||
ERROR = 3
|
||||
|
||||
def __str__(self):
|
||||
return FINISH_REASON_STRINGS[self.value]
|
||||
|
||||
|
||||
class EngineCoreRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False,
|
||||
): # type: ignore[call-arg]
|
||||
request_id: str
|
||||
prompt_token_ids: list[int] | None
|
||||
mm_features: list[MultiModalFeatureSpec] | None
|
||||
sampling_params: SamplingParams | None
|
||||
pooling_params: PoolingParams | None
|
||||
eos_token_id: int | None
|
||||
arrival_time: float
|
||||
lora_request: LoRARequest | None
|
||||
cache_salt: str | None
|
||||
data_parallel_rank: int | None
|
||||
prompt_embeds: torch.Tensor | None = None
|
||||
|
||||
# Index of the client, used to ensure outputs are sent back to the same
|
||||
# client for this request when scaling out the front-end.
|
||||
client_index: int = 0
|
||||
|
||||
# Used in DP case to indicate which wave of requests this is expected to
|
||||
# belong to, to cover a race condition where the request is sent before
|
||||
# a wave finished notification is received.
|
||||
current_wave: int = 0
|
||||
priority: int = 0
|
||||
|
||||
trace_headers: Mapping[str, str] | None = None
|
||||
|
||||
@property
|
||||
def params(self) -> SamplingParams | PoolingParams:
|
||||
"""Return the processed params (sampling or pooling)."""
|
||||
if self.sampling_params is not None:
|
||||
return self.sampling_params
|
||||
assert self.pooling_params is not None
|
||||
return self.pooling_params
|
||||
|
||||
|
||||
class EngineCoreEventType(enum.IntEnum):
|
||||
"""The type of engine core request event."""
|
||||
|
||||
QUEUED = 1
|
||||
SCHEDULED = 2
|
||||
PREEMPTED = 3
|
||||
|
||||
|
||||
class EngineCoreEvent(msgspec.Struct):
|
||||
"""A timestamped engine core event associated with a request.
|
||||
|
||||
The timestamp is a monotonic timestamps and is used for by the engine
|
||||
frontend to calculate intervals between engine core events. These
|
||||
timestamps should not be compared with timestamps from other processes.
|
||||
"""
|
||||
|
||||
type: EngineCoreEventType
|
||||
timestamp: float
|
||||
|
||||
@classmethod
|
||||
def new_event(
|
||||
cls, event_type: EngineCoreEventType, timestamp: float | None = None
|
||||
) -> "EngineCoreEvent":
|
||||
timestamp = time.monotonic() if timestamp is None else timestamp
|
||||
return cls(event_type, timestamp)
|
||||
|
||||
|
||||
class EngineCoreOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False,
|
||||
): # type: ignore[call-arg]
|
||||
request_id: str
|
||||
new_token_ids: list[int]
|
||||
|
||||
new_logprobs: LogprobsLists | None = None
|
||||
new_prompt_logprobs_tensors: LogprobsTensors | None = None
|
||||
|
||||
pooling_output: torch.Tensor | None = None
|
||||
|
||||
finish_reason: FinishReason | None = None
|
||||
stop_reason: int | str | None = None
|
||||
events: list[EngineCoreEvent] | None = None
|
||||
kv_transfer_params: dict[str, Any] | None = None
|
||||
|
||||
trace_headers: Mapping[str, str] | None = None
|
||||
# The number of tokens with prefix cache hits.
|
||||
num_cached_tokens: int = 0
|
||||
|
||||
# The number of NaNs in logits.
|
||||
# A value greater than 0 indicates that the output is corrupted.
|
||||
num_nans_in_logits: int = 0
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self.finish_reason is not None
|
||||
|
||||
|
||||
class UtilityOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
gc=False,
|
||||
): # type: ignore[call-arg]
|
||||
call_id: int
|
||||
|
||||
# Non-None implies the call failed, result should be None.
|
||||
failure_message: str | None = None
|
||||
result: UtilityResult | None = None
|
||||
|
||||
|
||||
class EngineCoreOutputs(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False,
|
||||
): # type: ignore[call-arg]
|
||||
# NOTE(Nick): We could consider ways to make this more compact,
|
||||
# e.g. columnwise layout
|
||||
|
||||
engine_index: int = 0
|
||||
|
||||
# [num_reqs]
|
||||
outputs: list[EngineCoreOutput] = []
|
||||
scheduler_stats: SchedulerStats | None = None
|
||||
timestamp: float = 0.0
|
||||
|
||||
utility_output: UtilityOutput | None = None
|
||||
finished_requests: set[str] | None = None
|
||||
|
||||
# In DP case, used to signal that the current wave of requests
|
||||
# has finished and the engines are paused.
|
||||
wave_complete: int | None = None
|
||||
# In DP case, used to signal that a request was received for an
|
||||
# "old" wave, so the next wave needs to be started in other engines.
|
||||
start_wave: int | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp == 0.0:
|
||||
self.timestamp = time.monotonic()
|
||||
|
||||
|
||||
class EngineCoreRequestType(enum.Enum):
|
||||
"""
|
||||
Request types defined as hex byte strings, so it can be sent over sockets
|
||||
without separate encoding step.
|
||||
"""
|
||||
|
||||
ADD = b"\x00"
|
||||
ABORT = b"\x01"
|
||||
START_DP_WAVE = b"\x02"
|
||||
UTILITY = b"\x03"
|
||||
# Sentinel used within EngineCoreProc.
|
||||
EXECUTOR_FAILED = b"\x04"
|
||||
|
||||
|
||||
class ReconfigureDistributedRequest(msgspec.Struct):
|
||||
new_data_parallel_size: int
|
||||
new_data_parallel_rank: int
|
||||
new_data_parallel_rank_local: int
|
||||
new_data_parallel_master_ip: str
|
||||
new_data_parallel_master_port: int
|
||||
|
||||
|
||||
class ReconfigureRankType(enum.IntEnum):
|
||||
"""
|
||||
Rank type for reconfiguring distributed request.
|
||||
"""
|
||||
|
||||
KEEP_CURRENT_RANK = -1
|
||||
SHUTDOWN_CURRENT_RANK = -2
|
||||
866
vllm/v1/engine/async_llm.py
Normal file
866
vllm/v1/engine/async_llm.py
Normal file
@@ -0,0 +1,866 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping
|
||||
from copy import copy
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.async_utils import cancel_task_threadsafe
|
||||
from vllm.utils.collection_utils import as_list
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.metrics.loggers import (
|
||||
StatLoggerFactory,
|
||||
StatLoggerManager,
|
||||
load_stat_logger_plugin_factories,
|
||||
)
|
||||
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncLLM(EngineClient):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = True,
|
||||
stat_loggers: list[StatLoggerFactory] | None = None,
|
||||
aggregate_engine_logging: bool = False,
|
||||
client_addresses: dict[str, str] | None = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Create an AsyncLLM.
|
||||
|
||||
Args:
|
||||
vllm_config: global configuration.
|
||||
executor_class: an Executor impl, e.g. MultiprocExecutor.
|
||||
log_stats: Whether to log stats.
|
||||
usage_context: Usage context of the LLM.
|
||||
mm_registry: Multi-modal registry.
|
||||
use_cached_outputs: Whether to use cached outputs.
|
||||
log_requests: Whether to log requests.
|
||||
start_engine_loop: Whether to start the engine loop.
|
||||
stat_loggers: customized stat loggers for the engine.
|
||||
If not provided, default stat loggers will be used.
|
||||
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
|
||||
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Ensure we can serialize custom transformer configs
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.log_requests = log_requests
|
||||
|
||||
custom_stat_loggers = list(stat_loggers or [])
|
||||
custom_stat_loggers.extend(load_stat_logger_plugin_factories())
|
||||
|
||||
has_custom_loggers = bool(custom_stat_loggers)
|
||||
self.log_stats = log_stats or has_custom_loggers
|
||||
if not log_stats and has_custom_loggers:
|
||||
logger.info(
|
||||
"AsyncLLM created with log_stats=False, "
|
||||
"but custom stat loggers were found; "
|
||||
"enabling logging without default stat loggers."
|
||||
)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_tokenizer_from_config(self.model_config)
|
||||
|
||||
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(
|
||||
self.tokenizer,
|
||||
log_stats=self.log_stats,
|
||||
stream_interval=self.vllm_config.scheduler_config.stream_interval,
|
||||
)
|
||||
endpoint = self.observability_config.otlp_traces_endpoint
|
||||
if endpoint is not None:
|
||||
tracer = init_tracer("vllm.llm_engine", endpoint)
|
||||
self.output_processor.tracer = tracer
|
||||
|
||||
# EngineCore (starts the engine in background process).
|
||||
self.engine_core = EngineCoreClient.make_async_mp_client(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
client_addresses=client_addresses,
|
||||
client_count=client_count,
|
||||
client_index=client_index,
|
||||
)
|
||||
|
||||
# Loggers.
|
||||
self.logger_manager: StatLoggerManager | None = None
|
||||
if self.log_stats:
|
||||
self.logger_manager = StatLoggerManager(
|
||||
vllm_config=vllm_config,
|
||||
engine_idxs=self.engine_core.engine_ranks_managed,
|
||||
custom_stat_loggers=custom_stat_loggers,
|
||||
enable_default_loggers=log_stats,
|
||||
client_count=client_count,
|
||||
aggregate_engine_logging=aggregate_engine_logging,
|
||||
)
|
||||
self.logger_manager.log_engine_initialized()
|
||||
|
||||
# Pause / resume state for async RL workflows.
|
||||
self._pause_cond = asyncio.Condition()
|
||||
self._paused = False
|
||||
|
||||
self.output_handler: asyncio.Task | None = None
|
||||
try:
|
||||
# Start output handler eagerly if we are in the asyncio eventloop.
|
||||
asyncio.get_running_loop()
|
||||
self._run_output_handler()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
if (
|
||||
vllm_config.profiler_config.profiler == "torch"
|
||||
and not vllm_config.profiler_config.ignore_frontend
|
||||
):
|
||||
profiler_dir = vllm_config.profiler_config.torch_profiler_dir
|
||||
logger.info(
|
||||
"Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501
|
||||
profiler_dir,
|
||||
)
|
||||
worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
],
|
||||
with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
profiler_dir,
|
||||
worker_name=worker_name,
|
||||
use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"`AsyncLLM.processor` has been renamed to `AsyncLLM.input_processor`. "
|
||||
"The old name will be removed in v0.14."
|
||||
)
|
||||
def processor(self):
|
||||
return self.input_processor
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: list[StatLoggerFactory] | None = None,
|
||||
enable_log_requests: bool = False,
|
||||
aggregate_engine_logging: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
client_addresses: dict[str, str] | None = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0,
|
||||
) -> "AsyncLLM":
|
||||
# Create the LLMEngine.
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
start_engine_loop=start_engine_loop,
|
||||
stat_loggers=stat_loggers,
|
||||
log_requests=enable_log_requests,
|
||||
log_stats=not disable_log_stats,
|
||||
aggregate_engine_logging=aggregate_engine_logging,
|
||||
usage_context=usage_context,
|
||||
client_addresses=client_addresses,
|
||||
client_count=client_count,
|
||||
client_index=client_index,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: list[StatLoggerFactory] | None = None,
|
||||
) -> "AsyncLLM":
|
||||
"""Create an AsyncLLM from the EngineArgs."""
|
||||
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
# Create the AsyncLLM.
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_requests=engine_args.enable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown, cleaning up the background proc and IPC."""
|
||||
|
||||
shutdown_prometheus()
|
||||
|
||||
if engine_core := getattr(self, "engine_core", None):
|
||||
engine_core.shutdown()
|
||||
|
||||
handler = getattr(self, "output_handler", None)
|
||||
if handler is not None:
|
||||
cancel_task_threadsafe(handler)
|
||||
|
||||
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return await self.engine_core.get_supported_tasks_async()
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: EngineCoreRequest | PromptType,
|
||||
params: SamplingParams | PoolingParams,
|
||||
arrival_time: float | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: int | None = None,
|
||||
prompt_text: str | None = None,
|
||||
) -> RequestOutputCollector:
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
if self.errored:
|
||||
raise EngineDeadError()
|
||||
|
||||
is_pooling = isinstance(params, PoolingParams)
|
||||
|
||||
# Create a new output collector for the request.
|
||||
queue = RequestOutputCollector(output_kind=params.output_kind)
|
||||
|
||||
# Convert Input --> Request.
|
||||
if isinstance(prompt, EngineCoreRequest):
|
||||
request = prompt
|
||||
else:
|
||||
assert prompt_text is None
|
||||
request = self.input_processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
tokenization_kwargs,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
)
|
||||
if isinstance(prompt, str):
|
||||
prompt_text = prompt
|
||||
elif isinstance(prompt, Mapping):
|
||||
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||
|
||||
# Use cloned params that may have been updated in process_inputs()
|
||||
params = request.params
|
||||
|
||||
if is_pooling or params.n == 1:
|
||||
await self._add_request(request, prompt_text, None, 0, queue)
|
||||
return queue
|
||||
|
||||
parent_params = params
|
||||
assert isinstance(parent_params, SamplingParams)
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_request = ParentRequest(request_id, parent_params)
|
||||
for idx in range(parent_params.n):
|
||||
request_id, child_params = parent_request.get_child_info(idx)
|
||||
child_request = request if idx == parent_params.n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = child_params
|
||||
await self._add_request(
|
||||
child_request, prompt_text, parent_request, idx, queue
|
||||
)
|
||||
return queue
|
||||
|
||||
async def _add_request(
|
||||
self,
|
||||
request: EngineCoreRequest,
|
||||
prompt: str | None,
|
||||
parent_req: ParentRequest | None,
|
||||
index: int,
|
||||
queue: RequestOutputCollector,
|
||||
):
|
||||
# Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, prompt, parent_req, index, queue)
|
||||
|
||||
# Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
||||
# TODO: we should support multiple prompts in one call, as you
|
||||
# can do with LLM.generate. So that for multi-prompt completion
|
||||
# requests we don't need to send multiple messages to core proc,
|
||||
# and so we don't need multiple streams which then get
|
||||
# re-multiplexed in the API server anyhow.
|
||||
async def generate(
|
||||
self,
|
||||
prompt: EngineCoreRequest | PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
*,
|
||||
prompt_text: str | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: int | None = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""
|
||||
Main function called by the API server to kick off a request
|
||||
* 1) Making an AsyncStream corresponding to the Request.
|
||||
* 2) Processing the Input.
|
||||
* 3) Adding the Request to the Detokenizer.
|
||||
* 4) Adding the Request to the EngineCore (separate process).
|
||||
|
||||
A separate output_handler loop runs in a background AsyncIO task,
|
||||
pulling outputs from EngineCore and putting them into the
|
||||
per-request AsyncStream.
|
||||
|
||||
The caller of generate() iterates the returned AsyncGenerator,
|
||||
returning the RequestOutput back to the caller.
|
||||
"""
|
||||
|
||||
if (
|
||||
self.vllm_config.cache_config.kv_sharing_fast_prefill
|
||||
and sampling_params.prompt_logprobs
|
||||
):
|
||||
raise ValueError(
|
||||
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||
"prompt tokens, please disable it when the requests need "
|
||||
"prompt logprobs"
|
||||
)
|
||||
|
||||
try:
|
||||
# We start the output_handler on the first call to generate() so
|
||||
# we can call __init__ before the event loop, which enables us
|
||||
# to handle startup failure gracefully in the OpenAI server.
|
||||
self._run_output_handler()
|
||||
|
||||
# Wait until generation is resumed if the engine is paused.
|
||||
async with self._pause_cond:
|
||||
await self._pause_cond.wait_for(lambda: not self._paused)
|
||||
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
|
||||
|
||||
_validate_truncation_size(
|
||||
self.model_config.max_model_len,
|
||||
truncate_prompt_tokens,
|
||||
tokenization_kwargs,
|
||||
)
|
||||
|
||||
q = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
prompt_text=prompt_text,
|
||||
)
|
||||
|
||||
# The output_handler task pushes items into the queue.
|
||||
# This task pulls from the queue and yields to caller.
|
||||
finished = False
|
||||
while not finished:
|
||||
# Note: drain queue without await if possible (avoids
|
||||
# task switching under load which helps performance).
|
||||
out = q.get_nowait() or await q.get()
|
||||
|
||||
# Note: both OutputProcessor and EngineCore handle their
|
||||
# own request cleanup based on finished.
|
||||
finished = out.finished
|
||||
assert isinstance(out, RequestOutput)
|
||||
yield out
|
||||
|
||||
# If the request is disconnected by the client, generate()
|
||||
# is cancelled or the generator is garbage collected. So,
|
||||
# we abort the request if we end up here.
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s aborted.", request_id)
|
||||
raise
|
||||
|
||||
# Engine is dead. Do not abort since we shut down.
|
||||
except EngineDeadError:
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (engine dead).", request_id)
|
||||
raise
|
||||
|
||||
# Request validation error.
|
||||
except ValueError:
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (bad request).", request_id)
|
||||
raise
|
||||
|
||||
# Unexpected error in the generate() task (possibly recoverable).
|
||||
except Exception as e:
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed.", request_id)
|
||||
raise EngineGenerateError() from e
|
||||
|
||||
def _run_output_handler(self):
|
||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||
|
||||
if self.output_handler is not None:
|
||||
return
|
||||
|
||||
# Ensure that the task doesn't have a circular ref back to the AsyncLLM
|
||||
# object, or else it won't be garbage collected and cleaned up properly.
|
||||
engine_core = self.engine_core
|
||||
output_processor = self.output_processor
|
||||
log_stats = self.log_stats
|
||||
logger_manager = self.logger_manager
|
||||
input_processor = self.input_processor
|
||||
|
||||
async def output_handler():
|
||||
try:
|
||||
while True:
|
||||
# 1) Pull EngineCoreOutputs from the EngineCore.
|
||||
outputs = await engine_core.get_output_async()
|
||||
num_outputs = len(outputs.outputs)
|
||||
|
||||
iteration_stats = (
|
||||
IterationStats() if (log_stats and num_outputs) else None
|
||||
)
|
||||
|
||||
# Split outputs into chunks of at most
|
||||
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
|
||||
# event loop for too long.
|
||||
if num_outputs <= envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
|
||||
slices = (outputs.outputs,)
|
||||
else:
|
||||
slices = np.array_split(
|
||||
outputs.outputs,
|
||||
cdiv(num_outputs, envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE),
|
||||
)
|
||||
|
||||
for i, outputs_slice in enumerate(slices):
|
||||
# 2) Process EngineCoreOutputs.
|
||||
processed_outputs = output_processor.process_outputs(
|
||||
outputs_slice, outputs.timestamp, iteration_stats
|
||||
)
|
||||
# NOTE: RequestOutputs are pushed to their queues.
|
||||
assert not processed_outputs.request_outputs
|
||||
|
||||
# Allow other asyncio tasks to run between chunks
|
||||
if i + 1 < len(slices):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
await engine_core.abort_requests_async(
|
||||
processed_outputs.reqs_to_abort
|
||||
)
|
||||
|
||||
output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
# 4) Logging.
|
||||
# TODO(rob): make into a coroutine and launch it in
|
||||
# background thread once Prometheus overhead is non-trivial.
|
||||
if logger_manager:
|
||||
logger_manager.record(
|
||||
engine_idx=outputs.engine_index,
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
mm_cache_stats=input_processor.stat_mm_cache(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("AsyncLLM output_handler failed.")
|
||||
output_processor.propagate_error(e)
|
||||
|
||||
self.output_handler = asyncio.create_task(output_handler())
|
||||
|
||||
async def abort(self, request_id: str | Iterable[str]) -> None:
|
||||
"""Abort RequestId in OutputProcessor and EngineCore."""
|
||||
|
||||
request_ids = (
|
||||
(request_id,) if isinstance(request_id, str) else as_list(request_id)
|
||||
)
|
||||
all_request_ids = self.output_processor.abort_requests(request_ids)
|
||||
await self.engine_core.abort_requests_async(all_request_ids)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Aborted request(s) %s.", ",".join(request_ids))
|
||||
|
||||
async def pause_generation(
|
||||
self,
|
||||
*,
|
||||
wait_for_inflight_requests: bool = False,
|
||||
clear_cache: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Pause generation to allow model weight updates.
|
||||
|
||||
New generation/encoding requests are blocked until resume.
|
||||
|
||||
Args:
|
||||
wait_for_inflight_requests: When ``True`` waits for in-flight
|
||||
requests to finish before pausing. When ``False`` (default),
|
||||
immediately aborts any in-flight requests.
|
||||
clear_cache: Whether to clear KV cache and prefix cache after
|
||||
draining. Set to ``False`` to preserve cache for faster resume.
|
||||
Default is ``True`` (clear caches).
|
||||
"""
|
||||
|
||||
async with self._pause_cond:
|
||||
if self._paused:
|
||||
return
|
||||
self._paused = True
|
||||
|
||||
if not wait_for_inflight_requests:
|
||||
request_ids = list(self.output_processor.request_states.keys())
|
||||
if request_ids:
|
||||
await self.abort(request_ids)
|
||||
|
||||
# Wait for running requests to drain before clearing cache.
|
||||
if self.output_processor.has_unfinished_requests():
|
||||
await self.output_processor.wait_for_requests_to_drain()
|
||||
|
||||
# Clear cache
|
||||
if clear_cache:
|
||||
await self.reset_prefix_cache()
|
||||
await self.reset_mm_cache()
|
||||
|
||||
async def resume_generation(self) -> None:
|
||||
"""Resume generation after :meth:`pause_generation`."""
|
||||
|
||||
async with self._pause_cond:
|
||||
self._paused = False
|
||||
self._pause_cond.notify_all() # Wake up all waiting requests
|
||||
|
||||
async def is_paused(self) -> bool:
|
||||
"""Return whether the engine is currently paused."""
|
||||
|
||||
async with self._pause_cond:
|
||||
return self._paused
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""
|
||||
Main function called by the API server to kick off a request
|
||||
* 1) Making an AsyncStream corresponding to the Request.
|
||||
* 2) Processing the Input.
|
||||
* 3) Adding the Request to the EngineCore (separate process).
|
||||
|
||||
A separate output_handler loop runs in a background AsyncIO task,
|
||||
pulling outputs from EngineCore and putting them into the
|
||||
per-request AsyncStream.
|
||||
|
||||
The caller of generate() iterates the returned AsyncGenerator,
|
||||
returning the RequestOutput back to the caller.
|
||||
"""
|
||||
|
||||
try:
|
||||
# We start the output_handler on the first call to generate() so
|
||||
# we can call __init__ before the event loop, which enables us
|
||||
# to handle startup failure gracefully in the OpenAI server.
|
||||
self._run_output_handler()
|
||||
|
||||
# Respect pause state before accepting new requests.
|
||||
async with self._pause_cond:
|
||||
await self._pause_cond.wait_for(lambda: not self._paused)
|
||||
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
_validate_truncation_size(
|
||||
self.model_config.max_model_len,
|
||||
truncate_prompt_tokens,
|
||||
tokenization_kwargs,
|
||||
)
|
||||
|
||||
q = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
pooling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
# The output_handler task pushes items into the queue.
|
||||
# This task pulls from the queue and yields to caller.
|
||||
finished = False
|
||||
while not finished:
|
||||
# Note: drain queue without await if possible (avoids
|
||||
# task switching under load which helps performance).
|
||||
out = q.get_nowait() or await q.get()
|
||||
assert isinstance(out, PoolingRequestOutput)
|
||||
# Note: both OutputProcessor and EngineCore handle their
|
||||
# own request cleanup based on finished.
|
||||
finished = out.finished
|
||||
yield out
|
||||
|
||||
# If the request is disconnected by the client, generate()
|
||||
# is cancelled. So, we abort the request if we end up here.
|
||||
except asyncio.CancelledError:
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s aborted.", request_id)
|
||||
raise
|
||||
|
||||
# Engine is dead. Do not abort since we shut down.
|
||||
except EngineDeadError:
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (engine dead).", request_id)
|
||||
raise
|
||||
|
||||
# Request validation error.
|
||||
except ValueError:
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (bad request).", request_id)
|
||||
raise
|
||||
|
||||
# Unexpected error in the generate() task (possibly recoverable).
|
||||
except Exception as e:
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed.", request_id)
|
||||
raise EngineGenerateError() from e
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
return self.input_processor.tokenizer
|
||||
|
||||
async def get_tokenizer(self) -> TokenizerLike:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||||
)
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return self.observability_config.otlp_traces_endpoint is not None # type: ignore
|
||||
|
||||
async def do_log_stats(self) -> None:
|
||||
if self.logger_manager:
|
||||
self.logger_manager.log()
|
||||
|
||||
async def check_health(self) -> None:
|
||||
logger.debug("Called check_health.")
|
||||
if self.errored:
|
||||
raise self.dead_error
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
coros = [self.engine_core.profile_async(True)]
|
||||
if self.profiler is not None:
|
||||
coros.append(asyncio.to_thread(self.profiler.start))
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
coros = [self.engine_core.profile_async(False)]
|
||||
if self.profiler is not None:
|
||||
coros.append(asyncio.to_thread(self.profiler.stop))
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
async def reset_mm_cache(self) -> None:
|
||||
self.input_processor.clear_mm_cache()
|
||||
await self.engine_core.reset_mm_cache_async()
|
||||
|
||||
async def reset_prefix_cache(
|
||||
self, reset_running_requests: bool = False, reset_connector: bool = False
|
||||
) -> bool:
|
||||
return await self.engine_core.reset_prefix_cache_async(
|
||||
reset_running_requests, reset_connector
|
||||
)
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
await self.reset_prefix_cache()
|
||||
await self.engine_core.sleep_async(level)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
self.logger_manager.record_sleep_state(1, level)
|
||||
|
||||
async def wake_up(self, tags: list[str] | None = None) -> None:
|
||||
await self.engine_core.wake_up_async(tags)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
self.logger_manager.record_sleep_state(0, 0)
|
||||
|
||||
async def is_sleeping(self) -> bool:
|
||||
return await self.engine_core.is_sleeping_async()
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
return await self.engine_core.add_lora_async(lora_request)
|
||||
|
||||
async def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove an already loaded LoRA adapter."""
|
||||
return await self.engine_core.remove_lora_async(lora_id)
|
||||
|
||||
async def list_loras(self) -> set[int]:
|
||||
"""List all registered adapters."""
|
||||
return await self.engine_core.list_loras_async()
|
||||
|
||||
async def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return await self.engine_core.pin_lora_async(lora_id)
|
||||
|
||||
async def collective_rpc(
|
||||
self,
|
||||
method: str,
|
||||
timeout: float | None = None,
|
||||
args: tuple = (),
|
||||
kwargs: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Perform a collective RPC call to the given path.
|
||||
"""
|
||||
return await self.engine_core.collective_rpc_async(
|
||||
method, timeout, args, kwargs
|
||||
)
|
||||
|
||||
async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
|
||||
"""Wait for all requests to be drained."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < drain_timeout:
|
||||
if not self.engine_core.dp_engines_running():
|
||||
logger.info("Engines are idle, requests have been drained")
|
||||
return
|
||||
|
||||
logger.info("Engines are still running, waiting for requests to drain...")
|
||||
await asyncio.sleep(1) # Wait 1 second before checking again
|
||||
|
||||
raise TimeoutError(
|
||||
f"Timeout reached after {drain_timeout} seconds "
|
||||
"waiting for requests to drain."
|
||||
)
|
||||
|
||||
async def scale_elastic_ep(
|
||||
self, new_data_parallel_size: int, drain_timeout: int = 300
|
||||
):
|
||||
"""
|
||||
Scale up or down the data parallel size by adding or removing
|
||||
engine cores.
|
||||
Args:
|
||||
new_data_parallel_size: The new number of data parallel workers
|
||||
drain_timeout:
|
||||
Maximum time to wait for requests to drain (seconds)
|
||||
"""
|
||||
old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
if old_data_parallel_size == new_data_parallel_size:
|
||||
logger.info(
|
||||
"Data parallel size is already %s, skipping scale",
|
||||
new_data_parallel_size,
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"Waiting for requests to drain before scaling up to %s engines...",
|
||||
new_data_parallel_size,
|
||||
)
|
||||
await self.wait_for_requests_to_drain(drain_timeout)
|
||||
logger.info(
|
||||
"Requests have been drained, proceeding with scale to %s engines",
|
||||
new_data_parallel_size,
|
||||
)
|
||||
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
|
||||
# recreate stat loggers
|
||||
if new_data_parallel_size > old_data_parallel_size and self.log_stats:
|
||||
# TODO(rob): fix this after talking with Ray team.
|
||||
# This resets all the prometheus metrics since we
|
||||
# unregister during initialization. Need to understand
|
||||
# the intended behavior here better.
|
||||
self.logger_manager = StatLoggerManager(
|
||||
vllm_config=self.vllm_config,
|
||||
engine_idxs=list(range(new_data_parallel_size)),
|
||||
custom_stat_loggers=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
# Is None before the loop is started.
|
||||
return self.output_handler is None or not self.output_handler.done()
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self.errored
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return self.engine_core.resources.engine_dead or not self.is_running
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
return EngineDeadError()
|
||||
377
vllm/v1/engine/coordinator.py
Normal file
377
vllm/v1/engine/coordinator.py
Normal file
@@ -0,0 +1,377 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
|
||||
import msgspec.msgpack
|
||||
import zmq
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.utils.system_utils import get_mp_context, set_process_title
|
||||
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DPCoordinator:
|
||||
"""Coordinator process used for data-parallel deployments (DP>1).
|
||||
|
||||
Intermediates between multiple DP engine rank processes and one or more
|
||||
front-end API server processes.
|
||||
|
||||
* Collects stats from each DP engine (currently just waiting and running
|
||||
queue lengths), and publishes these to all front-ends for use in
|
||||
load-balancing decisions.
|
||||
|
||||
* Keeps track of the current DP "request wave" number and running state
|
||||
of the engines. This is received from the DP rank 0 engine and published
|
||||
to the front-end processes along with the current load stats.
|
||||
|
||||
The engines alternate between a global running/paused state. The global
|
||||
"request wave" number is a count of the number of times that the workers
|
||||
collectively move from a running state to a paused state. This transition
|
||||
is synchronized via the all-reduce operation performed in the
|
||||
DPEngineCoreProc._has_global_unfinished_reqs method.
|
||||
|
||||
* Broadcasts the START_DP_WAVE message to engines to move them from paused
|
||||
to running state when one engine receives a new request. This can happen
|
||||
in two cases:
|
||||
1) A front-end sending a new request while the engines are paused will
|
||||
concurrently notify the coordinator.
|
||||
2) An engine receiving a request for a stale request wave while in paused
|
||||
state will notify the coordinator.
|
||||
|
||||
Engines will move into running state when receiving a new request or
|
||||
START_DP_WAVE message.
|
||||
|
||||
Note that when deployed in External LB mode, no stats will be published by
|
||||
the engines and thus updates will only be sent to front-ends when the
|
||||
request wave / running state changes.
|
||||
"""
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
assert dp_size > 1, "Coordinator only used for data parallel"
|
||||
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
external_lb = parallel_config.data_parallel_external_lb
|
||||
hybrid_lb = parallel_config.data_parallel_hybrid_lb
|
||||
|
||||
# Assume coordinator is colocated with front-end procs when not in
|
||||
# either external or hybrid DP LB mode.
|
||||
local_only = not (external_lb or hybrid_lb)
|
||||
front_publish_address = get_engine_client_zmq_addr(
|
||||
local_only=local_only, host=host
|
||||
)
|
||||
|
||||
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
|
||||
context = get_mp_context()
|
||||
self.proc: multiprocessing.Process = context.Process(
|
||||
target=DPCoordinatorProc.run_coordinator,
|
||||
name="VLLM_DP_Coordinator",
|
||||
kwargs={
|
||||
"engine_count": parallel_config.data_parallel_size,
|
||||
"front_publish_address": front_publish_address,
|
||||
"back_output_address": back_output_address,
|
||||
"back_publish_address": back_publish_address,
|
||||
},
|
||||
daemon=True,
|
||||
)
|
||||
self.proc.start()
|
||||
|
||||
self.stats_publish_address = front_publish_address
|
||||
self.coord_in_address = back_publish_address
|
||||
self.coord_out_address = back_output_address
|
||||
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
|
||||
|
||||
def get_stats_publish_address(self) -> str:
|
||||
return self.stats_publish_address
|
||||
|
||||
def get_engine_socket_addresses(self) -> tuple[str, str]:
|
||||
"""Returns tuple of ZMQ input address, output address."""
|
||||
return self.coord_in_address, self.coord_out_address
|
||||
|
||||
def close(self):
|
||||
self._finalizer()
|
||||
|
||||
|
||||
class EngineState:
|
||||
def __init__(self):
|
||||
self.request_counts = [0, 0] # [waiting, running]
|
||||
|
||||
|
||||
class DPCoordinatorProc:
|
||||
def __init__(self, engine_count: int, min_stats_update_interval_ms: int = 100):
|
||||
set_process_title("DPCoordinator")
|
||||
self.ctx = zmq.Context()
|
||||
|
||||
self.engines = [EngineState() for _ in range(engine_count)]
|
||||
|
||||
self.stats_update_interval_ms = min_stats_update_interval_ms
|
||||
|
||||
@staticmethod
|
||||
def run_coordinator(
|
||||
engine_count: int,
|
||||
front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
min_stats_update_interval_ms: int = 100,
|
||||
):
|
||||
coordinator = DPCoordinatorProc(
|
||||
engine_count=engine_count,
|
||||
min_stats_update_interval_ms=min_stats_update_interval_ms,
|
||||
)
|
||||
try:
|
||||
coordinator.process_input_socket(
|
||||
front_publish_address,
|
||||
back_output_address,
|
||||
back_publish_address,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DP Coordinator process exiting")
|
||||
|
||||
def process_input_socket(
|
||||
self,
|
||||
front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
):
|
||||
decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
# For tracking request wave progression.
|
||||
current_wave = 0
|
||||
engines_running = False
|
||||
|
||||
# For tracking request counts for internal load-balancing.
|
||||
stats_changed = False
|
||||
last_stats_step = -1
|
||||
last_stats_wave = -1
|
||||
last_step_counts: list[list[int]] | None = None
|
||||
|
||||
with (
|
||||
make_zmq_socket(
|
||||
path=front_publish_address, # IPC
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.XPUB,
|
||||
bind=True,
|
||||
) as publish_front,
|
||||
make_zmq_socket(
|
||||
path=back_output_address, # IPC or TCP
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.PULL,
|
||||
bind=True,
|
||||
) as output_back,
|
||||
make_zmq_socket(
|
||||
path=back_publish_address, # IPC or TCP
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.XPUB,
|
||||
bind=True,
|
||||
) as publish_back,
|
||||
):
|
||||
# Wait until all engines subscribe.
|
||||
for _ in self.engines:
|
||||
if publish_back.recv() != b"\x01":
|
||||
logger.error(
|
||||
"DP Coordinator received unexpected message while "
|
||||
"waiting for engines to subscribe"
|
||||
)
|
||||
return
|
||||
# Send ready message to engines.
|
||||
publish_back.send(b"READY")
|
||||
|
||||
logger.info("All engine subscriptions received by DP coordinator")
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(publish_front, zmq.POLLIN)
|
||||
poller.register(output_back, zmq.POLLIN)
|
||||
last_publish_time = 0
|
||||
while True:
|
||||
elapsed = int(time.time() * 1000) - last_publish_time
|
||||
# Send at stats_update_interval_ms interval if the stats have
|
||||
# changed, or otherwise every 5 seconds.
|
||||
wait_for = self.stats_update_interval_ms if stats_changed else 5000
|
||||
|
||||
# Wait at least 50ms to ensure we've received all stats for
|
||||
# the current step.
|
||||
min_timeout = 50 if last_step_counts is None else 0
|
||||
|
||||
events = poller.poll(timeout=max(min_timeout, wait_for - elapsed))
|
||||
if not events:
|
||||
# Poller timeout - publish current stats to front-ends.
|
||||
if last_step_counts is not None:
|
||||
engine_req_counts_list = last_step_counts
|
||||
last_step_counts = None
|
||||
else:
|
||||
engine_req_counts_list = self._get_engine_counts()
|
||||
stats_changed = False
|
||||
|
||||
to_publish = (engine_req_counts_list, current_wave, engines_running)
|
||||
publish_front.send(msgspec.msgpack.encode(to_publish))
|
||||
last_publish_time = int(time.time() * 1000)
|
||||
continue
|
||||
|
||||
events = dict(events)
|
||||
wave_state_changed = False
|
||||
|
||||
if publish_front in events:
|
||||
buffer = publish_front.recv()
|
||||
if buffer in (b"\x01", b"\x00"):
|
||||
# Ignore subscription messages.
|
||||
continue
|
||||
|
||||
decoded = msgspec.msgpack.decode(buffer)
|
||||
if (
|
||||
isinstance(decoded, (list, tuple))
|
||||
and len(decoded) == 2
|
||||
and decoded[0] == "SCALE_ELASTIC_EP"
|
||||
):
|
||||
# Handle scale up notification
|
||||
new_engine_count = decoded[1]
|
||||
current_count = len(self.engines)
|
||||
if new_engine_count > current_count:
|
||||
for _ in range(new_engine_count - current_count):
|
||||
self.engines.append(EngineState())
|
||||
# NOTE(yongji): handle the case
|
||||
# where newly started engines have current_wave = 0
|
||||
# if existing engines just finished a wave
|
||||
# and engine_running isn't updated yet at
|
||||
# CoordinatorProc requests routed to newly started
|
||||
# engines may not wake up existing engines, as long
|
||||
# as 0 < request.wave < existing engines'
|
||||
# current_wave
|
||||
# we note that 0 is the wave number for the new
|
||||
# engine
|
||||
engines_running = False
|
||||
logger.info(
|
||||
"DPCoordinator scaled up from %s to %s engines",
|
||||
current_count,
|
||||
new_engine_count,
|
||||
)
|
||||
else:
|
||||
self.engines = self.engines[:new_engine_count]
|
||||
logger.info(
|
||||
"DPCoordinator scaled down from %s to %s engines",
|
||||
current_count,
|
||||
new_engine_count,
|
||||
)
|
||||
continue # Skip normal engine notification processing
|
||||
|
||||
# We received a message on the front-end XPUB socket,
|
||||
# from an API server sending a new request while the
|
||||
# engines are paused, so that we can wake the other
|
||||
# engines.
|
||||
engine_to_exclude, wave = decoded
|
||||
if not engines_running:
|
||||
if wave < current_wave:
|
||||
# If the wave number is stale, ensure the message
|
||||
# is handled by all the engines.
|
||||
engine_to_exclude = None
|
||||
|
||||
engines_running = True
|
||||
wave_state_changed = True
|
||||
self._send_start_wave(
|
||||
publish_back, current_wave, engine_to_exclude
|
||||
)
|
||||
|
||||
if output_back in events:
|
||||
# We received a message from one of the engines.
|
||||
|
||||
buffer = output_back.recv()
|
||||
outputs: EngineCoreOutputs = decoder.decode(buffer)
|
||||
|
||||
assert not outputs.outputs
|
||||
assert outputs.utility_output is None
|
||||
|
||||
eng_index = outputs.engine_index
|
||||
scheduler_stats = outputs.scheduler_stats
|
||||
if scheduler_stats:
|
||||
# 1. Updated request load stats - update our local
|
||||
# state with these.
|
||||
stats = self.engines[eng_index].request_counts
|
||||
stats_step = scheduler_stats.step_counter
|
||||
stats_wave = scheduler_stats.current_wave
|
||||
if (
|
||||
stats_wave > last_stats_wave
|
||||
or stats_wave == last_stats_wave
|
||||
and stats_step > last_stats_step
|
||||
):
|
||||
if stats_changed:
|
||||
last_step_counts = self._get_engine_counts(do_copy=True)
|
||||
last_stats_step = stats_step
|
||||
last_stats_wave = stats_wave
|
||||
elif stats_wave != last_stats_wave or (
|
||||
stats_step != last_stats_step
|
||||
):
|
||||
logger.warning(
|
||||
"Received stats for out-of-order "
|
||||
"step (%d, %d) from engine %d (expected "
|
||||
"> (%d, %d))",
|
||||
stats_wave,
|
||||
stats_step,
|
||||
eng_index,
|
||||
last_stats_wave,
|
||||
last_stats_step,
|
||||
)
|
||||
stats[0] = scheduler_stats.num_waiting_reqs
|
||||
stats[1] = scheduler_stats.num_running_reqs
|
||||
stats_changed = True
|
||||
|
||||
if (wave := outputs.wave_complete) is not None:
|
||||
# 2. Notification from rank 0 engine that we've
|
||||
# moved into the global paused state
|
||||
# (engines_running==False).
|
||||
if current_wave <= wave:
|
||||
new_wave = wave + 1
|
||||
logger.debug(
|
||||
"Moving DP wave from %d to %d.", current_wave, new_wave
|
||||
)
|
||||
current_wave = new_wave
|
||||
engines_running = False
|
||||
wave_state_changed = True
|
||||
elif (wave := outputs.start_wave) is not None and (
|
||||
wave > current_wave
|
||||
or (wave == current_wave and not engines_running)
|
||||
):
|
||||
# 3. The engine received request for a non-current wave
|
||||
# so we must ensure that other engines progress to the
|
||||
# next wave (race condition handling).
|
||||
logger.debug(
|
||||
"Starting wave %d after notification of "
|
||||
"stale wave request from engine.",
|
||||
wave,
|
||||
)
|
||||
current_wave = wave
|
||||
engines_running = True
|
||||
wave_state_changed = True
|
||||
self._send_start_wave(publish_back, wave, eng_index)
|
||||
|
||||
if wave_state_changed:
|
||||
message = (None, current_wave, engines_running)
|
||||
publish_front.send(msgspec.msgpack.encode(message))
|
||||
|
||||
@staticmethod
|
||||
def _send_start_wave(
|
||||
socket: zmq.Socket, wave: int, exclude_engine_index: int | None
|
||||
):
|
||||
"""Broadcast the START_DP_WAVE message to all the engines.
|
||||
It includes the current wave number and index of engine which
|
||||
has already received a request with this wave number and so doesn't
|
||||
require additional notification.
|
||||
"""
|
||||
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
|
||||
socket.send_multipart((EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
|
||||
|
||||
def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
|
||||
"""Return list of [waiting, running] count lists for each engine."""
|
||||
if do_copy:
|
||||
return [copy.copy(e.request_counts) for e in self.engines]
|
||||
return [e.request_counts for e in self.engines]
|
||||
1455
vllm/v1/engine/core.py
Normal file
1455
vllm/v1/engine/core.py
Normal file
File diff suppressed because it is too large
Load Diff
1416
vllm/v1/engine/core_client.py
Normal file
1416
vllm/v1/engine/core_client.py
Normal file
File diff suppressed because it is too large
Load Diff
351
vllm/v1/engine/detokenizer.py
Normal file
351
vllm/v1/engine/detokenizer.py
Normal file
@@ -0,0 +1,351 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import tokenizers
|
||||
from packaging import version
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.decoders import DecodeStream
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.detokenizer_utils import (
|
||||
convert_prompt_ids_to_tokens,
|
||||
detokenize_incrementally,
|
||||
)
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Only tokenizers >= 0.21.1 supports DecodeStream used for
|
||||
# FastIncrementalDetokenizer.
|
||||
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1")
|
||||
|
||||
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
|
||||
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
|
||||
|
||||
|
||||
class IncrementalDetokenizer:
|
||||
def __init__(self):
|
||||
self.token_ids: list[int] = []
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> list[int]:
|
||||
return self.token_ids
|
||||
|
||||
def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
|
||||
self.token_ids.extend(new_token_ids)
|
||||
return None
|
||||
|
||||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: TokenizerLike | None,
|
||||
request: EngineCoreRequest,
|
||||
) -> "IncrementalDetokenizer":
|
||||
assert request.sampling_params is not None
|
||||
|
||||
if tokenizer is None:
|
||||
# No tokenizer => skipping detokenization.
|
||||
return IncrementalDetokenizer()
|
||||
|
||||
if USE_FAST_DETOKENIZER and isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
# Fast tokenizer => use tokenizers library DecodeStream.
|
||||
return FastIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
# Fall back to slow python-based incremental detokenization.
|
||||
return SlowIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
|
||||
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
||||
def __init__(self, request: EngineCoreRequest):
|
||||
super().__init__()
|
||||
|
||||
# Stop strings
|
||||
params = request.sampling_params
|
||||
assert params is not None
|
||||
stop_list: list[str]
|
||||
if params.stop is None:
|
||||
stop_list = []
|
||||
elif isinstance(params.stop, str):
|
||||
stop_list = [params.stop]
|
||||
else:
|
||||
stop_list = params.stop
|
||||
self.stop = stop_list
|
||||
self.min_tokens = params.min_tokens
|
||||
self.include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
# Number of chars to hold back when stop strings are to be excluded
|
||||
# from streamed output.
|
||||
if self.stop and not self.include_stop_str_in_output:
|
||||
self.stop_buffer_length = max(len(s) for s in self.stop) - 1
|
||||
else:
|
||||
self.stop_buffer_length = 0
|
||||
self._last_output_text_offset: int = 0
|
||||
|
||||
# Generation data
|
||||
self.output_text = ""
|
||||
|
||||
def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
|
||||
"""
|
||||
Update RequestState for the request_id by:
|
||||
1) Detokenize the new token ids incrementally.
|
||||
2) Evaluate stop criteria.
|
||||
|
||||
Return matched stop string or None.
|
||||
"""
|
||||
if not new_token_ids:
|
||||
# Skip detokenization if no new token ids.
|
||||
return None
|
||||
|
||||
if stop_terminated and not self.include_stop_str_in_output:
|
||||
# If stop-terminated, exclude last token from detokenization
|
||||
# based on include_stop_str_in_output parameter.
|
||||
skipped_stop_token_id = new_token_ids[-1]
|
||||
new_token_ids = new_token_ids[:-1]
|
||||
else:
|
||||
skipped_stop_token_id = None
|
||||
|
||||
# 1) Detokenize the new token ids incrementally.
|
||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||
# new_token_ids is more than 1. We need to optimize this.
|
||||
stop_check_offset = len(self.output_text)
|
||||
for new_token_id in new_token_ids:
|
||||
self.token_ids.append(new_token_id)
|
||||
self.output_text += self.decode_next(new_token_id)
|
||||
# Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
|
||||
if self.min_tokens and len(self.output_token_ids) <= self.min_tokens:
|
||||
stop_check_offset = len(self.output_text)
|
||||
|
||||
if skipped_stop_token_id is not None:
|
||||
# Cleanup after skipping detokenization.
|
||||
self.token_ids.append(skipped_stop_token_id)
|
||||
|
||||
# 2) Evaluate stop strings.
|
||||
stop_string = None
|
||||
if self.stop and len(self.output_token_ids) > self.min_tokens:
|
||||
stop = check_stop_strings(
|
||||
output_text=self.output_text,
|
||||
new_char_count=len(self.output_text) - stop_check_offset,
|
||||
stop=self.stop,
|
||||
include_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
if stop is not None:
|
||||
stop_string, truncate_to = stop
|
||||
if truncate_to != -1:
|
||||
self.output_text = self.output_text[:truncate_to]
|
||||
|
||||
return stop_string
|
||||
|
||||
@abstractmethod
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
"""If delta is True, only new text since the last call to
|
||||
this method is returned"""
|
||||
|
||||
# We return the full output text if the sequence is finished.
|
||||
buffer_length = 0 if finished else self.stop_buffer_length
|
||||
if not delta:
|
||||
return (
|
||||
self.output_text[:-buffer_length]
|
||||
if buffer_length
|
||||
else (self.output_text)
|
||||
)
|
||||
length = len(self.output_text) - buffer_length
|
||||
last_offset = self._last_output_text_offset
|
||||
if last_offset < length:
|
||||
self._last_output_text_offset = length
|
||||
return self.output_text[last_offset:length]
|
||||
return ""
|
||||
|
||||
|
||||
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
|
||||
self.request_id = request.request_id
|
||||
self.skip_special_tokens = sampling_params.skip_special_tokens
|
||||
self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
|
||||
|
||||
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
||||
|
||||
# Find a safe place to start.
|
||||
prompt_token_ids = request.prompt_token_ids or []
|
||||
prompt_suffix = prompt_token_ids
|
||||
prompt_len = len(prompt_suffix)
|
||||
if prompt_len > 4:
|
||||
for i in range(4, min(prompt_len + 1, 24)):
|
||||
suffix = prompt_token_ids[-i:]
|
||||
if "<EFBFBD>" not in self.tokenizer.decode(suffix):
|
||||
prompt_suffix = suffix
|
||||
break
|
||||
|
||||
# Prime the stream.
|
||||
for tid in prompt_suffix:
|
||||
self._protected_step(tid)
|
||||
|
||||
self.spaces_between_special_tokens = (
|
||||
sampling_params.skip_special_tokens
|
||||
or sampling_params.spaces_between_special_tokens
|
||||
)
|
||||
|
||||
if not self.spaces_between_special_tokens:
|
||||
# Store dict of added token ids so that we can suppress
|
||||
# the spaces between them.
|
||||
if (
|
||||
added_token_ids := getattr(self.tokenizer, "added_token_ids", None)
|
||||
) is None:
|
||||
self.tokenizer.added_token_ids = added_token_ids = {
|
||||
tid: tok.content
|
||||
for tid, tok in self.tokenizer.get_added_tokens_decoder().items()
|
||||
}
|
||||
|
||||
if added_token_ids:
|
||||
self.last_special = False
|
||||
self.added_token_ids = added_token_ids
|
||||
else:
|
||||
# No added tokens.
|
||||
self.spaces_between_special_tokens = True
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
token = self._protected_step(next_token_id)
|
||||
|
||||
if not self.spaces_between_special_tokens:
|
||||
special_token = self.added_token_ids.get(next_token_id)
|
||||
is_special = special_token is not None
|
||||
if is_special and self.last_special:
|
||||
# Return raw token string without any prefixed spaces.
|
||||
token = special_token
|
||||
self.last_special = is_special
|
||||
|
||||
return token or ""
|
||||
|
||||
def _protected_step(self, next_token_id: int) -> str | None:
|
||||
try:
|
||||
token = self.stream.step(self.tokenizer, next_token_id)
|
||||
except (OverflowError, TypeError):
|
||||
# Handle rare observed overflow, still to be diagnosed.
|
||||
# See https://github.com/vllm-project/vllm/issues/21951.
|
||||
logger.exception("Encountered invalid token id: %r", next_token_id)
|
||||
token = None
|
||||
except Exception as e:
|
||||
if not str(e).startswith(INVALID_PREFIX_ERR_MSG):
|
||||
raise e
|
||||
# Recover from edge case where tokenizer can produce non-monotonic,
|
||||
# invalid UTF-8 output, which breaks the internal state of
|
||||
# tokenizers' DecodeStream.
|
||||
# See https://github.com/vllm-project/vllm/issues/17448.
|
||||
logger.warning(
|
||||
"Encountered invalid prefix detokenization error"
|
||||
" for request %s, resetting decode stream.",
|
||||
self.request_id,
|
||||
)
|
||||
self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
|
||||
token = self.stream.step(self.tokenizer, next_token_id)
|
||||
return token
|
||||
|
||||
|
||||
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike, request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
params = request.sampling_params
|
||||
assert params is not None
|
||||
|
||||
self.prompt_len = length_from_prompt_token_ids_or_embeds(
|
||||
request.prompt_token_ids, request.prompt_embeds
|
||||
)
|
||||
|
||||
# Metadata for incremental detokenization.
|
||||
if request.prompt_token_ids is not None:
|
||||
self.tokens, self.prefix_offset, self.read_offset = (
|
||||
convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=request.prompt_token_ids,
|
||||
skip_special_tokens=params.skip_special_tokens,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Prompt embedding requests cannot be detokenized, in general.
|
||||
self.tokens = [""] * self.prompt_len
|
||||
self.prefix_offset = 0
|
||||
self.read_offest = 0
|
||||
|
||||
self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len)
|
||||
|
||||
self.skip_special_tokens = params.skip_special_tokens
|
||||
self.spaces_between_special_tokens = params.spaces_between_special_tokens
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> list[int]:
|
||||
return (
|
||||
self.token_ids
|
||||
if not self.prompt_len
|
||||
else (self.token_ids[self.prompt_len :])
|
||||
)
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally(
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=self.token_ids,
|
||||
prev_tokens=self.tokens,
|
||||
prefix_offset=self.prefix_offset,
|
||||
read_offset=self.read_offset,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
self.tokens.extend(new_tokens)
|
||||
self.prefix_offset = prefix_offset
|
||||
self.read_offset = read_offset
|
||||
|
||||
return decoded_text
|
||||
|
||||
|
||||
def check_stop_strings(
|
||||
output_text: str,
|
||||
new_char_count: int,
|
||||
stop: list[str],
|
||||
include_in_output: bool,
|
||||
) -> tuple[str, int] | None:
|
||||
"""Check if any stop strings are matched and truncate sequence
|
||||
output text accordingly.
|
||||
|
||||
Returns tuple (stop_string, offset) if matched or else None.
|
||||
|
||||
Where stop_string is the matched stop string and offset is the
|
||||
length to which output_text should be truncated, or -1 for no
|
||||
truncation.
|
||||
"""
|
||||
if not new_char_count or not stop:
|
||||
return None
|
||||
|
||||
for stop_str in stop:
|
||||
stop_string_len = len(stop_str)
|
||||
# Avoid searching already-searched text.
|
||||
stop_index = output_text.find(stop_str, 1 - new_char_count - stop_string_len)
|
||||
if stop_index == -1:
|
||||
continue
|
||||
|
||||
if include_in_output:
|
||||
# Truncate to end of stop string.
|
||||
stop_index += stop_string_len
|
||||
if stop_index >= len(output_text):
|
||||
# No truncation required.
|
||||
return stop_str, -1
|
||||
|
||||
# Truncate the output text to either the beginning
|
||||
# or end of the stop string.
|
||||
return stop_str, stop_index
|
||||
return None
|
||||
18
vllm/v1/engine/exceptions.py
Normal file
18
vllm/v1/engine/exceptions.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
class EngineGenerateError(Exception):
|
||||
"""Raised when a AsyncLLM.generate() fails. Recoverable."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EngineDeadError(Exception):
|
||||
"""Raised when the EngineCore dies. Unrecoverable."""
|
||||
|
||||
def __init__(self, *args, suppress_context: bool = False, **kwargs):
|
||||
ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501
|
||||
|
||||
super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs)
|
||||
# Make stack trace clearer when using with LLMEngine by
|
||||
# silencing irrelevant ZMQError.
|
||||
self.__suppress_context__ = suppress_context
|
||||
643
vllm/v1/engine/input_processor.py
Normal file
643
vllm/v1/engine/input_processor.py
Normal file
@@ -0,0 +1,643 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import processor_cache_from_config
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||
from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar
|
||||
from vllm.v1.structured_output.backend_lm_format_enforcer import (
|
||||
validate_structured_output_request_lm_format_enforcer,
|
||||
)
|
||||
from vllm.v1.structured_output.backend_outlines import (
|
||||
validate_structured_output_request_outlines,
|
||||
)
|
||||
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class InputProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
tokenizer: TokenizerLike | None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.structured_outputs_config = vllm_config.structured_outputs_config
|
||||
|
||||
self.generation_config_fields = self.model_config.try_get_generation_config()
|
||||
|
||||
self.mm_registry = mm_registry
|
||||
self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry)
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(
|
||||
self.model_config,
|
||||
tokenizer,
|
||||
mm_registry,
|
||||
mm_processor_cache=self.mm_processor_cache,
|
||||
)
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
return self.input_preprocessor.tokenizer
|
||||
|
||||
def _validate_logprobs(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
max_logprobs = self.model_config.max_logprobs
|
||||
if max_logprobs == -1:
|
||||
max_logprobs = self.model_config.get_vocab_size()
|
||||
|
||||
# Validate sample logprobs.
|
||||
if params.logprobs:
|
||||
num_logprobs = params.logprobs
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = self.model_config.get_vocab_size()
|
||||
if num_logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Requested sample logprobs of {num_logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}"
|
||||
)
|
||||
|
||||
# Validate prompt logprobs.
|
||||
if params.prompt_logprobs:
|
||||
num_prompt_logprobs = params.prompt_logprobs
|
||||
if num_prompt_logprobs == -1:
|
||||
num_prompt_logprobs = self.model_config.get_vocab_size()
|
||||
if num_prompt_logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Requested prompt logprobs of {num_prompt_logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}"
|
||||
)
|
||||
|
||||
def _validate_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
self._validate_structured_output(params)
|
||||
self._validate_logit_bias(params)
|
||||
|
||||
if params.allowed_token_ids is None:
|
||||
return
|
||||
if not params.allowed_token_ids:
|
||||
raise ValueError("allowed_token_ids is not None and empty!")
|
||||
if self.tokenizer is None:
|
||||
# When skip_tokenizer_init=True, we can't validate token IDs
|
||||
# Skip validation and let the model handle invalid tokens
|
||||
return
|
||||
vocab_size = len(self.tokenizer)
|
||||
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
|
||||
raise ValueError("allowed_token_ids contains out-of-vocab token id!")
|
||||
|
||||
def _validate_logit_bias(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
"""Validate logit_bias token IDs are within vocabulary range."""
|
||||
if not params.logit_bias:
|
||||
return
|
||||
|
||||
vocab_size = self.model_config.get_vocab_size()
|
||||
invalid_token_ids = []
|
||||
|
||||
for token_id in params.logit_bias:
|
||||
if token_id < 0 or token_id >= vocab_size:
|
||||
invalid_token_ids.append(token_id)
|
||||
|
||||
if invalid_token_ids:
|
||||
raise ValueError(
|
||||
f"token_id(s) {invalid_token_ids} in logit_bias contain "
|
||||
f"out-of-vocab token ids. Vocabulary size: {vocab_size}"
|
||||
)
|
||||
|
||||
def _validate_supported_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
# Logits processors not supported.
|
||||
if params.logits_processors:
|
||||
raise ValueError(
|
||||
"vLLM V1 does not support per request user provided logits processors."
|
||||
)
|
||||
# Async scheduling + spec decode currently incompatible with some
|
||||
# sampling parameters.
|
||||
if (
|
||||
self.vllm_config.speculative_config is not None
|
||||
and self.vllm_config.scheduler_config.async_scheduling
|
||||
and (
|
||||
params.frequency_penalty != 0.0
|
||||
or params.presence_penalty != 0.0
|
||||
or params.repetition_penalty != 1.0
|
||||
or params.bad_words_token_ids
|
||||
or params.structured_outputs
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"async scheduling with spec decoding doesn't yet support "
|
||||
"penalties, bad words or structured outputs in sampling parameters."
|
||||
)
|
||||
|
||||
def _validate_params(
|
||||
self,
|
||||
params: SamplingParams | PoolingParams,
|
||||
):
|
||||
"""
|
||||
Validate supported SamplingParam.
|
||||
Should raise ValueError if unsupported for API Server.
|
||||
"""
|
||||
|
||||
if isinstance(params, PoolingParams):
|
||||
return
|
||||
|
||||
self._validate_logprobs(params)
|
||||
self._validate_sampling_params(params)
|
||||
self._validate_supported_sampling_params(params)
|
||||
|
||||
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
|
||||
"""
|
||||
Validate that user-provided multi_modal_uuids align with
|
||||
multi_modal_data in the incoming request prompt(s).
|
||||
Only checks lengths; `None` entries are allowed and will be
|
||||
auto-hashed downstream.
|
||||
"""
|
||||
|
||||
def _validate_single_prompt(single_prompt: dict | str) -> None:
|
||||
if not isinstance(single_prompt, dict):
|
||||
return
|
||||
|
||||
mm_data = single_prompt.get("multi_modal_data")
|
||||
mm_uuids = single_prompt.get("multi_modal_uuids")
|
||||
if not mm_data or not mm_uuids:
|
||||
return
|
||||
|
||||
import torch
|
||||
|
||||
def _get_len(items: object):
|
||||
if isinstance(items, dict): # Embedding inputs
|
||||
return _get_len(next(iter(items.values()))) if items else 1
|
||||
|
||||
if isinstance(items, list):
|
||||
return len(items)
|
||||
if isinstance(items, torch.Tensor):
|
||||
# To keep backwards compatibility for single item embedding input
|
||||
return 1 if getattr(items, "_is_single_item", False) else len(items)
|
||||
|
||||
return 1
|
||||
|
||||
for modality, items in mm_data.items():
|
||||
if modality in mm_uuids:
|
||||
data_len = _get_len(items)
|
||||
uuid_len = _get_len(mm_uuids[modality])
|
||||
if uuid_len != data_len:
|
||||
raise ValueError(
|
||||
f"multi_modal_uuids for modality {modality!r} "
|
||||
"must have same length as data: got "
|
||||
f"{uuid_len} uuids vs {data_len} items."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"multi_modal_uuids for modality {modality!r} must "
|
||||
"be provided if multi_modal_data is provided."
|
||||
)
|
||||
|
||||
# Handle explicit encoder/decoder prompts or singleton prompt
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
enc = prompt.get("encoder_prompt")
|
||||
dec = prompt.get("decoder_prompt")
|
||||
if enc is not None:
|
||||
_validate_single_prompt(cast(dict | str, enc))
|
||||
if dec is not None:
|
||||
_validate_single_prompt(cast(dict | str, dec))
|
||||
else:
|
||||
_validate_single_prompt(prompt) # type: ignore[arg-type]
|
||||
|
||||
def _validate_lora(self, lora_request: LoRARequest | None) -> None:
|
||||
if lora_request is None:
|
||||
return
|
||||
|
||||
# LoRA request passed in while LoRA is not enabled
|
||||
if not self.lora_config:
|
||||
raise ValueError(
|
||||
f"Got lora_request {lora_request} but LoRA is not enabled!"
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
logger.warning_once(
|
||||
"vLLM has deprecated support for supporting different "
|
||||
"tokenizers for different LoRAs. By default, vLLM uses base "
|
||||
"model's tokenizer. If you are using a LoRA "
|
||||
"with its own tokenizer, consider specifying `--tokenizer "
|
||||
"[lora_path]` to use the LoRA tokenizer."
|
||||
)
|
||||
|
||||
def _validate_structured_output(self, params: SamplingParams) -> None:
|
||||
if not params.structured_outputs or not self.structured_outputs_config:
|
||||
return
|
||||
|
||||
if self.model_config.skip_tokenizer_init and params.structured_outputs:
|
||||
raise ValueError(
|
||||
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
|
||||
)
|
||||
|
||||
backend = self.structured_outputs_config.backend
|
||||
if _backend := params.structured_outputs._backend:
|
||||
# Request-level backend selection is not supported.
|
||||
# The values may differ if `params` is reused and was set
|
||||
# to a specific backend based on `auto` behavior in a previous
|
||||
# request. We remember that it was set as a result of `auto`
|
||||
# using the `_backend_was_auto` field set in the params.
|
||||
if backend != _backend and not (
|
||||
backend == "auto" and params.structured_outputs._backend_was_auto
|
||||
):
|
||||
raise ValueError(
|
||||
"Request-level structured output backend selection is not "
|
||||
f"supported. The request specified '{_backend}', but vLLM "
|
||||
f"was initialised with '{backend}'. This error can be "
|
||||
"resolved by removing '_backend' from the request."
|
||||
)
|
||||
else:
|
||||
params.structured_outputs._backend = backend
|
||||
|
||||
# Request content validation
|
||||
if (
|
||||
isinstance(params.structured_outputs.choice, list)
|
||||
and not params.structured_outputs.choice
|
||||
):
|
||||
# It is invalid for choice to be an empty list
|
||||
raise ValueError(
|
||||
f"Choice '{params.structured_outputs.choice}' cannot be an empty list" # noqa: E501
|
||||
)
|
||||
# Reject empty string grammar early to avoid engine-side crashes
|
||||
if (
|
||||
isinstance(params.structured_outputs.grammar, str)
|
||||
and params.structured_outputs.grammar.strip() == ""
|
||||
):
|
||||
raise ValueError("structured_outputs.grammar cannot be an empty string")
|
||||
|
||||
if backend.startswith("xgrammar"):
|
||||
# xgrammar with no fallback
|
||||
validate_xgrammar_grammar(params)
|
||||
elif backend.startswith("guidance"):
|
||||
# TODO: ideally we would have the LLTokenizer here as Lark syntax
|
||||
# allows <|special_token|> and similar, see
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||
# Without tokenizer these are disallowed in grammars.
|
||||
if isinstance(self.tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"Mistral tokenizer is not supported for the 'guidance' "
|
||||
"structured output backend. Please use ['xgrammar', 'outlines'] "
|
||||
"backends or tokenizer_mode='hf' instead."
|
||||
)
|
||||
validate_guidance_grammar(params, tokenizer=None)
|
||||
elif backend == "outlines":
|
||||
# outlines backend
|
||||
validate_structured_output_request_outlines(params)
|
||||
elif backend == "lm-format-enforcer":
|
||||
# lm format enforcer backend
|
||||
if isinstance(self.tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
|
||||
"structured output backend. Please use ['xgrammar', 'outlines'] "
|
||||
"backends or tokenizer_mode='hf' instead."
|
||||
)
|
||||
validate_structured_output_request_lm_format_enforcer(params)
|
||||
else:
|
||||
# NOTE: backend must be "auto" here, because we have
|
||||
# checked supported_backends above.
|
||||
# In this mode, we set opinionated defaults based on what we think
|
||||
# will satisfy the most use cases without having to worry about
|
||||
# this setting. We include fallback behavior here, but not with any
|
||||
# other setting where a specific backend was specified.
|
||||
try:
|
||||
validate_xgrammar_grammar(params)
|
||||
params.structured_outputs._backend = "xgrammar"
|
||||
except ValueError:
|
||||
# The request either failed validation
|
||||
# or includes some jsonschema feature(s) that
|
||||
# are not supported in xgrammar.
|
||||
if isinstance(self.tokenizer, MistralTokenizer):
|
||||
# Fall back to outlines if the tokenizer is Mistral
|
||||
validate_structured_output_request_outlines(params)
|
||||
params.structured_outputs._backend = "outlines"
|
||||
else:
|
||||
# Fall back to guidance by default.
|
||||
validate_guidance_grammar(params, tokenizer=None)
|
||||
params.structured_outputs._backend = "guidance"
|
||||
# Remember that this backend was set automatically
|
||||
params.structured_outputs._backend_was_auto = True
|
||||
|
||||
def _maybe_build_mm_uuids(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
) -> MultiModalUUIDDict | None:
|
||||
"""Build per-item multimodal hash overrides when enabled. In this case,
|
||||
multimodal data items are identified by their request id, modality and
|
||||
index rather than their content.
|
||||
|
||||
Returns a dictionary of modality -> list[str] of overrides, or None if
|
||||
disabled or no multimodal data is present.
|
||||
"""
|
||||
|
||||
def _extract_mm_data(p: PromptType):
|
||||
if isinstance(p, dict) and "encoder_prompt" in p:
|
||||
enc = p.get("encoder_prompt")
|
||||
if isinstance(enc, dict):
|
||||
return enc.get("multi_modal_data")
|
||||
return None
|
||||
if isinstance(p, dict):
|
||||
return p.get("multi_modal_data")
|
||||
return None
|
||||
|
||||
mm_data = _extract_mm_data(prompt)
|
||||
if not mm_data:
|
||||
return None
|
||||
|
||||
mm_uuids: dict[str, list[str | None] | str] = {}
|
||||
for modality, data in mm_data.items():
|
||||
# Hash each item for embedding inputs.
|
||||
n = (
|
||||
len(data)
|
||||
if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
|
||||
else 1
|
||||
)
|
||||
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
|
||||
return mm_uuids
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: SamplingParams | PoolingParams,
|
||||
arrival_time: float | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: int | None = None,
|
||||
) -> EngineCoreRequest:
|
||||
self._validate_lora(lora_request)
|
||||
self._validate_params(params)
|
||||
|
||||
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
if data_parallel_rank is not None and not (
|
||||
0 <= data_parallel_rank < data_parallel_size
|
||||
):
|
||||
raise ValueError(
|
||||
f"data_parallel_rank {data_parallel_rank} "
|
||||
f"is out of range [0, {data_parallel_size})."
|
||||
)
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
# Optionally generate multimodal hash overrides to avoid hashing
|
||||
# multimodal data items by their content as their identifiers.
|
||||
|
||||
# NOTE: when users explicitly turn off BOTH prefix caching and input
|
||||
# processing caching, no multimodal features or embeddings will be
|
||||
# reused across requests, therefore identifying multimodal data items
|
||||
# by their content is no longer necessary, and we create uuids with
|
||||
# request id-modality-index as multimodal hash overrides.
|
||||
if (
|
||||
self.model_config.multimodal_config
|
||||
and self.model_config.multimodal_config.mm_processor_cache_gb == 0
|
||||
and not self.cache_config.enable_prefix_caching
|
||||
):
|
||||
mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
|
||||
else:
|
||||
# Otherwise, use user-provided uuids as multimodal hash overrides
|
||||
# if provided.
|
||||
self._validate_multi_modal_uuids(prompt)
|
||||
if isinstance(prompt, dict):
|
||||
mm_uuids = cast(
|
||||
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
|
||||
)
|
||||
else:
|
||||
mm_uuids = None
|
||||
|
||||
# Process inputs, which includes:
|
||||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||
# 2. For multimodal models with a merged preprocessor, preprocess
|
||||
# multimodal data and expand prompt token ids accordingly.
|
||||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.validate_request(
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
processed_inputs=processed_inputs,
|
||||
)
|
||||
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id()
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
self._validate_model_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
# Mypy can be conservative for TypedDict unions; normalize access.
|
||||
if decoder_inputs["type"] == "embeds":
|
||||
prompt_token_ids = None
|
||||
prompt_embeds = decoder_inputs["prompt_embeds"]
|
||||
else:
|
||||
prompt_token_ids = decoder_inputs["prompt_token_ids"]
|
||||
prompt_embeds = None
|
||||
|
||||
sampling_params = None
|
||||
pooling_params = None
|
||||
if isinstance(params, SamplingParams):
|
||||
# TODO: can we avoid cloning here in multiproc case?
|
||||
sampling_params = params.clone()
|
||||
# If unset max tokens, then generate up to the max_model_len.
|
||||
if sampling_params.max_tokens is None:
|
||||
seq_len = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids, prompt_embeds
|
||||
)
|
||||
sampling_params.max_tokens = self.model_config.max_model_len - seq_len
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id
|
||||
)
|
||||
if self.tokenizer is not None:
|
||||
sampling_params.update_from_tokenizer(self.tokenizer)
|
||||
else:
|
||||
pooling_params = params.clone()
|
||||
|
||||
# Multimodal related.
|
||||
mm_features: list[MultiModalFeatureSpec] | None = None
|
||||
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
|
||||
decoder_mm_positions = decoder_inputs["mm_placeholders"]
|
||||
decoder_mm_hashes = decoder_inputs["mm_hashes"]
|
||||
|
||||
# Merge and flatten multimodal placeholders, hashes and inputs
|
||||
# from dictionaries to lists, and sort them by each item's position
|
||||
# in the input sequence.
|
||||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||||
|
||||
mm_features = []
|
||||
for modality, idx in sorted_mm_idxs:
|
||||
mm_features.append(
|
||||
MultiModalFeatureSpec(
|
||||
data=decoder_mm_inputs[modality][idx],
|
||||
modality=modality,
|
||||
identifier=decoder_mm_hashes[modality][idx],
|
||||
mm_position=decoder_mm_positions[modality][idx],
|
||||
)
|
||||
)
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_embeds=prompt_embeds,
|
||||
mm_features=mm_features,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
eos_token_id=eos_token_id,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
cache_salt=decoder_inputs.get("cache_salt"),
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
def _validate_model_inputs(
|
||||
self, encoder_inputs: SingletonInputs | None, decoder_inputs: SingletonInputs
|
||||
):
|
||||
if encoder_inputs is not None:
|
||||
self._validate_model_input(encoder_inputs, prompt_type="encoder")
|
||||
|
||||
self._validate_model_input(decoder_inputs, prompt_type="decoder")
|
||||
|
||||
def _validate_model_input(
|
||||
self,
|
||||
prompt_inputs: SingletonInputs,
|
||||
*,
|
||||
prompt_type: Literal["encoder", "decoder"],
|
||||
):
|
||||
model_config = self.model_config
|
||||
|
||||
prompt_ids = (
|
||||
None
|
||||
if prompt_inputs["type"] == "embeds"
|
||||
else prompt_inputs["prompt_token_ids"]
|
||||
)
|
||||
prompt_embeds = (
|
||||
prompt_inputs["prompt_embeds"]
|
||||
if prompt_inputs["type"] == "embeds"
|
||||
else None
|
||||
)
|
||||
prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds)
|
||||
if not prompt_ids:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
pass # Mllama may have empty encoder inputs for text-only data
|
||||
elif prompt_inputs["type"] == "embeds":
|
||||
pass # Prompt embeds should not have prompt_ids.
|
||||
else:
|
||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is not None:
|
||||
max_input_id = max(prompt_ids or [], default=0)
|
||||
|
||||
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
|
||||
# self.model_config.get_vocab_size() is the model’s vocab size.
|
||||
# For Qwen3 models, the language model has extra tokens that do
|
||||
# not exist in the tokenizer, and vice versa for multimodal
|
||||
# placeholder tokens in some multimodal models.
|
||||
# See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
|
||||
# and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501
|
||||
|
||||
# Here we take the max of the two to determine if a token id is
|
||||
# truly out-of-vocabulary.
|
||||
if max_input_id > max(
|
||||
tokenizer.max_token_id, self.model_config.get_vocab_size() - 1
|
||||
):
|
||||
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
|
||||
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
if prompt_len > max_prompt_len:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
mm_registry = self.input_preprocessor.mm_registry
|
||||
mm_processor = mm_registry.create_processor(
|
||||
model_config,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
assert isinstance(mm_processor, EncDecMultiModalProcessor)
|
||||
|
||||
if mm_processor.pad_dummy_encoder_prompt:
|
||||
return # Skip encoder length check for Whisper
|
||||
|
||||
if model_config.is_multimodal_model:
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens plus multimodal tokens. For image "
|
||||
"inputs, the number of image tokens depends on the number "
|
||||
"of images, and possibly their aspect ratios as well."
|
||||
)
|
||||
else:
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"The {prompt_type} prompt (length {prompt_len}) is "
|
||||
f"longer than the maximum model length of {max_prompt_len}. "
|
||||
f"{suggestion}"
|
||||
)
|
||||
|
||||
# TODO: Find out how many placeholder tokens are there so we can
|
||||
# check that chunked prefill does not truncate them
|
||||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||
|
||||
if (
|
||||
prompt_len == max_prompt_len
|
||||
and prompt_type == "decoder"
|
||||
and not model_config.is_multimodal_model
|
||||
and self.model_config.runner_type != "pooling"
|
||||
):
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens (prompt + requested output tokens)."
|
||||
)
|
||||
raise ValueError(
|
||||
f"The {prompt_type} prompt (length {prompt_len}) plus the number of "
|
||||
f"requested output tokens (at least 1) is longer than the maximum "
|
||||
f"model length of {max_prompt_len}. {suggestion}"
|
||||
)
|
||||
|
||||
def stat_mm_cache(self) -> MultiModalCacheStats | None:
|
||||
return self.input_preprocessor.stat_mm_cache()
|
||||
|
||||
def clear_mm_cache(self) -> None:
|
||||
self.input_preprocessor.clear_mm_cache()
|
||||
414
vllm/v1/engine/llm_engine.py
Normal file
414
vllm/v1/engine/llm_engine.py
Normal file
@@ -0,0 +1,414 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from collections.abc import Callable, Mapping
|
||||
from copy import copy
|
||||
from typing import Any, cast
|
||||
|
||||
import torch.nn as nn
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
from vllm.v1.utils import record_function_or_nullcontext
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""Legacy LLMEngine for backwards compatibility."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
aggregate_engine_logging: bool = False,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: list[StatLoggerFactory] | None = None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
multiprocess_mode: bool = False,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
self.log_stats = log_stats
|
||||
|
||||
executor_backend = self.vllm_config.parallel_config.distributed_executor_backend
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.external_launcher_dp = (
|
||||
parallel_config.data_parallel_size > 1
|
||||
and executor_backend == "external_launcher"
|
||||
)
|
||||
# important: init dp group before init the engine_core
|
||||
# In the decoupled engine case this is handled in EngineCoreProc.
|
||||
if (
|
||||
not multiprocess_mode
|
||||
and parallel_config.data_parallel_size > 1
|
||||
and not self.external_launcher_dp
|
||||
):
|
||||
self.dp_group = parallel_config.stateless_init_dp_group()
|
||||
else:
|
||||
self.dp_group = None
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_tokenizer_from_config(self.model_config)
|
||||
|
||||
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(
|
||||
self.tokenizer,
|
||||
log_stats=self.log_stats,
|
||||
stream_interval=self.vllm_config.scheduler_config.stream_interval,
|
||||
)
|
||||
endpoint = self.observability_config.otlp_traces_endpoint
|
||||
if endpoint is not None:
|
||||
tracer = init_tracer("vllm.llm_engine", endpoint)
|
||||
self.output_processor.tracer = tracer
|
||||
|
||||
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
||||
self.engine_core = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocess_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
self.logger_manager: StatLoggerManager | None = None
|
||||
if self.log_stats:
|
||||
self.logger_manager = StatLoggerManager(
|
||||
vllm_config=vllm_config,
|
||||
custom_stat_loggers=stat_loggers,
|
||||
enable_default_loggers=log_stats,
|
||||
aggregate_engine_logging=aggregate_engine_logging,
|
||||
)
|
||||
self.logger_manager.log_engine_initialized()
|
||||
|
||||
if not multiprocess_mode:
|
||||
# for v0 compatibility
|
||||
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
||||
|
||||
if self.external_launcher_dp:
|
||||
# If we use DP in external launcher mode, we reuse the
|
||||
# existing DP group used for data communication.
|
||||
self.dp_group = get_dp_group().cpu_group
|
||||
|
||||
# Don't keep the dummy data in memory
|
||||
self.reset_mm_cache()
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"`LLMEngine.processor` has been renamed to `LLMEngine.input_processor`. "
|
||||
"The old name will be removed in v0.14."
|
||||
)
|
||||
def processor(self):
|
||||
return self.input_processor
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: list[StatLoggerFactory] | None = None,
|
||||
disable_log_stats: bool = False,
|
||||
) -> "LLMEngine":
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=(not disable_log_stats),
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: list[StatLoggerFactory] | None = None,
|
||||
enable_multiprocessing: bool = False,
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||
logger.debug("Enabling multiprocessing for LLMEngine.")
|
||||
enable_multiprocessing = True
|
||||
|
||||
# Create the LLMEngine.
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=enable_multiprocessing,
|
||||
)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.output_processor.get_num_unfinished_requests()
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
has_unfinished = self.output_processor.has_unfinished_requests()
|
||||
if self.dp_group is None:
|
||||
return has_unfinished or self.engine_core.dp_engines_running()
|
||||
return self.has_unfinished_requests_dp(has_unfinished)
|
||||
|
||||
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
|
||||
aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
|
||||
self.dp_group, has_unfinished
|
||||
)
|
||||
if not has_unfinished and aggregated_has_unfinished:
|
||||
self.should_execute_dummy_batch = True
|
||||
return aggregated_has_unfinished
|
||||
|
||||
@classmethod
|
||||
def validate_outputs(cls, outputs, output_type):
|
||||
return outputs
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.engine_core.get_supported_tasks()
|
||||
|
||||
def abort_request(self, request_ids: list[str]) -> None:
|
||||
"""Remove request_ids from EngineCore and Detokenizer."""
|
||||
|
||||
request_ids = self.output_processor.abort_requests(request_ids)
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: EngineCoreRequest | PromptType,
|
||||
params: SamplingParams | PoolingParams,
|
||||
arrival_time: float | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
prompt_text: str | None = None,
|
||||
) -> None:
|
||||
# Validate the request_id type.
|
||||
if not isinstance(request_id, str):
|
||||
raise TypeError(f"request_id must be a string, got {type(request_id)}")
|
||||
|
||||
# Process raw inputs into the request.
|
||||
if isinstance(prompt, EngineCoreRequest):
|
||||
request = prompt
|
||||
else:
|
||||
assert prompt_text is None
|
||||
request = self.input_processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
tokenization_kwargs,
|
||||
trace_headers,
|
||||
priority,
|
||||
)
|
||||
if isinstance(prompt, str):
|
||||
prompt_text = prompt
|
||||
elif isinstance(prompt, Mapping):
|
||||
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||
|
||||
# Use cloned params that may have been updated in process_inputs()
|
||||
params = request.params
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
|
||||
if n == 1:
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request, prompt_text, None, 0)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(request)
|
||||
return
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_req = ParentRequest(request_id, params)
|
||||
for idx in range(n):
|
||||
request_id, child_params = parent_req.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = child_params
|
||||
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(
|
||||
child_request, prompt_text, parent_req, idx
|
||||
)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(child_request)
|
||||
|
||||
def step(self) -> list[RequestOutput | PoolingRequestOutput]:
|
||||
if self.should_execute_dummy_batch:
|
||||
self.should_execute_dummy_batch = False
|
||||
self.engine_core.execute_dummy_batch()
|
||||
return []
|
||||
|
||||
# 1) Get EngineCoreOutput from the EngineCore.
|
||||
with record_function_or_nullcontext("llm_engine step: get_output"):
|
||||
outputs = self.engine_core.get_output()
|
||||
|
||||
# 2) Process EngineCoreOutputs.
|
||||
with record_function_or_nullcontext("llm_engine step: process_outputs"):
|
||||
iteration_stats = IterationStats() if self.log_stats else None
|
||||
processed_outputs = self.output_processor.process_outputs(
|
||||
outputs.outputs,
|
||||
engine_core_timestamp=outputs.timestamp,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
with record_function_or_nullcontext("llm_engine step: abort_requests"):
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
|
||||
# 4) Record stats
|
||||
with record_function_or_nullcontext("llm_engine step: record_stats"):
|
||||
if self.logger_manager is not None and outputs.scheduler_stats is not None:
|
||||
self.logger_manager.record(
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
mm_cache_stats=self.input_processor.stat_mm_cache(),
|
||||
)
|
||||
self.do_log_stats_with_interval()
|
||||
|
||||
return processed_outputs.request_outputs
|
||||
|
||||
def start_profile(self):
|
||||
self.engine_core.profile(True)
|
||||
|
||||
def stop_profile(self):
|
||||
self.engine_core.profile(False)
|
||||
|
||||
def reset_mm_cache(self):
|
||||
self.input_processor.clear_mm_cache()
|
||||
self.engine_core.reset_mm_cache()
|
||||
|
||||
def reset_prefix_cache(
|
||||
self, reset_running_requests: bool = False, reset_connector: bool = False
|
||||
) -> bool:
|
||||
return self.engine_core.reset_prefix_cache(
|
||||
reset_running_requests, reset_connector
|
||||
)
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
self.engine_core.sleep(level)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
self.logger_manager.record_sleep_state(1, level)
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None):
|
||||
self.engine_core.wake_up(tags)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
self.logger_manager.record_sleep_state(0, 0)
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
return self.engine_core.is_sleeping()
|
||||
|
||||
def get_metrics(self) -> list[Metric]:
|
||||
assert self.log_stats, "Stat logging disabled"
|
||||
return get_metrics_snapshot()
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
return self.input_processor.tokenizer
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||||
)
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
def do_log_stats(self) -> None:
|
||||
"""Log stats if logging is enabled."""
|
||||
if self.logger_manager:
|
||||
self.logger_manager.log()
|
||||
|
||||
def do_log_stats_with_interval(self) -> None:
|
||||
"""Log stats when the time interval has passed."""
|
||||
now = time.time()
|
||||
if not hasattr(self, "_last_log_time"):
|
||||
self._last_log_time = now
|
||||
if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL:
|
||||
self.do_log_stats()
|
||||
self._last_log_time = now
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
return self.engine_core.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove an already loaded LoRA adapter."""
|
||||
return self.engine_core.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
"""List all registered adapters."""
|
||||
return self.engine_core.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
def collective_rpc(
|
||||
self,
|
||||
method: str | Callable[[WorkerBase], _R],
|
||||
timeout: float | None = None,
|
||||
args: tuple = (),
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> list[_R]:
|
||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||
return self.collective_rpc("apply_model", args=(func,))
|
||||
|
||||
def __del__(self):
|
||||
dp_group = getattr(self, "dp_group", None)
|
||||
if dp_group is not None and not self.external_launcher_dp:
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
189
vllm/v1/engine/logprobs.py
Normal file
189
vllm/v1/engine/logprobs.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import (
|
||||
PromptLogprobs,
|
||||
SampleLogprobs,
|
||||
append_logprobs_for_next_position,
|
||||
create_prompt_logprobs,
|
||||
create_sample_logprobs,
|
||||
)
|
||||
from vllm.tokenizers.detokenizer_utils import (
|
||||
TokenizerLike,
|
||||
convert_ids_list_to_tokens,
|
||||
)
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NONES = itertools.repeat(None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogprobsProcessor:
|
||||
# Tokenizer for this request,
|
||||
# None if detokenization is disabled.
|
||||
tokenizer: TokenizerLike | None
|
||||
|
||||
# Logprobs for this request
|
||||
logprobs: SampleLogprobs | None
|
||||
prompt_logprobs: PromptLogprobs | None
|
||||
cumulative_logprob: float | None
|
||||
num_logprobs: int | None
|
||||
num_prompt_logprobs: int | None
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: TokenizerLike | None,
|
||||
request: EngineCoreRequest,
|
||||
) -> "LogprobsProcessor":
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
num_logprobs = sampling_params.logprobs
|
||||
num_prompt_logprobs = sampling_params.prompt_logprobs
|
||||
return cls(
|
||||
tokenizer=tokenizer,
|
||||
cumulative_logprob=(None if num_logprobs is None else 0.0),
|
||||
logprobs=(
|
||||
None
|
||||
if num_logprobs is None
|
||||
else create_sample_logprobs(sampling_params.flat_logprobs)
|
||||
),
|
||||
prompt_logprobs=(
|
||||
None
|
||||
if num_prompt_logprobs is None
|
||||
else create_prompt_logprobs(sampling_params.flat_logprobs)
|
||||
),
|
||||
num_prompt_logprobs=num_prompt_logprobs,
|
||||
num_logprobs=num_logprobs,
|
||||
)
|
||||
|
||||
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
|
||||
"""Update with sample logprobs from EngineCore.
|
||||
|
||||
Outer lists are only of len > 1 if EngineCore made
|
||||
>1 tokens in prior step (e.g. in spec decoding).
|
||||
|
||||
Args:
|
||||
logprobs_lists: the lists of logprob tokens, logprobs, and ranks.
|
||||
|
||||
"""
|
||||
|
||||
assert self.num_logprobs is not None
|
||||
assert self.logprobs is not None
|
||||
assert self.cumulative_logprob is not None
|
||||
|
||||
token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists
|
||||
|
||||
for rank_np, logprobs_np, token_ids_np in zip(
|
||||
ranks_lst, logprobs_lst, token_ids_lst
|
||||
):
|
||||
rank = rank_np.tolist()
|
||||
logprobs = logprobs_np.tolist()
|
||||
token_ids = token_ids_np.tolist()
|
||||
# Detokenize (non-incrementally).
|
||||
decoded_tokens = (
|
||||
NONES
|
||||
if self.tokenizer is None
|
||||
else (convert_ids_list_to_tokens(self.tokenizer, token_ids))
|
||||
)
|
||||
|
||||
# Sampler puts the sampled logprob in first.
|
||||
sampled_token_logprob = logprobs[0]
|
||||
self.cumulative_logprob += sampled_token_logprob
|
||||
|
||||
# Update with the Logprob container for this pos.
|
||||
append_logprobs_for_next_position(
|
||||
self.logprobs,
|
||||
token_ids,
|
||||
logprobs,
|
||||
decoded_tokens,
|
||||
rank,
|
||||
self.num_logprobs,
|
||||
)
|
||||
|
||||
def _update_prompt_logprobs(
|
||||
self,
|
||||
prompt_logprobs_tensors: LogprobsTensors,
|
||||
) -> None:
|
||||
"""Update with prompt logprobs from EngineCore.
|
||||
|
||||
Args:
|
||||
prompt_logprobs_tensors: tuple containing the prompt logprobs
|
||||
tensors.
|
||||
|
||||
"""
|
||||
|
||||
# Prompt logprobs are enabled.
|
||||
assert self.num_prompt_logprobs is not None
|
||||
assert self.prompt_logprobs is not None
|
||||
|
||||
token_ids, logprobs, ranks = prompt_logprobs_tensors
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
decoded_tokens = (
|
||||
None
|
||||
if self.tokenizer is None
|
||||
else (
|
||||
convert_ids_list_to_tokens(self.tokenizer, token_ids.flatten().tolist())
|
||||
)
|
||||
)
|
||||
|
||||
# Recover shapes.
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
|
||||
# Pythonize the torch tensors.
|
||||
prompt_token_ranks = ranks.tolist()
|
||||
prompt_logprobs = logprobs.tolist()
|
||||
token_ids = token_ids.tolist()
|
||||
|
||||
# Make Logprob for each position.
|
||||
for pos in range(num_prompt_tokens):
|
||||
# Handle flattening.
|
||||
offset = pos * num_logprobs
|
||||
offset_end = offset + num_logprobs
|
||||
decoded_tokens_for_pos = (
|
||||
NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
|
||||
)
|
||||
|
||||
# Update with the Logprob container for this pos.
|
||||
append_logprobs_for_next_position(
|
||||
self.prompt_logprobs,
|
||||
token_ids[pos],
|
||||
prompt_logprobs[pos],
|
||||
decoded_tokens_for_pos,
|
||||
prompt_token_ranks[pos],
|
||||
self.num_prompt_logprobs,
|
||||
)
|
||||
|
||||
def pop_prompt_logprobs(self) -> PromptLogprobs | None:
|
||||
"""Pop and return all request prompt logprobs
|
||||
|
||||
The logprobs processor aggregates prompt chunk logprobs
|
||||
over one or more prefill chunks. This method returns
|
||||
all prompt logprobs at once and then forgets them.
|
||||
Ensures correct RequestOutputKind.DELTA semantics
|
||||
wherein all prompt logprobs are returned at once at
|
||||
the end of prefill.
|
||||
|
||||
Returns:
|
||||
None if prompt logprobs are disabled for this request.
|
||||
List of all prompt logprobs, otherwise.
|
||||
"""
|
||||
plp = self.prompt_logprobs
|
||||
if plp:
|
||||
self.prompt_logprobs = []
|
||||
return plp
|
||||
|
||||
def update_from_output(self, output: EngineCoreOutput) -> None:
|
||||
if output.new_logprobs is not None:
|
||||
self._update_sample_logprobs(output.new_logprobs)
|
||||
if output.new_prompt_logprobs_tensors is not None:
|
||||
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)
|
||||
659
vllm/v1/engine/output_processor.py
Normal file
659
vllm/v1/engine/output_processor.py
Normal file
@@ -0,0 +1,659 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.outputs import (
|
||||
CompletionOutput,
|
||||
PoolingOutput,
|
||||
PoolingRequestOutput,
|
||||
RequestOutput,
|
||||
)
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.metrics.stats import (
|
||||
IterationStats,
|
||||
LoRARequestStates,
|
||||
RequestStateStats,
|
||||
SchedulerStats,
|
||||
)
|
||||
|
||||
|
||||
class RequestOutputCollector:
|
||||
"""
|
||||
Collects streamed RequestOutputs per individual request,
|
||||
for hand-off to the consuming asyncio generate task.
|
||||
|
||||
When streaming deltas, RequestOutputs are merged if the
|
||||
producer gets ahead of the consumer.
|
||||
"""
|
||||
|
||||
def __init__(self, output_kind: RequestOutputKind):
|
||||
self.aggregate = output_kind == RequestOutputKind.DELTA
|
||||
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
|
||||
self.ready = asyncio.Event()
|
||||
|
||||
def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
|
||||
"""Non-blocking put operation."""
|
||||
if self.output is None or isinstance(output, Exception):
|
||||
self.output = output
|
||||
self.ready.set()
|
||||
elif isinstance(self.output, RequestOutput) and isinstance(
|
||||
output, RequestOutput
|
||||
):
|
||||
# This ensures that request outputs with different request indexes
|
||||
# (if n > 1) do not override each other.
|
||||
self.output.add(output, aggregate=self.aggregate)
|
||||
elif isinstance(self.output, PoolingRequestOutput) and isinstance(
|
||||
output, PoolingRequestOutput
|
||||
):
|
||||
self.output = output
|
||||
|
||||
async def get(self) -> RequestOutput | PoolingRequestOutput:
|
||||
"""Get operation blocks on put event."""
|
||||
while (output := self.output) is None:
|
||||
await self.ready.wait()
|
||||
self.output = None
|
||||
self.ready.clear()
|
||||
if isinstance(output, Exception):
|
||||
raise output
|
||||
return output
|
||||
|
||||
def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
|
||||
"""Non-blocking get operation."""
|
||||
output = self.output
|
||||
if output is not None:
|
||||
self.output = None
|
||||
self.ready.clear()
|
||||
if isinstance(output, Exception):
|
||||
raise output
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputProcessorOutput:
|
||||
request_outputs: list[RequestOutput | PoolingRequestOutput]
|
||||
reqs_to_abort: list[str]
|
||||
|
||||
|
||||
class RequestState:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
parent_req: ParentRequest | None,
|
||||
request_index: int,
|
||||
lora_name: str | None,
|
||||
output_kind: RequestOutputKind,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None,
|
||||
prompt_embeds: torch.Tensor | None,
|
||||
logprobs_processor: LogprobsProcessor | None,
|
||||
detokenizer: IncrementalDetokenizer | None,
|
||||
max_tokens_param: int | None,
|
||||
arrival_time: float,
|
||||
queue: RequestOutputCollector | None,
|
||||
log_stats: bool,
|
||||
stream_interval: int,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
temperature: float | None = None,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.parent_req = parent_req
|
||||
self.request_index = request_index
|
||||
self.lora_name = lora_name
|
||||
self.output_kind = output_kind
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_embeds = prompt_embeds
|
||||
self.prompt_len = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds
|
||||
)
|
||||
self.logprobs_processor = logprobs_processor
|
||||
self.detokenizer = detokenizer
|
||||
self.max_tokens_param = max_tokens_param
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.temperature = temperature
|
||||
self.is_prefilling = True
|
||||
self.queue = queue
|
||||
self.num_cached_tokens = 0
|
||||
|
||||
self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
|
||||
|
||||
# Stream Interval
|
||||
self.stream_interval = stream_interval
|
||||
self.sent_tokens_offset = 0 # Offset of sent tokens
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: TokenizerLike | None,
|
||||
request: EngineCoreRequest,
|
||||
prompt: str | None,
|
||||
parent_req: ParentRequest | None,
|
||||
request_index: int,
|
||||
queue: RequestOutputCollector | None,
|
||||
log_stats: bool,
|
||||
stream_interval: int,
|
||||
) -> "RequestState":
|
||||
if sampling_params := request.sampling_params:
|
||||
if not sampling_params.detokenize:
|
||||
tokenizer = None
|
||||
output_kind = sampling_params.output_kind
|
||||
logprobs_processor = LogprobsProcessor.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
)
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
)
|
||||
max_tokens_param = sampling_params.max_tokens
|
||||
top_p = sampling_params.top_p
|
||||
n = sampling_params.n
|
||||
temperature = sampling_params.temperature
|
||||
else:
|
||||
logprobs_processor = None
|
||||
detokenizer = None
|
||||
max_tokens_param = None
|
||||
top_p = None
|
||||
n = None
|
||||
temperature = None
|
||||
assert request.pooling_params is not None
|
||||
output_kind = request.pooling_params.output_kind
|
||||
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
lora_name=(
|
||||
request.lora_request.name if request.lora_request is not None else None
|
||||
),
|
||||
output_kind=output_kind,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
logprobs_processor=logprobs_processor,
|
||||
detokenizer=detokenizer,
|
||||
max_tokens_param=max_tokens_param,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
temperature=temperature,
|
||||
arrival_time=request.arrival_time,
|
||||
queue=queue,
|
||||
log_stats=log_stats,
|
||||
stream_interval=stream_interval,
|
||||
)
|
||||
|
||||
def make_request_output(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
pooling_output: torch.Tensor | None,
|
||||
finish_reason: FinishReason | None,
|
||||
stop_reason: int | str | None,
|
||||
kv_transfer_params: dict[str, Any] | None = None,
|
||||
) -> RequestOutput | PoolingRequestOutput | None:
|
||||
finished = finish_reason is not None
|
||||
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
|
||||
|
||||
if not finished and final_only:
|
||||
# Only the final output is required in FINAL_ONLY mode.
|
||||
return None
|
||||
|
||||
if self.stream_interval > 1:
|
||||
assert self.detokenizer is not None
|
||||
|
||||
# Send output request only when
|
||||
# 1. It has finished, or
|
||||
# 2. It is the first token, or
|
||||
# 3. It has reached the stream interval number of tokens
|
||||
if not (
|
||||
finished
|
||||
or self.sent_tokens_offset == 0
|
||||
or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset
|
||||
>= self.stream_interval
|
||||
):
|
||||
return None
|
||||
|
||||
if self.output_kind == RequestOutputKind.DELTA:
|
||||
# Send tokens from the offset in DELTA mode, otherwise all
|
||||
# tokens are sent.
|
||||
new_token_ids = self.detokenizer.output_token_ids[
|
||||
self.sent_tokens_offset :
|
||||
]
|
||||
self.sent_tokens_offset = len(self.detokenizer.output_token_ids)
|
||||
|
||||
request_id = self.request_id
|
||||
if pooling_output is not None:
|
||||
return self._new_request_output(
|
||||
request_id, [self._new_pooling_output(pooling_output)], finished
|
||||
)
|
||||
|
||||
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
|
||||
|
||||
if self.parent_req is None:
|
||||
outputs = [output]
|
||||
else:
|
||||
request_id, outputs, finished = self.parent_req.get_outputs(
|
||||
request_id, output
|
||||
)
|
||||
if not outputs:
|
||||
return None
|
||||
|
||||
return self._new_request_output(
|
||||
request_id, outputs, finished, kv_transfer_params
|
||||
)
|
||||
|
||||
def _new_request_output(
|
||||
self,
|
||||
request_id: str,
|
||||
outputs: list[CompletionOutput] | list[PoolingOutput],
|
||||
finished: bool,
|
||||
kv_transfer_params: dict[str, Any] | None = None,
|
||||
) -> RequestOutput | PoolingRequestOutput:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, PoolingOutput):
|
||||
assert len(outputs) == 1
|
||||
# Prompt embeddings are currently not supported by pooling requests.
|
||||
assert self.prompt_token_ids is not None
|
||||
return PoolingRequestOutput(
|
||||
request_id=request_id,
|
||||
outputs=first_output,
|
||||
num_cached_tokens=self.num_cached_tokens,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
finished=finished,
|
||||
)
|
||||
assert self.logprobs_processor is not None
|
||||
if self.output_kind == RequestOutputKind.DELTA:
|
||||
# Side effect: logprobs processor forgets prompt logprobs
|
||||
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
|
||||
else:
|
||||
prompt_logprobs = self.logprobs_processor.prompt_logprobs
|
||||
|
||||
# If prompt embeds were used, put placeholder prompt token ids
|
||||
prompt_token_ids = self.prompt_token_ids
|
||||
if prompt_token_ids is None and self.prompt_embeds is not None:
|
||||
prompt_token_ids = [0] * len(self.prompt_embeds)
|
||||
|
||||
return RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=self.prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
outputs=cast(list[CompletionOutput], outputs),
|
||||
finished=finished,
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
num_cached_tokens=self.num_cached_tokens,
|
||||
metrics=self.stats,
|
||||
)
|
||||
|
||||
def _new_completion_output(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
finish_reason: FinishReason | None,
|
||||
stop_reason: int | str | None,
|
||||
) -> CompletionOutput:
|
||||
assert self.detokenizer is not None
|
||||
assert self.logprobs_processor is not None
|
||||
finished = finish_reason is not None
|
||||
delta = self.output_kind == RequestOutputKind.DELTA
|
||||
|
||||
# Prepare text and token_ids, based on delta mode
|
||||
text = self.detokenizer.get_next_output_text(finished, delta)
|
||||
if not delta:
|
||||
token_ids = self.detokenizer.output_token_ids
|
||||
|
||||
# Prepare logprobs, based on delta mode
|
||||
logprobs = self.logprobs_processor.logprobs
|
||||
if delta and logprobs:
|
||||
logprobs = logprobs[-len(token_ids) :]
|
||||
|
||||
return CompletionOutput(
|
||||
index=self.request_index,
|
||||
text=text,
|
||||
token_ids=token_ids,
|
||||
logprobs=logprobs,
|
||||
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
|
||||
finish_reason=str(finish_reason) if finished else None,
|
||||
stop_reason=stop_reason if finished else None,
|
||||
)
|
||||
|
||||
def _new_pooling_output(
|
||||
self,
|
||||
pooling_output: torch.Tensor,
|
||||
) -> PoolingOutput:
|
||||
return PoolingOutput(data=pooling_output)
|
||||
|
||||
|
||||
class OutputProcessor:
|
||||
"""Process EngineCoreOutputs into RequestOutputs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: TokenizerLike | None,
|
||||
log_stats: bool,
|
||||
stream_interval: int = 1,
|
||||
):
|
||||
self.log_stats = log_stats
|
||||
self.tokenizer = tokenizer
|
||||
self.stream_interval = stream_interval
|
||||
self.request_states: dict[str, RequestState] = {}
|
||||
self.parent_requests: dict[str, ParentRequest] = {}
|
||||
self.lora_states = LoRARequestStates(log_stats)
|
||||
self.tracer: Tracer | None = None
|
||||
self._requests_drained = asyncio.Event()
|
||||
self._requests_drained.set()
|
||||
|
||||
def get_num_unfinished_requests(self):
|
||||
return len(self.request_states)
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return len(self.request_states) > 0
|
||||
|
||||
async def wait_for_requests_to_drain(self) -> None:
|
||||
if not self.request_states:
|
||||
return
|
||||
await self._requests_drained.wait()
|
||||
|
||||
def propagate_error(self, e: Exception):
|
||||
"""Propagate error to all generate() tasks."""
|
||||
|
||||
for _, state in self.request_states.items():
|
||||
assert state.queue is not None
|
||||
state.queue.put(e)
|
||||
|
||||
def abort_requests(
|
||||
self,
|
||||
request_ids: Iterable[str],
|
||||
) -> list[str]:
|
||||
request_ids_to_abort = []
|
||||
for request_id in request_ids:
|
||||
req_state = self.request_states.pop(request_id, None)
|
||||
if req_state is not None:
|
||||
self.lora_states.request_finished(request_id, req_state.lora_name)
|
||||
request_ids_to_abort.append(request_id)
|
||||
# Produce final abort output.
|
||||
if req_state.queue is not None and (
|
||||
request_output := req_state.make_request_output(
|
||||
new_token_ids=[],
|
||||
# Set pooling_output is not None to
|
||||
# correctly enter the abort pooling branch
|
||||
pooling_output=torch.randn(0, device="cpu")
|
||||
if req_state.detokenizer is None
|
||||
else None,
|
||||
finish_reason=FinishReason.ABORT,
|
||||
stop_reason=None,
|
||||
kv_transfer_params=None,
|
||||
)
|
||||
):
|
||||
req_state.queue.put(request_output)
|
||||
elif parent := self.parent_requests.get(request_id):
|
||||
# Abort children prior to removing the parent.
|
||||
if parent.child_requests:
|
||||
child_reqs = list(parent.child_requests)
|
||||
child_reqs = self.abort_requests(child_reqs)
|
||||
request_ids_to_abort.extend(child_reqs)
|
||||
self.parent_requests.pop(request_id, None)
|
||||
if not self.request_states:
|
||||
self._requests_drained.set()
|
||||
return request_ids_to_abort
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: EngineCoreRequest,
|
||||
prompt: str | None,
|
||||
parent_req: ParentRequest | None = None,
|
||||
request_index: int = 0,
|
||||
queue: RequestOutputCollector | None = None,
|
||||
) -> None:
|
||||
request_id = request.request_id
|
||||
if request_id in self.request_states:
|
||||
raise ValueError(f"Request id {request_id} already running.")
|
||||
|
||||
req_state = RequestState.from_new_request(
|
||||
tokenizer=self.tokenizer,
|
||||
request=request,
|
||||
prompt=prompt,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
queue=queue,
|
||||
log_stats=self.log_stats,
|
||||
stream_interval=self.stream_interval,
|
||||
)
|
||||
if self._requests_drained.is_set():
|
||||
self._requests_drained.clear()
|
||||
self.request_states[request_id] = req_state
|
||||
if parent_req:
|
||||
self.parent_requests[parent_req.request_id] = parent_req
|
||||
|
||||
def process_outputs(
|
||||
self,
|
||||
engine_core_outputs: list[EngineCoreOutput],
|
||||
engine_core_timestamp: float | None = None,
|
||||
iteration_stats: IterationStats | None = None,
|
||||
) -> OutputProcessorOutput:
|
||||
"""
|
||||
Process the EngineCoreOutputs:
|
||||
1) Compute stats for logging
|
||||
2) Detokenize
|
||||
3) Create and handle RequestOutput objects:
|
||||
* If there is a queue (for usage with AsyncLLM),
|
||||
put the RequestOutput objects into the queue for
|
||||
handling by the per-request generate() tasks.
|
||||
|
||||
* If there is no queue (for usage with LLMEngine),
|
||||
return a list of RequestOutput objects.
|
||||
|
||||
NOTE FOR DEVELOPERS
|
||||
|
||||
vLLM V1 minimizes the number of python loops over the full
|
||||
batch to ensure system overheads are minimized. This is the
|
||||
only function that should loop over EngineCoreOutputs.
|
||||
|
||||
If you need to touch every element of the batch, do it from
|
||||
within the loop below.
|
||||
"""
|
||||
|
||||
request_outputs: list[RequestOutput | PoolingRequestOutput] = []
|
||||
reqs_to_abort: list[str] = []
|
||||
for engine_core_output in engine_core_outputs:
|
||||
req_id = engine_core_output.request_id
|
||||
req_state = self.request_states.get(req_id)
|
||||
if req_state is None:
|
||||
# Ignore output for already-aborted request.
|
||||
continue
|
||||
|
||||
# 1) Compute stats for this iteration.
|
||||
self._update_stats_from_output(
|
||||
req_state, engine_core_output, engine_core_timestamp, iteration_stats
|
||||
)
|
||||
|
||||
new_token_ids = engine_core_output.new_token_ids
|
||||
pooling_output = engine_core_output.pooling_output
|
||||
finish_reason = engine_core_output.finish_reason
|
||||
stop_reason = engine_core_output.stop_reason
|
||||
kv_transfer_params = engine_core_output.kv_transfer_params
|
||||
req_state.num_cached_tokens = engine_core_output.num_cached_tokens
|
||||
req_state.is_prefilling = False
|
||||
|
||||
if pooling_output is None:
|
||||
assert req_state.detokenizer is not None
|
||||
assert req_state.logprobs_processor is not None
|
||||
# 2) Detokenize the token ids into text and perform stop checks.
|
||||
stop_string = req_state.detokenizer.update(
|
||||
new_token_ids, finish_reason == FinishReason.STOP
|
||||
)
|
||||
if stop_string:
|
||||
finish_reason = FinishReason.STOP
|
||||
stop_reason = stop_string
|
||||
|
||||
# 3) Compute sample and prompt logprobs for request,
|
||||
# if required.
|
||||
req_state.logprobs_processor.update_from_output(engine_core_output)
|
||||
|
||||
# 4) Create and handle RequestOutput objects.
|
||||
if request_output := req_state.make_request_output(
|
||||
new_token_ids,
|
||||
pooling_output,
|
||||
finish_reason,
|
||||
stop_reason,
|
||||
kv_transfer_params,
|
||||
):
|
||||
if req_state.queue is not None:
|
||||
# AsyncLLM: put into queue for handling by generate().
|
||||
req_state.queue.put(request_output)
|
||||
else:
|
||||
# LLMEngine: return list of RequestOutputs.
|
||||
request_outputs.append(request_output)
|
||||
|
||||
# Free completed requests.
|
||||
if finish_reason is not None:
|
||||
self.request_states.pop(req_id)
|
||||
# Remove parent request if applicable.
|
||||
parent_req = req_state.parent_req
|
||||
if parent_req and not parent_req.child_requests:
|
||||
self.parent_requests.pop(parent_req.request_id, None)
|
||||
if not self.request_states:
|
||||
self._requests_drained.set()
|
||||
if not engine_core_output.finished:
|
||||
# If req not finished in EngineCore, but Detokenizer
|
||||
# detected stop string, abort needed in EngineCore.
|
||||
reqs_to_abort.append(req_id)
|
||||
|
||||
# Track per-request stats
|
||||
self._update_stats_from_finished(
|
||||
req_state, finish_reason, iteration_stats
|
||||
)
|
||||
if self.tracer:
|
||||
self.do_tracing(engine_core_output, req_state, iteration_stats)
|
||||
|
||||
return OutputProcessorOutput(
|
||||
request_outputs=request_outputs,
|
||||
reqs_to_abort=reqs_to_abort,
|
||||
)
|
||||
|
||||
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
|
||||
self.lora_states.update_scheduler_stats(scheduler_stats)
|
||||
|
||||
def do_tracing(
|
||||
self,
|
||||
engine_core_output: EngineCoreOutput,
|
||||
req_state: RequestState,
|
||||
iteration_stats: IterationStats | None,
|
||||
) -> None:
|
||||
assert req_state.stats is not None
|
||||
assert iteration_stats is not None
|
||||
assert self.tracer is not None
|
||||
|
||||
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
|
||||
trace_context = extract_trace_context(engine_core_output.trace_headers)
|
||||
prompt_length = length_from_prompt_token_ids_or_embeds(
|
||||
req_state.prompt_token_ids, req_state.prompt_embeds
|
||||
)
|
||||
with self.tracer.start_as_current_span(
|
||||
"llm_request",
|
||||
kind=SpanKind.SERVER,
|
||||
context=trace_context,
|
||||
start_time=arrival_time_nano_seconds,
|
||||
) as span:
|
||||
metrics = req_state.stats
|
||||
e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
|
||||
queued_time = metrics.scheduled_ts - metrics.queued_ts
|
||||
prefill_time = metrics.first_token_ts - metrics.scheduled_ts
|
||||
decode_time = metrics.last_token_ts - metrics.first_token_ts
|
||||
inference_time = metrics.last_token_ts - metrics.scheduled_ts
|
||||
span.set_attribute(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
|
||||
metrics.first_token_latency,
|
||||
)
|
||||
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
|
||||
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
|
||||
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
|
||||
span.set_attribute(
|
||||
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
|
||||
metrics.num_generation_tokens,
|
||||
)
|
||||
span.set_attribute(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
|
||||
)
|
||||
span.set_attribute(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
|
||||
)
|
||||
span.set_attribute(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
|
||||
)
|
||||
|
||||
# meta
|
||||
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
|
||||
if req_state.top_p:
|
||||
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
|
||||
if req_state.max_tokens_param:
|
||||
span.set_attribute(
|
||||
SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
|
||||
)
|
||||
if req_state.temperature:
|
||||
span.set_attribute(
|
||||
SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
|
||||
)
|
||||
if req_state.n:
|
||||
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
|
||||
|
||||
def _update_stats_from_output(
|
||||
self,
|
||||
req_state: RequestState,
|
||||
engine_core_output: EngineCoreOutput,
|
||||
engine_core_timestamp: float | None,
|
||||
iteration_stats: IterationStats | None,
|
||||
):
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
assert engine_core_timestamp is not None
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_output(
|
||||
engine_core_output,
|
||||
engine_core_timestamp,
|
||||
req_state.is_prefilling,
|
||||
req_state.prompt_len,
|
||||
req_state.stats,
|
||||
self.lora_states,
|
||||
req_state.lora_name,
|
||||
)
|
||||
|
||||
def _update_stats_from_finished(
|
||||
self,
|
||||
req_state: RequestState,
|
||||
finish_reason: FinishReason | None,
|
||||
iteration_stats: IterationStats | None,
|
||||
):
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
assert finish_reason is not None
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_finished_request(
|
||||
finish_reason=finish_reason,
|
||||
num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
|
||||
req_state.prompt_token_ids, req_state.prompt_embeds
|
||||
),
|
||||
max_tokens_param=req_state.max_tokens_param,
|
||||
req_stats=req_state.stats,
|
||||
num_cached_tokens=req_state.num_cached_tokens,
|
||||
)
|
||||
self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
|
||||
|
||||
ParentRequest.observe_finished_request(
|
||||
req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
|
||||
)
|
||||
145
vllm/v1/engine/parallel_sampling.py
Normal file
145
vllm/v1/engine/parallel_sampling.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import copy
|
||||
from typing import Optional, cast
|
||||
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
|
||||
class ParentRequest:
|
||||
"""Info, state & processing for parallel sampling request.
|
||||
|
||||
Store parent request ID and sampling params.
|
||||
Facilitate generating child request sampling params.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
sampling_params: SamplingParams
|
||||
|
||||
# To track the completion of child requests
|
||||
child_requests: set[str]
|
||||
|
||||
# To aggregate child completions when not streaming
|
||||
output_aggregator: list[CompletionOutput]
|
||||
|
||||
# To find the max number of generated tokens across all children
|
||||
max_num_generation_tokens: int
|
||||
|
||||
# To efficiently obtain child sampling params
|
||||
cached_child_sampling_params: SamplingParams | None
|
||||
|
||||
def __init__(self, request_id: str, sampling_params: SamplingParams) -> None:
|
||||
self.request_id = request_id
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.child_requests = set()
|
||||
self.output_aggregator = (
|
||||
[cast(CompletionOutput, None)] * sampling_params.n
|
||||
if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY)
|
||||
else []
|
||||
)
|
||||
self.max_num_generation_tokens = 0
|
||||
self.cached_child_sampling_params = None
|
||||
|
||||
def _get_child_sampling_params(
|
||||
self,
|
||||
index: int,
|
||||
) -> SamplingParams:
|
||||
"""Efficiently obtain child `sampling_params`
|
||||
|
||||
If `sampling_params.seed` is not `None` then
|
||||
each child request requires a unique clone of
|
||||
parent `sampling_params` with a unique seed.
|
||||
|
||||
Args:
|
||||
index: index within `n` child requests
|
||||
|
||||
Returns:
|
||||
Child `sampling_params` instance.
|
||||
"""
|
||||
seed = self.sampling_params.seed
|
||||
if self.cached_child_sampling_params:
|
||||
# Reuse child sampling_params data structure
|
||||
return self.cached_child_sampling_params
|
||||
# Build child sampling_params
|
||||
child_sampling_params = copy(self.sampling_params)
|
||||
child_sampling_params.n = 1
|
||||
if seed is None:
|
||||
# Cache child sampling_params for later reuse
|
||||
self.cached_child_sampling_params = child_sampling_params
|
||||
else:
|
||||
# Each child gets a clone with a unique seed
|
||||
child_sampling_params.seed = seed + index
|
||||
return child_sampling_params
|
||||
|
||||
def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
|
||||
"""Get child request ID and sampling params.
|
||||
|
||||
Args:
|
||||
index: index within `n` child requests.
|
||||
|
||||
Returns:
|
||||
(request ID, sampling_params) tuple
|
||||
"""
|
||||
child_req_id = f"{index}_{self.request_id}"
|
||||
self.child_requests.add(child_req_id)
|
||||
return child_req_id, self._get_child_sampling_params(index)
|
||||
|
||||
@property
|
||||
def n(self) -> int:
|
||||
return self.sampling_params.n
|
||||
|
||||
def get_outputs(
|
||||
self,
|
||||
child_request_id: str,
|
||||
completion_output: CompletionOutput,
|
||||
) -> tuple[str, list[CompletionOutput], bool]:
|
||||
already_finished_and_returned: bool = False
|
||||
if completion_output.finished():
|
||||
if child_request_id in self.child_requests:
|
||||
self.child_requests.remove(child_request_id)
|
||||
else:
|
||||
# child request ID is not available in child_requests
|
||||
# which means the request had finished in previous
|
||||
# batch step and returned to the client earlier
|
||||
already_finished_and_returned = True
|
||||
|
||||
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
|
||||
# If streaming, just return the current output
|
||||
#
|
||||
# DO NOT output finished and already returned child request to client again
|
||||
outputs = [] if already_finished_and_returned else [completion_output]
|
||||
else:
|
||||
# If not streaming, aggregate the n final outputs.
|
||||
self.output_aggregator[completion_output.index] = completion_output
|
||||
outputs = [] if self.child_requests else self.output_aggregator
|
||||
|
||||
finished = not self.child_requests
|
||||
return self.request_id, outputs, finished
|
||||
|
||||
def observe_num_generation_tokens(self, num_generation_tokens: int):
|
||||
self.max_num_generation_tokens = max(
|
||||
num_generation_tokens, self.max_num_generation_tokens
|
||||
)
|
||||
return self.max_num_generation_tokens
|
||||
|
||||
@staticmethod
|
||||
def observe_finished_request(
|
||||
parent_req: Optional["ParentRequest"],
|
||||
iteration_stats: IterationStats,
|
||||
num_generation_tokens: int,
|
||||
):
|
||||
n_param = parent_req.n if parent_req is not None else 1
|
||||
|
||||
if parent_req is not None:
|
||||
num_generation_tokens = parent_req.observe_num_generation_tokens(
|
||||
num_generation_tokens
|
||||
)
|
||||
|
||||
# Child requests finished, we can now record to iteration stats
|
||||
if parent_req is None or not parent_req.child_requests:
|
||||
iteration_stats.max_num_generation_tokens_iter.append(num_generation_tokens)
|
||||
iteration_stats.n_params_iter.append(n_param)
|
||||
20
vllm/v1/engine/processor.py
Normal file
20
vllm/v1/engine/processor.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import warnings
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "Processor":
|
||||
from .input_processor import InputProcessor
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.v1.engine.processor.Processor` has been moved to "
|
||||
"`vllm.v1.engine.input_processor.InputProcessor`. "
|
||||
"The old name will be removed in v0.14.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return InputProcessor
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
1068
vllm/v1/engine/utils.py
Normal file
1068
vllm/v1/engine/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user