init src 0.9.2

This commit is contained in:
2026-01-09 15:09:53 +08:00
parent 0eb2c0a4b3
commit 41d98d4359
1438 changed files with 417605 additions and 683 deletions

179
vllm/v1/engine/__init__.py Normal file
View File

@@ -0,0 +1,179 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
from collections.abc import Sequence
from typing import Any, Optional, Union
import msgspec
import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort")
class FinishReason(enum.IntEnum):
"""
Reason a request finished - stop, length, or abort.
Int rather than Str for more compact serialization.
stop - a stop string was emitted
length - max_tokens was consumed, or max_model_len was reached
abort - aborted for another reason
"""
STOP = 0
LENGTH = 1
ABORT = 2
def __str__(self):
return FINISH_REASON_STRINGS[self.value]
class EngineCoreRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
prompt_token_ids: list[int]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]
cache_salt: Optional[str]
data_parallel_rank: Optional[int]
# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.
client_index: int = 0
# Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before
# a wave finished notification is received.
current_wave: int = 0
priority: int = 0
class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
QUEUED = 1
SCHEDULED = 2
PREEMPTED = 3
class EngineCoreEvent(msgspec.Struct):
"""A timestamped engine core event associated with a request.
The timestamp is a monotonic timestamps and is used for by the engine
frontend to calculate intervals between engine core events. These
timestamps should not be compared with timestamps from other processes.
"""
type: EngineCoreEventType
timestamp: float
@classmethod
def new_event(cls,
event_type: EngineCoreEventType,
timestamp: Optional[float] = None) -> "EngineCoreEvent":
timestamp = time.monotonic() if timestamp is None else timestamp
return cls(event_type, timestamp)
class EngineCoreOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
new_token_ids: list[int]
new_logprobs: Optional[LogprobsLists] = None
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
pooling_output: Optional[torch.Tensor] = None
finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None
# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0
@property
def finished(self) -> bool:
return self.finish_reason is not None
class UtilityOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
call_id: int
# Non-None implies the call failed, result should be None.
failure_message: Optional[str] = None
result: Any = None
class EngineCoreOutputs(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout
engine_index: int = 0
# [num_reqs]
outputs: list[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0
utility_output: Optional[UtilityOutput] = None
finished_requests: Optional[set[str]] = None
# In DP case, used to signal that the current wave of requests
# has finished and the engines are paused.
wave_complete: Optional[int] = None
# In DP case, used to signal that a request was received for an
# "old" wave, so the next wave needs to be started in other engines.
start_wave: Optional[int] = None
def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.monotonic()
class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
ADD = b'\x00'
ABORT = b'\x01'
START_DP_WAVE = b'\x02'
UTILITY = b'\x03'
# Sentinel used within EngineCoreProc.
EXECUTOR_FAILED = b'\x04'

626
vllm/v1/engine/async_llm.py Normal file
View File

@@ -0,0 +1,626 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Any, Optional, Union
import numpy as np
import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
setup_default_loggers)
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
logger = init_logger(__name__)
class AsyncLLM(EngineClient):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> None:
"""
Create an AsyncLLM.
Args:
vllm_config: global configuration.
executor_class: an Executor impl, e.g. MultiprocExecutor.
log_stats: Whether to log stats.
usage_context: Usage context of the LLM.
mm_registry: Multi-modal registry.
use_cached_outputs: Whether to use cached outputs.
log_requests: Whether to log requests.
start_engine_loop: Whether to start the engine loop.
stat_loggers: customized stat loggers for the engine.
If not provided, default stat loggers will be used.
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
Returns:
None
"""
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
# Ensure we can serialize custom transformer configs
maybe_register_config_serialize_by_value()
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.log_requests = log_requests
self.log_stats = log_stats
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
vllm_config=vllm_config,
log_stats=self.log_stats,
engine_num=vllm_config.parallel_config.data_parallel_size,
custom_stat_loggers=stat_loggers,
)
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
vllm_config=vllm_config,
tokenizer=self.tokenizer,
mm_registry=mm_registry,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
client_addresses=client_addresses,
client_index=client_index,
)
if self.stat_loggers:
for stat_logger in self.stat_loggers[0]:
stat_logger.log_engine_initialized()
self.output_handler: Optional[asyncio.Task] = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
asyncio.get_running_loop()
self._run_output_handler()
except RuntimeError:
pass
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> "AsyncLLM":
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
# Create the LLMEngine.
return cls(
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop,
stat_loggers=stat_loggers,
log_requests=not disable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
client_addresses=client_addresses,
client_index=client_index,
)
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
# Create the AsyncLLM.
return cls(
vllm_config=vllm_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
def __del__(self):
self.shutdown()
def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC."""
shutdown_prometheus()
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
if handler := getattr(self, "output_handler", None):
handler.cancel()
async def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""
if self.errored:
raise EngineDeadError()
is_pooling = isinstance(params, PoolingParams)
# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)
# Convert Input --> Request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority, data_parallel_rank)
if is_pooling or params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue)
return queue
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, params)
for idx in range(params.n):
request_id, params = parent_request.get_child_info(idx)
child_request = request if idx == params.n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
await self._add_request(child_request, prompt_str, parent_request,
idx, queue)
return queue
async def _add_request(self, request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest], index: int,
queue: RequestOutputCollector):
# Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, prompt, parent_req, index,
queue)
# Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
* 2) Processing the Input.
* 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process).
A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the
per-request AsyncStream.
The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller.
"""
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
q = await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() or await q.get()
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
yield out
# If the request is disconnected by the client, generate()
# is cancelled or the generator is garbage collected. So,
# we abort the request if we end up here.
except (asyncio.CancelledError, GeneratorExit):
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
# Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
if self.output_handler is not None:
return
# Ensure that the task doesn't have a circular ref back to the AsyncLLM
# object, or else it won't be garbage collected and cleaned up properly.
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
stat_loggers = self.stat_loggers if log_stats else None
async def output_handler():
try:
while True:
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await engine_core.get_output_async()
num_outputs = len(outputs.outputs)
iteration_stats = IterationStats() if (
log_stats and num_outputs) else None
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
else:
slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
for i, outputs_slice in enumerate(slices):
# 2) Process EngineCoreOutputs.
processed_outputs = output_processor.process_outputs(
outputs_slice, outputs.timestamp, iteration_stats)
# NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs
# Allow other asyncio tasks to run between chunks
if i + 1 < len(slices):
await asyncio.sleep(0)
# 3) Abort any reqs that finished due to stop strings.
await engine_core.abort_requests_async(
processed_outputs.reqs_to_abort)
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if stat_loggers:
AsyncLLM._record_stats(
stat_loggers[outputs.engine_index],
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
except Exception as e:
logger.exception("AsyncLLM output_handler failed.")
output_processor.propagate_error(e)
self.output_handler = asyncio.create_task(output_handler())
async def abort(self, request_id: str) -> None:
"""Abort RequestId in OutputProcessor and EngineCore."""
request_ids = self.output_processor.abort_requests((request_id, ))
await self.engine_core.abort_requests_async(request_ids)
if self.log_requests:
logger.info("Aborted request %s.", request_id)
@staticmethod
def _record_stats(
stat_loggers: list[StatLoggerBase],
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
):
"""static so that it can be used from the output_handler task
without a circular ref to AsyncLLM."""
for stat_logger in stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
async def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
* 2) Processing the Input.
* 3) Adding the Request to the EngineCore (separate process).
A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the
per-request AsyncStream.
The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller.
"""
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
q = await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() or await q.get()
assert isinstance(out, PoolingRequestOutput)
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
yield out
# If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
# Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def get_decoding_config(self):
raise ValueError("Not Supported on V1 yet.")
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.processor.input_preprocessor
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.tokenizer.get_lora_tokenizer(lora_request)
async def is_tracing_enabled(self) -> bool:
return False
async def do_log_stats(
self,
scheduler_outputs=None,
model_output=None,
) -> None:
for loggers in self.stat_loggers:
for stat_logger in loggers:
stat_logger.log()
async def check_health(self) -> None:
logger.debug("Called check_health.")
if self.errored:
raise self.dead_error
async def start_profile(self) -> None:
await self.engine_core.profile_async(True)
async def stop_profile(self) -> None:
await self.engine_core.profile_async(False)
async def reset_mm_cache(self) -> None:
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_input_cache_client.reset()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
if device == Device.CPU:
raise ValueError("Not supported on CPU.")
await self.engine_core.reset_prefix_cache_async()
async def sleep(self, level: int = 1) -> None:
await self.engine_core.sleep_async(level)
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
await self.engine_core.wake_up_async(tags)
async def is_sleeping(self) -> bool:
return await self.engine_core.is_sleeping_async()
async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
return await self.engine_core.add_lora_async(lora_request)
async def remove_lora(self, lora_id: int) -> bool:
"""Remove an already loaded LoRA adapter."""
return await self.engine_core.remove_lora_async(lora_id)
async def list_loras(self) -> set[int]:
"""List all registered adapters."""
return await self.engine_core.list_loras_async()
async def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted."""
return await self.engine_core.pin_lora_async(lora_id)
async def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
"""
Perform a collective RPC call to the given path.
"""
return await self.engine_core.collective_rpc_async(
method, timeout, args, kwargs)
@property
def is_running(self) -> bool:
# Is None before the loop is started.
return self.output_handler is None or not self.output_handler.done()
@property
def is_stopped(self) -> bool:
return self.errored
@property
def errored(self) -> bool:
return self.engine_core.resources.engine_dead or not self.is_running
@property
def dead_error(self) -> BaseException:
return EngineDeadError()

View File

@@ -0,0 +1,278 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import time
import weakref
from typing import Optional
import msgspec.msgpack
import zmq
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_mp_context, make_zmq_socket
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
from vllm.v1.serial_utils import MsgpackDecoder
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
logger = init_logger(__name__)
class DPCoordinator:
"""Coordinator process used for data-parallel deployments (DP>1).
Intermediates between multiple DP engine rank processes and one or more
front-end API server processes.
* Collects stats from each DP engine (currently just waiting and running
queue lengths), and publishes these to all front-ends for use in
load-balancing decisions.
* Keeps track of the current DP "request wave" number and running state
of the engines. This is received from the DP rank 0 engine and published
to the front-end processes along with the current load stats.
The engines alternate between a global running/paused state. The global
"request wave" number is a count of the number of times that the workers
collectively move from a running state to a paused state. This transition
is synchronized via the all-reduce operation performed in the
DPEngineCoreProc._has_global_unfinished_reqs method.
* Broadcasts the START_DP_WAVE message to engines to move them from paused
to running state when one engine receives a new request. This can happen
in two cases:
1) A front-end sending a new request while the engines are paused will
concurrently notify the coordinator.
2) An engine receiving a request for a stale request wave while in paused
state will notify the coordinator.
Engines will move into running state when receiving a new request or
START_DP_WAVE message.
Note that when deployed in External LB mode, no stats will be published by
the engines and thus updates will only be sent to front-ends when the
request wave / running state changes.
"""
def __init__(self, parallel_config: ParallelConfig):
dp_size = parallel_config.data_parallel_size
assert dp_size > 1, "Coordinator only used for data parallel"
host = parallel_config.data_parallel_master_ip
external_lb = parallel_config.data_parallel_external_lb
# Assume coordinator is colocated with front-end procs when not in
# external DP LB mode.
front_publish_address = get_engine_client_zmq_addr(
local_only=not external_lb, host=host)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
# When in external LB mode, load stats aren't published, only changes
# to request wave / running state, so we don't need to rate-limit the
# updates to the front-end proc(s).
min_stats_update_interval_ms = 0 if external_lb else 100
context = get_mp_context()
self.proc: multiprocessing.Process = context.Process(
target=CoordinatorProc.run_coordinator,
name="VLLM_DP_Coordinator",
kwargs={
"engine_count": parallel_config.data_parallel_size,
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
"min_stats_update_interval_ms": min_stats_update_interval_ms,
},
daemon=True)
self.proc.start()
self.stats_publish_address = front_publish_address
self.coord_in_address = back_publish_address
self.coord_out_address = back_output_address
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
def get_stats_publish_address(self) -> str:
return self.stats_publish_address
def get_engine_socket_addresses(self) -> tuple[str, str]:
"""Returns tuple of ZMQ input address, output address."""
return self.coord_in_address, self.coord_out_address
def close(self):
self._finalizer()
class EngineState:
def __init__(self):
self.request_counts = [0, 0] # [waiting, running]
class CoordinatorProc:
def __init__(self,
engine_count: int,
min_stats_update_interval_ms: int = 100):
self.ctx = zmq.Context()
self.engines = [EngineState() for _ in range(engine_count)]
self.stats_update_interval_ms = min_stats_update_interval_ms
self.current_wave = 0
self.engines_running = False
self.stats_changed = False
@staticmethod
def run_coordinator(
engine_count: int,
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
min_stats_update_interval_ms: int = 100,
):
coordinator = CoordinatorProc(
engine_count=engine_count,
min_stats_update_interval_ms=min_stats_update_interval_ms)
try:
coordinator.process_input_socket(
front_publish_address,
back_output_address,
back_publish_address,
)
except KeyboardInterrupt:
logger.info("DP Coordinator process exiting")
def process_input_socket(self, front_publish_address: str,
back_output_address: str,
back_publish_address: str):
decoder = MsgpackDecoder(EngineCoreOutputs)
with make_zmq_socket(
path=front_publish_address, # IPC
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_front, make_zmq_socket(
path=back_output_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.PULL,
bind=True,
) as output_back, make_zmq_socket(
path=back_publish_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_back:
poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN)
last_publish_time = 0
while True:
elapsed = int(time.time() * 1000) - last_publish_time
# Send at stats_update_interval_ms interval if the stats have
# changed, or otherwise every 4 seconds.
wait_for = (self.stats_update_interval_ms
if self.stats_changed else 4000)
events = poller.poll(timeout=max(0, wait_for - elapsed))
if not events:
# Poller timeout - publish current stats to front-ends.
engine_req_counts_list = self._get_engine_counts()
to_publish = (engine_req_counts_list, self.current_wave,
self.engines_running)
publish_front.send(msgspec.msgpack.encode(to_publish))
last_publish_time = int(time.time() * 1000)
self.stats_changed = False
continue
events = dict(events)
if publish_front in events:
buffer = publish_front.recv()
if buffer in (b'\x01', b'\x00'):
# Ignore subscription messages.
continue
# We received a message on the front-end XPUB socket,
# from an API server sending a new request while the
# engines are paused, so that we can wake the other
# engines.
engine_to_exclude, wave = msgspec.msgpack.decode(buffer)
if not self.engines_running:
if wave < self.current_wave:
# If the wave number is stale, ensure the message
# is handled by all the engines.
engine_to_exclude = None
self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, self.current_wave,
engine_to_exclude)
if output_back in events:
# We received a message from one of the engines.
buffer = output_back.recv()
outputs: EngineCoreOutputs = decoder.decode(buffer)
assert not outputs.outputs
assert outputs.utility_output is None
eng_index = outputs.engine_index
scheduler_stats = outputs.scheduler_stats
if scheduler_stats:
# 1. Updated request load stats - update our local
# state with these.
stats = self.engines[eng_index].request_counts
stats[0] = scheduler_stats.num_waiting_reqs
stats[1] = scheduler_stats.num_running_reqs
self.stats_changed = True
if (wave := outputs.wave_complete) is not None:
# 2. Notification from rank 0 engine that we've
# moved into the global paused state
# (engines_running==False).
if self.current_wave <= wave:
new_wave = wave + 1
logger.debug("Moving DP wave from %d to %d.",
self.current_wave, new_wave)
self.current_wave = new_wave
self.engines_running = False
self.stats_changed = True
elif (wave := outputs.start_wave) is not None and (
wave > self.current_wave or
(wave == self.current_wave
and not self.engines_running)):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger.debug(
"Starting wave %d after notification of "
"stale wave request from engine.", wave)
self.current_wave = wave
self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, wave, eng_index)
@staticmethod
def _send_start_wave(socket: zmq.Socket, wave: int,
exclude_engine_index: Optional[int]):
"""Broadcast the START_DP_WAVE message to all the engines.
It includes the current wave number and index of engine which
has already received a request with this wave number and so doesn't
require additional notification.
"""
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
socket.send_multipart(
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
def _get_engine_counts(self) -> list[list[int]]:
"""Return list of [waiting, running] count lists for each engine."""
return [e.request_counts for e in self.engines]

1049
vllm/v1/engine/core.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,294 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Optional
import tokenizers
from packaging import version
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
# Only tokenizers >= 0.21.1 supports DecodeStream used for
# FastIncrementalDetokenizer.
USE_FAST_DETOKENIZER = version.parse(
tokenizers.__version__) >= version.parse("0.21.1")
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
class IncrementalDetokenizer:
def __init__(self):
self.token_ids: list[int] = []
@property
def output_token_ids(self) -> list[int]:
return self.token_ids
def update(self, new_token_ids: list[int],
stop_terminated: bool) -> Optional[str]:
self.token_ids.extend(new_token_ids)
return None
def get_next_output_text(self, finished: bool, delta: bool) -> str:
return ""
@classmethod
def from_new_request(
cls,
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "IncrementalDetokenizer":
assert request.sampling_params is not None
if tokenizer is None:
# No tokenizer => skipping detokenization.
return IncrementalDetokenizer()
if USE_FAST_DETOKENIZER and isinstance(tokenizer,
PreTrainedTokenizerFast):
# Fast tokenizer => use tokenizers library DecodeStream.
return FastIncrementalDetokenizer(tokenizer, request)
# Fall back to slow python-based incremental detokenization.
return SlowIncrementalDetokenizer(tokenizer, request)
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
def __init__(self, request: EngineCoreRequest):
super().__init__()
# Stop strings
params = request.sampling_params
assert params is not None
self.stop = stop = params.stop
self.include_stop_str_in_output = params.include_stop_str_in_output
# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
if stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in stop) - 1
else:
self.stop_buffer_length = 0
self._last_output_text_offset: int = 0
# Generation data
self.output_text = ""
def update(self, new_token_ids: list[int],
stop_terminated: bool) -> Optional[str]:
"""
Update RequestState for the request_id by:
1) Detokenize the new token ids incrementally.
2) Evaluate stop criteria.
Return matched stop string or None.
"""
if not new_token_ids:
# Skip detokenization if no new token ids.
return None
if stop_terminated and not self.include_stop_str_in_output:
# If stop-terminated, exclude last token from detokenization
# based on include_stop_str_in_output parameter.
skipped_stop_token_id = new_token_ids[-1]
new_token_ids = new_token_ids[:-1]
else:
skipped_stop_token_id = None
# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
offset_before = len(self.output_text)
for new_token_id in new_token_ids:
self.token_ids.append(new_token_id)
self.output_text += self.decode_next(new_token_id)
if stop_terminated:
if skipped_stop_token_id is not None:
# Cleanup after skipping detokenization.
self.token_ids.append(skipped_stop_token_id)
# Stop token triggered; skip stop string check.
return None
# 2) Evaluate stop strings.
stop_string = None
if self.stop:
stop = StopChecker.check_stop_strings(
output_text=self.output_text,
new_char_count=len(self.output_text) - offset_before,
stop=self.stop,
include_in_output=self.include_stop_str_in_output,
)
if stop is not None:
stop_string, truncate_to = stop
if truncate_to != -1:
self.output_text = self.output_text[:truncate_to]
return stop_string
@abstractmethod
def decode_next(self, next_token_id: int) -> str:
raise NotImplementedError
def get_next_output_text(self, finished: bool, delta: bool) -> str:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished.
buffer_length = 0 if finished else self.stop_buffer_length
if not delta:
return self.output_text[:-buffer_length] if buffer_length else (
self.output_text)
length = len(self.output_text) - buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
return self.output_text[last_offset:length]
return ""
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: PreTrainedTokenizerFast,
request: EngineCoreRequest):
super().__init__(request)
sampling_params = request.sampling_params
assert sampling_params is not None
self.request_id = request.request_id
self.skip_special_tokens = sampling_params.skip_special_tokens
self.stream = DecodeStream(
skip_special_tokens=self.skip_special_tokens)
self.tokenizer: Tokenizer = tokenizer._tokenizer
# Find a safe place to start.
prompt_suffix = request.prompt_token_ids
prompt_len = len(prompt_suffix)
if prompt_len > 4:
for i in range(4, min(prompt_len + 1, 24)):
suffix = request.prompt_token_ids[-i:]
if '<EFBFBD>' not in self.tokenizer.decode(suffix):
prompt_suffix = suffix
break
# Prime the stream.
for tid in prompt_suffix:
self._protected_step(tid)
self.spaces_between_special_tokens = (
sampling_params.skip_special_tokens
or sampling_params.spaces_between_special_tokens)
if not self.spaces_between_special_tokens:
# Store dict of added token ids so that we can suppress
# the spaces between them.
if (added_token_ids := getattr(self.tokenizer, "added_token_ids",
None)) is None:
self.tokenizer.added_token_ids = added_token_ids = {
tid: tok.content
for tid, tok in
self.tokenizer.get_added_tokens_decoder().items()
}
if added_token_ids:
self.last_special = False
self.added_token_ids = added_token_ids
else:
# No added tokens.
self.spaces_between_special_tokens = True
def decode_next(self, next_token_id: int) -> str:
token = self._protected_step(next_token_id)
if not self.spaces_between_special_tokens:
special_token = self.added_token_ids.get(next_token_id)
is_special = special_token is not None
if is_special and self.last_special:
# Return raw token string without any prefixed spaces.
token = special_token
self.last_special = is_special
return token or ""
def _protected_step(self, next_token_id: int) -> Optional[str]:
try:
token = self.stream.step(self.tokenizer, next_token_id)
except Exception as e:
if str(e) != INVALID_PREFIX_ERR_MSG:
raise e
# Recover from edge case where tokenizer can produce non-monotonic,
# invalid UTF-8 output, which breaks the internal state of
# tokenizers' DecodeStream.
# See https://github.com/vllm-project/vllm/issues/17448.
logger.warning(
"Encountered invalid prefix detokenization error"
" for request %s, resetting decode stream.", self.request_id)
self.stream = DecodeStream(self.skip_special_tokens)
token = self.stream.step(self.tokenizer, next_token_id)
return token
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest, mode="auto"):
super().__init__(request)
self.mode = mode
self.tokenizer = tokenizer
params = request.sampling_params
assert params is not None
# Metadata for incremental detokenization.
self.tokens, self.prefix_offset, self.read_offset = (
convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=params.skip_special_tokens,
))
self.token_ids.extend(request.prompt_token_ids)
self.prompt_len = len(request.prompt_token_ids)
self.skip_special_tokens = params.skip_special_tokens
self.spaces_between_special_tokens = (
params.spaces_between_special_tokens)
@property
def output_token_ids(self) -> list[int]:
return self.token_ids if not self.prompt_len else (
self.token_ids[self.prompt_len:])
def decode_next(self, next_token_id: int) -> str:
new_tokens, decoded_text, prefix_offset, read_offset = (
detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=self.token_ids,
prev_tokens=self.tokens,
prefix_offset=self.prefix_offset,
read_offset=self.read_offset,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.
spaces_between_special_tokens,
mode=self.mode,
))
self.tokens.extend(new_tokens)
self.prefix_offset = prefix_offset
self.read_offset = read_offset
return decoded_text

View File

@@ -0,0 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
class EngineGenerateError(Exception):
"""Raised when a AsyncLLM.generate() fails. Recoverable."""
pass
class EngineDeadError(Exception):
"""Raised when the EngineCore dies. Unrecoverable."""
def __init__(self, *args, suppress_context: bool = False, **kwargs):
ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501
super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs)
# Make stack trace clearer when using with LLMEngine by
# silencing irrelevant ZMQError.
self.__suppress_context__ = suppress_context

View File

@@ -0,0 +1,322 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from copy import copy
from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import (
TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class LLMEngine:
"""Legacy LLMEngine for backwards compatibility."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
) -> None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
if stat_loggers is not None:
raise NotImplementedError(
"Passing StatLoggers to LLMEngine in V1 is not yet supported. "
"Set VLLM_USE_V1=0 and file and issue on Github.")
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.log_stats = log_stats
self.stat_logger: Optional[StatLoggerBase] = None
if self.log_stats:
self.stat_logger = PrometheusStatLogger(vllm_config)
# important: init dp group before init the engine_core
# In the decoupled engine case this is handled in EngineCoreProc.
parallel_config = vllm_config.parallel_config
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
self.dp_group = parallel_config.stateless_init_dp_group()
else:
self.dp_group = None
self.should_execute_dummy_batch = False
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config,
tokenizer=self.tokenizer,
mm_registry=mm_registry)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(
multiprocess_mode=multiprocess_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)
if not multiprocess_mode:
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
# Don't keep the dummy data in memory
self.reset_mm_cache()
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
return cls(vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=(not disable_log_stats),
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
enable_multiprocessing: bool = False,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.")
enable_multiprocessing = True
# Create the LLMEngine.
return cls(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing)
def get_num_unfinished_requests(self) -> int:
return self.output_processor.get_num_unfinished_requests()
def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
if self.dp_group is None:
return has_unfinished or self.engine_core.dp_engines_running()
return self.has_unfinished_requests_dp(has_unfinished)
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
self.dp_group, has_unfinished)
if not has_unfinished and aggregated_has_unfinished:
self.should_execute_dummy_batch = True
return aggregated_has_unfinished
@classmethod
def validate_outputs(cls, outputs, output_type):
return outputs
def abort_request(self, request_ids: list[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
request_ids = self.output_processor.abort_requests(request_ids)
self.engine_core.abort_requests(request_ids)
def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
# Validate the request_id type.
if not isinstance(request_id, str):
raise TypeError(
f"request_id must be a string, got {type(request_id)}")
# Process raw inputs into the request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
# Make a new RequestState and queue.
self.output_processor.add_request(request, prompt_str, None, 0)
# Add the request to EngineCore.
self.engine_core.add_request(request)
return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
# Make a new RequestState and queue.
self.output_processor.add_request(child_request, prompt_str,
parent_req, idx)
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]:
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
self.engine_core.execute_dummy_batch()
return []
# 1) Get EngineCoreOutput from the EngineCore.
outputs = self.engine_core.get_output()
# 2) Process EngineCoreOutputs.
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats)
# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
# 4) Record stats
if self.stat_logger is not None:
assert outputs.scheduler_stats is not None
self.stat_logger.record(scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats)
return processed_outputs.request_outputs
def get_vllm_config(self):
return self.vllm_config
def get_model_config(self):
return self.model_config
def start_profile(self):
self.engine_core.profile(True)
def stop_profile(self):
self.engine_core.profile(False)
def reset_mm_cache(self):
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_input_cache_client.reset()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, device: Optional[Device] = None):
self.engine_core.reset_prefix_cache()
def sleep(self, level: int = 1):
self.engine_core.sleep(level)
def wake_up(self, tags: Optional[list[str]] = None):
self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()
def get_metrics(self) -> list[Metric]:
assert self.log_stats, "Stat logging disabled"
return get_metrics_snapshot()
def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
return self.tokenizer
def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
return self.engine_core.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
"""Remove an already loaded LoRA adapter."""
return self.engine_core.remove_lora(lora_id)
def list_loras(self) -> set[int]:
"""List all registered adapters."""
return self.engine_core.list_loras()
def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def __del__(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)

200
vllm/v1/engine/logprobs.py Normal file
View File

@@ -0,0 +1,200 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional
from vllm.logger import init_logger
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_ids_list_to_tokens)
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
logger = init_logger(__name__)
NONES = itertools.repeat(None)
@dataclass
class LogprobsProcessor:
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: Optional[AnyTokenizer]
# Logprobs for this request
logprobs: Optional[SampleLogprobs]
prompt_logprobs: Optional[PromptLogprobs]
cumulative_logprob: Optional[float]
num_logprobs: Optional[int]
num_prompt_logprobs: Optional[int]
@classmethod
def from_new_request(
cls,
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "LogprobsProcessor":
assert request.sampling_params is not None
num_logprobs = request.sampling_params.logprobs
num_prompt_logprobs = request.sampling_params.prompt_logprobs
return cls(
tokenizer=tokenizer,
cumulative_logprob=(None if num_logprobs is None else 0.),
logprobs=(None if num_logprobs is None else []),
# NOTE: logprob of first prompt token is None.
prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
num_prompt_logprobs=num_prompt_logprobs,
num_logprobs=num_logprobs,
)
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
"""Update with sample logprobs from EngineCore.
Outer lists are only of len > 1 if EngineCore made
>1 tokens in prior step (e.g. in spec decoding).
Args:
logprobs_lists: the lists of logprob tokens, logprobs, and ranks.
"""
assert self.num_logprobs is not None
assert self.logprobs is not None
assert self.cumulative_logprob is not None
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst,
token_ids_lst):
# Detokenize (non-incrementally).
decoded_tokens = NONES if self.tokenizer is None else (
convert_ids_list_to_tokens(self.tokenizer, token_ids))
# Sampler puts the sampled logprob in first.
sampled_token_logprob = logprobs[0]
self.cumulative_logprob += sampled_token_logprob
# Update with the Logprob dictionary for this pos.
self.logprobs.append(
self._make_logprob_dict(
logprobs,
token_ids,
decoded_tokens,
rank,
self.num_logprobs,
))
def _update_prompt_logprobs(
self,
prompt_logprobs_tensors: LogprobsTensors,
) -> None:
"""Update with prompt logprobs from EngineCore.
Args:
prompt_logprobs_tensors: tuple containing the prompt logprobs
tensors.
"""
# Prompt logprobs are enabled.
assert self.num_prompt_logprobs is not None
assert self.prompt_logprobs is not None
token_ids, logprobs, ranks = prompt_logprobs_tensors
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = None if self.tokenizer is None else (
convert_ids_list_to_tokens(self.tokenizer,
token_ids.flatten().tolist()))
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the torch tensors.
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
# Make Logprob for each position.
for pos in range(num_prompt_tokens):
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = NONES \
if decoded_tokens is None else decoded_tokens[offset:offset_end]
# Update with the Logprob dictionary for this pos.
self.prompt_logprobs.append(
self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
self.num_prompt_logprobs))
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
"""Pop and return all request prompt logprobs
The logprobs processor aggregates prompt chunk logprobs
over one or more prefill chunks. This method returns
all prompt logprobs at once and then forgets them.
Ensures correct RequestOutputKind.DELTA semantics
wherein all prompt logprobs are returned at once at
the end of prefill.
Returns:
None if prompt logprobs are disabled for this request.
List of all prompt logprobs, otherwise.
"""
plp = self.prompt_logprobs
if plp:
self.prompt_logprobs = []
return plp
@staticmethod
def _make_logprob_dict(
logprobs: list[float],
logprob_token_ids: list[int],
decoded_tokens: Iterable[Optional[str]],
rank: int,
num_logprobs: int,
) -> dict[int, Logprob]:
"""Make a Logprob dictionary for a position.
Args:
logprobs: list of log probabilities
logprob_token_ids: list of top token ids
decoded_tokens: list of decoded top tokens
rank: rank of the sampled token
num_logprobs: number of logprobs requested
by the user (in addition to sampled logprob)
Returns:
dict[token id, Logprob]
"""
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank, ), topk_ranks)
return {
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(
logprob_token_ids, logprobs, ranks, decoded_tokens)
}
def update_from_output(self, output: EngineCoreOutput) -> None:
if output.new_logprobs is not None:
self._update_sample_logprobs(output.new_logprobs)
if output.new_prompt_logprobs_tensors is not None:
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)

View File

@@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Optional
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.processing import ProcessingCache
from vllm.utils import is_list_of
# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
#
# -- Client:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier.
# - MirroredProcessingCache to keep track of the cached entries and
# determine whether to send the MultiModalKwargs to P1.
#
# -- Server:
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
#
# The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
# client (=P0) and server (=P1) processes if the mm_hash is found in the client
# cache.
# Both Client and Server must use the same cache size
# (to perform mirrored caching). This cache size is set by the environment
# variable VLLM_MM_INPUT_CACHE_GIB.
class MirroredProcessingCache:
def __init__(self, model_config):
mm_config = model_config.multimodal_config
disable_mm_preprocessor_cache = (
mm_config is not None and mm_config.disable_mm_preprocessor_cache)
self.use_cache = not disable_mm_preprocessor_cache
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)
def get_and_update_p0(
self,
mm_inputs: Sequence[MultiModalKwargs],
mm_hashes: list[str],
) -> Sequence[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = list[Optional[MultiModalKwargs]]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if self.mm_cache.get(mm_hash) is not None:
mm_input = None
else:
self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input)
return full_mm_inputs
def get_and_update_p1(
self,
mm_inputs: Sequence[Optional[MultiModalKwargs]],
mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = list[MultiModalKwargs]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_input is None:
mm_input = self.mm_cache[mm_hash]
else:
self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input)
return full_mm_inputs
def reset(self) -> bool:
self.mm_cache.clear()
return True

View File

@@ -0,0 +1,477 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Optional, Union, cast
import torch
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
RequestStateStats)
class RequestOutputCollector:
"""
Collects streamed RequestOutputs per individual request,
for hand-off to the consuming asyncio generate task.
When streaming deltas, RequestOutputs are merged if the
producer gets ahead of the consumer.
"""
def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: Optional[Union[RequestOutput, PoolingRequestOutput,
Exception]] = None
self.ready = asyncio.Event()
def put(self, output: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
"""Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
self.output = output
self.ready.set()
elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
# This ensures that request outputs with different request indexes
# (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate)
async def get(self) -> Union[RequestOutput, PoolingRequestOutput]:
"""Get operation blocks on put event."""
while (output := self.output) is None:
await self.ready.wait()
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
def get_nowait(
self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
"""Non-blocking get operation."""
output = self.output
if output is not None:
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
@dataclass
class OutputProcessorOutput:
request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
reqs_to_abort: list[str]
class RequestState:
def __init__(
self,
request_id: str,
parent_req: Optional[ParentRequest],
request_index: int,
lora_name: Optional[str],
output_kind: RequestOutputKind,
prompt: Optional[str],
prompt_token_ids: list[int],
logprobs_processor: Optional[LogprobsProcessor],
detokenizer: Optional[IncrementalDetokenizer],
max_tokens_param: Optional[int],
arrival_time: float,
queue: Optional[RequestOutputCollector],
log_stats: bool,
):
self.request_id = request_id
self.parent_req = parent_req
self.request_index = request_index
self.lora_name = lora_name
self.output_kind = output_kind
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_len = len(prompt_token_ids)
self.logprobs_processor = logprobs_processor
self.detokenizer = detokenizer
self.max_tokens_param = max_tokens_param
self.is_prefilling = True
self.queue = queue
self.stats = RequestStateStats(
arrival_time=arrival_time) if log_stats else None
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest],
request_index: int,
queue: Optional[RequestOutputCollector],
log_stats: bool,
) -> "RequestState":
if sampling_params := request.sampling_params:
if not sampling_params.detokenize:
tokenizer = None
output_kind = sampling_params.output_kind
logprobs_processor = LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
)
detokenizer = IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
)
max_tokens_param = sampling_params.max_tokens
else:
logprobs_processor = None
detokenizer = None
max_tokens_param = None
assert request.pooling_params is not None
output_kind = request.pooling_params.output_kind
return cls(
request_id=request.request_id,
parent_req=parent_req,
request_index=request_index,
lora_name=(request.lora_request.name
if request.lora_request is not None else None),
output_kind=output_kind,
prompt=prompt,
prompt_token_ids=request.prompt_token_ids,
logprobs_processor=logprobs_processor,
detokenizer=detokenizer,
max_tokens_param=max_tokens_param,
arrival_time=request.arrival_time,
queue=queue,
log_stats=log_stats,
)
def make_request_output(
self,
new_token_ids: list[int],
pooling_output: Optional[torch.Tensor],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
if not finished and final_only:
# Only the final output is required in FINAL_ONLY mode.
return None
request_id = self.request_id
if pooling_output is not None:
return self._new_request_output(
request_id, [self._new_pooling_output(pooling_output)],
finished)
output = self._new_completion_output(new_token_ids, finish_reason,
stop_reason)
if self.parent_req is None:
outputs = [output]
else:
request_id, outputs, finished = self.parent_req.get_outputs(
request_id, output)
if not outputs:
return None
return self._new_request_output(request_id, outputs, finished,
kv_transfer_params, num_cached_tokens)
def _new_request_output(
self,
request_id: str,
outputs: Union[list[CompletionOutput], list[PoolingOutput]],
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> Union[RequestOutput, PoolingRequestOutput]:
if isinstance(outputs[0], PoolingOutput):
assert len(outputs) == 1
return PoolingRequestOutput(
request_id=request_id,
outputs=outputs[0],
prompt_token_ids=self.prompt_token_ids,
finished=finished,
)
assert self.logprobs_processor is not None
if self.output_kind == RequestOutputKind.DELTA:
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
else:
prompt_logprobs = self.logprobs_processor.prompt_logprobs
return RequestOutput(
request_id=request_id,
prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids,
prompt_logprobs=prompt_logprobs,
outputs=cast(list[CompletionOutput], outputs),
finished=finished,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=num_cached_tokens,
)
def _new_completion_output(
self,
token_ids: list[int],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
) -> CompletionOutput:
assert self.detokenizer is not None
assert self.logprobs_processor is not None
finished = finish_reason is not None
delta = self.output_kind == RequestOutputKind.DELTA
# Prepare text and token_ids, based on delta mode
text = self.detokenizer.get_next_output_text(finished, delta)
if not delta:
token_ids = self.detokenizer.output_token_ids
# Prepare logprobs, based on delta mode
logprobs = self.logprobs_processor.logprobs
if delta and logprobs:
logprobs = logprobs[-len(token_ids):]
return CompletionOutput(
index=self.request_index,
text=text,
token_ids=token_ids,
logprobs=logprobs,
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
finish_reason=str(finish_reason) if finished else None,
stop_reason=stop_reason if finished else None)
def _new_pooling_output(
self,
pooling_output: torch.Tensor,
) -> PoolingOutput:
return PoolingOutput(data=pooling_output)
class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""
def __init__(
self,
tokenizer: TokenizerGroup,
log_stats: bool,
):
self.log_stats = log_stats
self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates()
def get_num_unfinished_requests(self):
return len(self.request_states)
def has_unfinished_requests(self) -> bool:
return len(self.request_states) > 0
def propagate_error(self, e: Exception):
"""Propagate error to all generate() tasks."""
for _, state in self.request_states.items():
assert state.queue is not None
state.queue.put(e)
def abort_requests(
self,
request_ids: Iterable[str],
) -> list[str]:
request_ids_to_abort = []
for request_id in request_ids:
req_state = self.request_states.pop(request_id, None)
if req_state is not None:
self.lora_states.abort_request(req_state)
request_ids_to_abort.append(request_id)
else:
parent = self.parent_requests.pop(request_id, None)
if parent and parent.child_requests:
self.abort_requests(parent.child_requests)
request_ids_to_abort.extend(parent.child_requests)
return request_ids_to_abort
def add_request(
self,
request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest] = None,
request_index: int = 0,
queue: Optional[RequestOutputCollector] = None,
) -> None:
request_id = request.request_id
if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.")
req_state = RequestState.from_new_request(
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
request=request,
prompt=prompt,
parent_req=parent_req,
request_index=request_index,
queue=queue,
log_stats=self.log_stats)
self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req
def process_outputs(
self,
engine_core_outputs: list[EngineCoreOutput],
engine_core_timestamp: Optional[float] = None,
iteration_stats: Optional[IterationStats] = None,
) -> OutputProcessorOutput:
"""
Process the EngineCoreOutputs:
1) Compute stats for logging
2) Detokenize
3) Create and handle RequestOutput objects:
* If there is a queue (for usage with AsyncLLM),
put the RequestOutput objects into the queue for
handling by the per-request generate() tasks.
* If there is no queue (for usage with LLMEngine),
return a list of RequestOutput objects.
NOTE FOR DEVELOPERS
vLLM V1 minimizes the number of python loops over the full
batch to ensure system overheads are minimized. This is the
only function that should loop over EngineCoreOutputs.
If you need to touch every element of the batch, do it from
within the loop below.
"""
request_outputs: Union[list[RequestOutput],
list[PoolingRequestOutput]] = []
reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id
req_state = self.request_states.get(req_id)
if req_state is None:
# Ignore output for already-aborted request.
continue
# 1) Compute stats for this iteration.
self._update_stats_from_output(req_state, engine_core_output,
engine_core_timestamp,
iteration_stats)
new_token_ids = engine_core_output.new_token_ids
pooling_output = engine_core_output.pooling_output
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False
if pooling_output is None:
assert req_state.detokenizer is not None
assert req_state.logprobs_processor is not None
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP)
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request,
# if required.
req_state.logprobs_processor.update_from_output(
engine_core_output)
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, pooling_output, finish_reason, stop_reason,
kv_transfer_params, num_cached_tokens):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
else:
# LLMEngine: return list of RequestOutputs.
request_outputs.append(request_output)
# Free completed requests.
if finish_reason is not None:
self.request_states.pop(req_id)
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
# Track per-request stats
self._update_stats_from_finished(req_state, finish_reason,
iteration_stats)
self.lora_states.update_iteration_stats(iteration_stats)
return OutputProcessorOutput(
request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort,
)
def _update_stats_from_output(self, req_state: RequestState,
engine_core_output: EngineCoreOutput,
engine_core_timestamp: Optional[float],
iteration_stats: Optional[IterationStats]):
if iteration_stats is None:
return
lora_stats = self.lora_states.get_stats(req_state)
assert engine_core_timestamp is not None
assert req_state.stats is not None
iteration_stats.update_from_output(engine_core_output,
engine_core_timestamp,
req_state.is_prefilling,
req_state.prompt_len,
req_state.stats, lora_stats)
def _update_stats_from_finished(self, req_state: RequestState,
finish_reason: Optional[FinishReason],
iteration_stats: Optional[IterationStats]):
if iteration_stats is None:
return
assert finish_reason is not None
assert req_state.stats is not None
iteration_stats.update_from_finished_request(
finish_reason=finish_reason,
num_prompt_tokens=len(req_state.prompt_token_ids),
max_tokens_param=req_state.max_tokens_param,
req_stats=req_state.stats)
self.lora_states.finish_request(req_state)
ParentRequest.observe_finished_request(
req_state.parent_req, iteration_stats,
req_state.stats.num_generation_tokens)

View File

@@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import copy
from typing import Optional
from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.metrics.stats import IterationStats
class ParentRequest:
"""Info, state & processing for parallel sampling request.
Store parent request ID and sampling params.
Facilitate generating child request sampling params.
"""
request_id: str
sampling_params: SamplingParams
# To track the completion of child requests
child_requests: set[str]
# To aggregate child completions when not streaming
output_aggregator: list[CompletionOutput]
# To find the max number of generated tokens across all children
max_num_generation_tokens: int
# To efficiently obtain child sampling params
cached_child_sampling_params: Optional[SamplingParams]
def __init__(self, request_id: str,
sampling_params: SamplingParams) -> None:
self.request_id = request_id
self.sampling_params = sampling_params
self.child_requests = set()
self.output_aggregator = [None] * sampling_params.n if (
sampling_params.output_kind
== RequestOutputKind.FINAL_ONLY) else []
self.max_num_generation_tokens = 0
self.cached_child_sampling_params = None
def _get_child_sampling_params(
self,
index: int,
) -> SamplingParams:
"""Efficiently obtain child `sampling_params`
If `sampling_params.seed` is not `None` then
each child request requires a unique clone of
parent `sampling_params` with a unique seed.
Args:
index: index within `n` child requests
Returns:
Child `sampling_params` instance.
"""
seed = self.sampling_params.seed
if self.cached_child_sampling_params:
# Reuse child sampling_params data structure
return self.cached_child_sampling_params
# Build child sampling_params
child_sampling_params = copy(self.sampling_params)
child_sampling_params.n = 1
if seed is None:
# Cache child sampling_params for later reuse
self.cached_child_sampling_params = child_sampling_params
else:
# Each child gets a clone with a unique seed
child_sampling_params.seed = seed + index
return child_sampling_params
def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
"""Get child request ID and sampling params.
Args:
index: index within `n` child requests.
Returns:
(request ID, sampling_params) tuple
"""
child_req_id = f"{index}_{self.request_id}"
self.child_requests.add(child_req_id)
return child_req_id, self._get_child_sampling_params(index)
@property
def n(self) -> int:
return self.sampling_params.n
def get_outputs(
self,
child_request_id: str,
completion_output: CompletionOutput,
) -> tuple[str, list[CompletionOutput], bool]:
if completion_output.finished():
self.child_requests.remove(child_request_id)
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
# If streaming, just return the current output.
outputs = [completion_output]
else:
# If not streaming, aggregate the n final outputs.
self.output_aggregator[completion_output.index] = completion_output
outputs = [] if self.child_requests else self.output_aggregator
finished = not self.child_requests
return self.request_id, outputs, finished
def observe_num_generation_tokens(self, num_generation_tokens: int):
self.max_num_generation_tokens = max(num_generation_tokens,
self.max_num_generation_tokens)
return self.max_num_generation_tokens
@staticmethod
def observe_finished_request(parent_req: Optional['ParentRequest'],
iteration_stats: IterationStats,
num_generation_tokens: int):
n_param = parent_req.n if parent_req is not None else 1
if parent_req is not None:
num_generation_tokens = parent_req.observe_num_generation_tokens(
num_generation_tokens)
# Child requests finished, we can now record to iteration stats
if parent_req is None or not parent_req.child_requests:
iteration_stats.max_num_generation_tokens_iter.append(
num_generation_tokens)
iteration_stats.n_params_iter.append(n_param)

422
vllm/v1/engine/processor.py Normal file
View File

@@ -0,0 +1,422 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Mapping, Sequence
from typing import Any, Literal, Optional, Union
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_xgrammar import (
validate_xgrammar_grammar)
class Processor:
def __init__(
self,
vllm_config: VllmConfig,
tokenizer: TokenizerGroup,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.decoding_config = vllm_config.decoding_config
self.tokenizer = tokenizer
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
# Multi-modal hasher (for images)
self.use_hash = self.mm_input_cache_client.use_cache or \
self.cache_config.enable_prefix_caching
@property
def mm_registry(self):
return self.input_preprocessor.mm_registry
def _validate_logprobs(
self,
params: SamplingParams,
) -> None:
max_logprobs = self.model_config.max_logprobs
# Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs:
raise ValueError(
f"Requested sample logprobs of {params.logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
# Validate prompt logprobs.
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
raise ValueError(
f"Requested prompt logprobs of {params.prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
def _validate_sampling_params(
self,
params: SamplingParams,
lora_request: Optional[LoRARequest],
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
if params.allowed_token_ids is None:
return
if not params.allowed_token_ids:
raise ValueError("allowed_token_ids is not None and empty!")
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
vocab_size = len(tokenizer)
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")
def _validate_logit_bias(
self,
params: SamplingParams,
) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not params.logit_bias:
return
vocab_size = self.model_config.get_vocab_size()
invalid_token_ids = []
for token_id in params.logit_bias:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if invalid_token_ids:
raise ValueError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
def _validate_supported_sampling_params(
self,
params: SamplingParams,
) -> None:
# Best of not yet supported.
if params.best_of is not None and params.best_of > 1:
raise ValueError("vLLM V1 does not yet support best_of.")
# Logits processors not supported.
if params.logits_processors:
raise ValueError("vLLM V1 does not support per request "
"user provided logits processors.")
def _validate_params(
self,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest],
):
"""
Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server.
"""
if isinstance(params, PoolingParams):
return
self._validate_logprobs(params)
self._validate_sampling_params(params, lora_request)
self._validate_supported_sampling_params(params)
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return
if self.model_config.skip_tokenizer_init and params.guided_decoding:
raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
)
engine_level_backend = self.decoding_config.backend
if params.guided_decoding.backend:
# Request-level backend selection is not supported in V1.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_auto` option set on the backend in the params.
if (params.guided_decoding.backend != engine_level_backend
and not (engine_level_backend == "auto"
and params.guided_decoding.backend_was_auto)):
raise ValueError(
"Request-level structured output backend selection is no "
"longer supported. The request specified "
f"'{params.guided_decoding.backend}', but vLLM was "
f"initialised with '{engine_level_backend}'. This error "
"can be resolved by removing backend selection from the "
"request.")
else:
params.guided_decoding.backend = engine_level_backend
# Request content validation
if (isinstance(params.guided_decoding.choice, list)
and not params.guided_decoding.choice):
# It is invalid for choice to be an empty list
raise ValueError(f"Choice '{params.guided_decoding.choice}' "
"cannot be an empty list")
if engine_level_backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_xgrammar_grammar(params)
elif engine_level_backend.startswith("guidance"):
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
else:
# NOTE: engine_level_backend must be "auto" here, because we have
# checked supported_backends above.
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
# default as it is less predictable and subject to change
# between releases as feature support changes.
try:
validate_xgrammar_grammar(params)
params.guided_decoding.backend = "xgrammar"
except ValueError:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
validate_guidance_grammar(params, tokenizer=None)
params.guided_decoding.backend = "guidance"
# Remember that this backend was set automatically
params.guided_decoding.backend_was_auto = True
def process_inputs(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> tuple[Optional[str], EngineCoreRequest]:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request)
self._validate_params(params, lora_request)
if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.")
if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.")
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
data_parallel_size):
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
f"is out of range [0, {data_parallel_size}).")
if arrival_time is None:
arrival_time = time.time()
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
# 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=self.use_hash,
)
from vllm.platforms import current_platform
current_platform.validate_request(
prompt=prompt,
params=params,
processed_inputs=processed_inputs,
)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
self._validate_model_inputs(processed_inputs, lora_request)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
# TODO: Impl encoder-decoder
if encoder_inputs is not None:
raise NotImplementedError
sampling_params = None
pooling_params = None
if isinstance(params, SamplingParams):
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (
self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
else:
pooling_params = params.clone()
# Multimodal related.
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
(
sorted_item_modalities,
sorted_mm_positions,
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
decoder_inputs["mm_placeholders"],
decoder_inputs["mm_hashes"] if self.use_hash else None,
)
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# is a single MultiModalKwargs for all items from all modalities.
# This code flattens kwargs for individual items in a list and
# sorts them by each item's position in the input sequence if there
# are multiple modalities.
unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1:
orig_sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities}
for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
orig_sorted_mm_inputs.append(
MultiModalKwargs.from_items([item]))
used_indices[modality] += 1
else:
orig_sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
orig_sorted_mm_inputs, sorted_mm_hashes)
else:
sorted_mm_inputs = orig_sorted_mm_inputs
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_inputs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
sampling_params=sampling_params,
pooling_params=pooling_params,
eos_token_id=eos_token_id,
arrival_time=arrival_time,
lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"),
priority=priority,
data_parallel_rank=data_parallel_rank,
)
def _validate_model_inputs(self,
inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
model_config = self.model_config
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
prompt_ids = prompt_inputs["prompt_token_ids"]
if not prompt_ids:
if prompt_type == "encoder" and model_config.is_multimodal_model:
pass # Mllama may have empty encoder inputs for text-only data
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config,
tokenizer=tokenizer,
)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper
if model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens.")
raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens

546
vllm/v1/engine/utils.py Normal file
View File

@@ -0,0 +1,546 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import weakref
from collections.abc import Iterator
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess
from typing import TYPE_CHECKING, Callable, Optional, Union
import msgspec
import zmq
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
STARTUP_POLL_PERIOD_MS = 10000
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank, used to track state during handshaking."""
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.identity = index.to_bytes(2, "little")
self.state = CoreEngineState.NEW
@dataclass
class EngineZmqAddresses:
# ZMQ input socket addresses for each front-end client (requests)
inputs: list[str]
# ZMQ output socket addresses for each front-end client (responses)
outputs: list[str]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input: Optional[str] = None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output: Optional[str] = None
# ZMQ socket for front-end to connect to DP coordinator.
# Not used by engine, just relayed to front-end in handshake response.
# Only required for external DP LB case.
frontend_stats_publish_address: Optional[str] = None
@dataclass
class EngineHandshakeMetadata:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str]]
class CoreEngineProcManager:
"""
Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine.
"""
def __init__(
self,
target_fn: Callable,
local_engine_count: int,
start_index: int,
local_start_index: int,
vllm_config: VllmConfig,
local_client: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: Optional[str] = None,
):
context = get_mp_context()
common_kwargs = {
"vllm_config": vllm_config,
"local_client": local_client,
"handshake_address": handshake_address,
"executor_class": executor_class,
"log_stats": log_stats,
}
if client_handshake_address:
common_kwargs[
"client_handshake_address"] = client_handshake_address
self.processes: list[BaseProcess] = []
for index in range(local_engine_count):
local_index = local_start_index + index
global_index = start_index + index
# Start EngineCore in background process.
self.processes.append(
context.Process(target=target_fn,
name=f"EngineCore_{global_index}",
kwargs=common_kwargs | {
"dp_rank": global_index,
"local_dp_rank": local_index,
}))
self._finalizer = weakref.finalize(self, shutdown, self.processes)
try:
for proc in self.processes:
proc.start()
finally:
# Kill other procs if not all are running.
if self.finished_procs():
self.close()
def close(self):
"""Shutdown all procs."""
self._finalizer()
def join_first(self):
"""Wait for any process to exit."""
connection.wait(proc.sentinel for proc in self.processes)
def sentinels(self) -> list:
return [proc.sentinel for proc in self.processes]
def finished_procs(self) -> dict[str, int]:
"""Returns dict of proc name -> exit code for any finished procs."""
return {
proc.name: proc.exitcode
for proc in self.processes if proc.exitcode is not None
}
class CoreEngineActorManager:
"""
Utility class to handle creation, readiness, and shutdown
of core engine Ray actors used by the AsyncLLM and LLMEngine.
Different from CoreEngineProcManager, this class manages
core engines for both local and remote nodes.
"""
def __init__(
self,
vllm_config: VllmConfig,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
placement_groups: Optional[list["PlacementGroup"]] = None,
local_dp_ranks: Optional[list[int]] = None,
):
import copy
import ray
from ray.util.scheduling_strategies import (
PlacementGroupSchedulingStrategy)
from vllm.v1.engine.core import DPEngineCoreActor
self.local_engine_actors: list[ray.ActorHandle] = []
self.remote_engine_actors: list[ray.ActorHandle] = []
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
world_size = vllm_config.parallel_config.world_size
if ray.is_initialized():
logger.info(
"Ray is already initialized. Skipping Ray initialization.")
else:
ray.init()
if placement_groups is not None:
assert local_dp_ranks is not None, (
"local_dp_ranks must be provided if "
"placement_groups is provided")
assert len(placement_groups) == len(local_dp_ranks), (
"placement_groups and local_dp_ranks must "
"have the same length")
logger.info("Using provided placement groups")
# TODO(rui): validate passed-in placement groups
self.created_placement_groups = []
else:
placement_groups, local_dp_ranks = \
CoreEngineActorManager.create_dp_placement_groups(vllm_config)
self.created_placement_groups = placement_groups
assert len(placement_groups) == dp_size, (
"Number of placement groups must match data parallel size")
refs = []
for index in range(dp_size):
local_index = local_dp_ranks[index]
dp_vllm_config = copy.deepcopy(vllm_config)
pg = placement_groups[index]
dp_vllm_config.parallel_config.placement_group = pg
local_client = index < local_engine_count
actor = ray.remote(DPEngineCoreActor).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=world_size,
)).remote(vllm_config=dp_vllm_config,
executor_class=executor_class,
log_stats=log_stats,
local_client=local_client,
addresses=addresses,
dp_rank=index,
local_dp_rank=local_index)
if local_client:
self.local_engine_actors.append(actor)
else:
self.remote_engine_actors.append(actor)
refs.append(actor.wait_for_init.remote())
ray.get(refs)
self.run_refs = []
for actor in self.local_engine_actors + self.remote_engine_actors:
self.run_refs.append(actor.run.remote())
@staticmethod
def create_dp_placement_groups(
vllm_config: VllmConfig
) -> tuple[list["PlacementGroup"], list[int]]:
import ray
from ray._private.state import available_resources_per_node
from ray.util.state import list_nodes
logger.info("Creating placement groups for data parallel")
dp_master_ip = \
vllm_config.parallel_config.data_parallel_master_ip
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
nodes = sorted(list_nodes(),
key=lambda node: node.node_ip != dp_master_ip)
assert nodes[0].node_ip == dp_master_ip, (
"The first node must be the head node")
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
"There can only be one head node")
available_resources = available_resources_per_node()
world_size = vllm_config.parallel_config.world_size
placement_groups: list[PlacementGroup] = []
local_dp_ranks: list[int] = []
for node in nodes:
node_ip = node.node_ip
node_resources = available_resources[node.node_id]
# For now, each DP rank can only be assigned to one node
# TODO(rui): support allocating a single DP rank
# to multiple nodes
available_engine_count = int(node_resources["GPU"]) // world_size
if node_ip == dp_master_ip:
assert available_engine_count >= local_engine_count, (
"Not enough resources to allocate DP ranks "
f"on DP master node {node_ip}")
for i in range(local_engine_count):
bundles = [{
"GPU": 1.0,
"node:" + dp_master_ip: 0.001
}] * world_size + [{
"CPU": 1.0
}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
else:
for i in range(available_engine_count):
if len(placement_groups) == dp_size:
break
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
return placement_groups, local_dp_ranks
def get_run_refs(self):
return self.run_refs
def close(self):
import ray
for actor in self.local_engine_actors + self.remote_engine_actors:
ray.kill(actor)
for pg in self.created_placement_groups:
ray.util.remove_placement_group(pg)
@contextlib.contextmanager
def launch_core_engines(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
num_api_servers: int = 1,
) -> Iterator[tuple[
Optional[Union[CoreEngineProcManager, CoreEngineActorManager]],
Optional[DPCoordinator],
EngineZmqAddresses,
]]:
"""Launch engine and DP coordinator processes as needed."""
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip
external_dp_lb = parallel_config.data_parallel_external_lb
# In offline mode there is an LLM instance per DP rank and
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
offline_mode = local_start_index is not None
# client_local_only = True for cases where this front-end
# sends requests only to colocated engines.
client_local_only = offline_mode or external_dp_lb or (local_engine_count
== dp_size)
# Set up input and output addresses.
addresses = EngineZmqAddresses(
inputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
outputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
)
# Run the DP Coordinator process with rank 0 when in
# online DP mode.
run_coordinator = dp_size > 1 and not offline_mode and dp_rank == 0
if run_coordinator:
coordinator = DPCoordinator(parallel_config)
addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses())
addresses.frontend_stats_publish_address = (
coordinator.get_stats_publish_address())
logger.info("Started DP Coordinator process (PID: %d)",
coordinator.proc.pid)
else:
coordinator = None
if parallel_config.data_parallel_backend == "ray":
logger.info("Starting ray-based data parallel backend")
engine_actor_manager = CoreEngineActorManager(
vllm_config=vllm_config,
addresses=addresses,
executor_class=executor_class,
log_stats=log_stats,
)
yield engine_actor_manager, coordinator, addresses
return
if offline_mode or (external_dp_lb and dp_rank > 0):
assert local_engine_count == 1
engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
else:
engines_to_handshake = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)
]
# Whether the started engines will handshake only with co-located
# front-end processes. In external_dp_lb mode, ranks > 0 handshake with
# their co-located frontend and also the rank 0 front-end, and hence this
# will be False.
handshake_local_only = offline_mode or local_engine_count == dp_size
handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port)
if external_dp_lb and dp_rank > 0:
assert not handshake_local_only
local_handshake_address = get_open_zmq_ipc_path()
client_handshake_address = local_handshake_address
else:
local_handshake_address = handshake_address
client_handshake_address = None
with zmq_socket_ctx(local_handshake_address, zmq.ROUTER,
bind=True) as handshake_socket:
from vllm.v1.engine.core import EngineCoreProc
# Start local engines.
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
handshake_address=handshake_address,
client_handshake_address=client_handshake_address,
local_client=True,
local_engine_count=local_engine_count,
start_index=dp_rank,
local_start_index=local_start_index or 0)
else:
local_engine_manager = None
yield local_engine_manager, coordinator, addresses
# Now wait for engines to start.
wait_for_engine_startup(
handshake_socket,
addresses,
engines_to_handshake,
parallel_config,
vllm_config.cache_config,
local_engine_manager,
coordinator.proc if coordinator else None,
)
def wait_for_engine_startup(
handshake_socket: zmq.Socket,
addresses: EngineZmqAddresses,
core_engines: list[CoreEngine],
parallel_config: ParallelConfig,
cache_config: CacheConfig,
proc_manager: Optional[CoreEngineProcManager],
coord_process: Optional[Process],
):
# Wait for engine core process(es) to send ready messages.
local_count = parallel_config.data_parallel_size_local
remote_count = len(core_engines) - local_count
# [local, remote] counts
conn_pending, start_pending = [local_count, remote_count], [0, 0]
poller = zmq.Poller()
poller.register(handshake_socket, zmq.POLLIN)
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
if coord_process is not None:
poller.register(coord_process.sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
if any(conn_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect.", *conn_pending)
if any(start_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to start.", *start_pending)
continue
if len(events) > 1 or events[0][0] != handshake_socket:
# One of the local core processes exited.
finished = proc_manager.finished_procs() if proc_manager else {}
if coord_process is not None and coord_process.exitcode is not None:
finished[coord_process.name] = coord_process.exitcode
raise RuntimeError("Engine core initialization failed. "
"See root cause above. "
f"Failed core proc(s): {finished}")
# Receive HELLO and READY messages from the input socket.
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
eng_index = int.from_bytes(eng_identity, "little")
engine = next((e for e in core_engines if e.identity == eng_identity),
None)
if engine is None:
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.
init_message = msgspec.msgpack.encode(
EngineHandshakeMetadata(
addresses=addresses,
parallel_config={
"data_parallel_master_ip":
parallel_config.data_parallel_master_ip,
"data_parallel_master_port":
parallel_config.data_parallel_master_port,
"data_parallel_size":
parallel_config.data_parallel_size,
}))
handshake_socket.send_multipart((eng_identity, init_message),
copy=False)
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg["num_gpu_blocks"]
cache_config.num_gpu_blocks = num_gpu_blocks
# In external DP LB mode, the coordinator address that the
# front-end procs connect to is obtained from rank 0 via
# one of the engine handshakes, and passed to the local
# front-end process in the response from the other.
if addresses.frontend_stats_publish_address is None:
addresses.frontend_stats_publish_address = msg.get(
"dp_stats_address")
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else:
raise RuntimeError(f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state.")
logger.debug("%s from %s core engine process %s.", status,
"local" if local else "remote", eng_index)