[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 17ea2ec6aa
1232 changed files with 777 additions and 36 deletions

173
vllm/v1/engine/__init__.py Normal file
View File

@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
from collections.abc import Sequence
from typing import Any, Optional, Union
import msgspec
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort")
class FinishReason(enum.IntEnum):
"""
Reason a request finished - stop, length, or abort.
Int rather than Str for more compact serialization.
stop - a stop string was emitted
length - max_tokens was consumed, or max_model_len was reached
abort - aborted for another reason
"""
STOP = 0
LENGTH = 1
ABORT = 2
def __str__(self):
return FINISH_REASON_STRINGS[self.value]
class EngineCoreRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
prompt_token_ids: list[int]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: SamplingParams
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]
cache_salt: Optional[str]
data_parallel_rank: Optional[int]
# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.
client_index: int = 0
# Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before
# a wave finished notification is received.
current_wave: int = 0
class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
QUEUED = 1
SCHEDULED = 2
PREEMPTED = 3
class EngineCoreEvent(msgspec.Struct):
"""A timestamped engine core event associated with a request.
The timestamp is a monotonic timestamps and is used for by the engine
frontend to calculate intervals between engine core events. These
timestamps should not be compared with timestamps from other processes.
"""
type: EngineCoreEventType
timestamp: float
@classmethod
def new_event(cls,
event_type: EngineCoreEventType,
timestamp: Optional[float] = None) -> "EngineCoreEvent":
timestamp = time.monotonic() if timestamp is None else timestamp
return cls(event_type, timestamp)
class EngineCoreOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
new_token_ids: list[int]
new_logprobs: Optional[LogprobsLists] = None
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None
# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0
@property
def finished(self) -> bool:
return self.finish_reason is not None
class UtilityOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
call_id: int
# Non-None implies the call failed, result should be None.
failure_message: Optional[str] = None
result: Any = None
class EngineCoreOutputs(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout
engine_index: int = 0
# [num_reqs]
outputs: list[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0
utility_output: Optional[UtilityOutput] = None
finished_requests: Optional[set[str]] = None
# In DP case, used to signal that the current wave of requests
# has finished and the engines are paused.
wave_complete: Optional[int] = None
# In DP case, used to signal that a request was received for an
# "old" wave, so the next wave needs to be started in other engines.
start_wave: Optional[int] = None
def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.monotonic()
class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
ADD = b'\x00'
ABORT = b'\x01'
START_DP_WAVE = b'\x02'
UTILITY = b'\x03'
# Sentinel used within EngineCoreProc.
EXECUTOR_FAILED = b'\x04'

558
vllm/v1/engine/async_llm.py Normal file
View File

@@ -0,0 +1,558 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Any, Optional, Union
import numpy as np
import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
setup_default_loggers)
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
logger = init_logger(__name__)
class AsyncLLM(EngineClient):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> None:
"""
Create an AsyncLLM.
Args:
vllm_config: global configuration.
executor_class: an Executor impl, e.g. MultiprocExecutor.
log_stats: Whether to log stats.
usage_context: Usage context of the LLM.
mm_registry: Multi-modal registry.
use_cached_outputs: Whether to use cached outputs.
log_requests: Whether to log requests.
start_engine_loop: Whether to start the engine loop.
stat_loggers: customized stat loggers for the engine.
If not provided, default stat loggers will be used.
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
Returns:
None
"""
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
# Ensure we can serialize custom transformer configs
maybe_register_config_serialize_by_value()
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.log_requests = log_requests
self.log_stats = log_stats
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
vllm_config=vllm_config,
log_stats=self.log_stats,
engine_num=vllm_config.parallel_config.data_parallel_size,
custom_stat_loggers=stat_loggers,
)
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
vllm_config=vllm_config,
tokenizer=self.tokenizer,
mm_registry=mm_registry,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
client_addresses=client_addresses,
client_index=client_index,
)
if self.stat_loggers:
for stat_logger in self.stat_loggers[0]:
stat_logger.log_engine_initialized()
self.output_handler: Optional[asyncio.Task] = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
asyncio.get_running_loop()
self._run_output_handler()
except RuntimeError:
pass
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> "AsyncLLM":
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
# Create the LLMEngine.
return cls(
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop,
stat_loggers=stat_loggers,
log_requests=not disable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
client_addresses=client_addresses,
client_index=client_index,
)
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
# Create the AsyncLLM.
return cls(
vllm_config=vllm_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
def __del__(self):
self.shutdown()
def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC."""
shutdown_prometheus()
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
if handler := getattr(self, "output_handler", None):
handler.cancel()
async def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""
if self.errored:
raise EngineDeadError()
assert isinstance(params, SamplingParams), \
"Pooling is not supported in V1"
# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)
# Convert Input --> Request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority, data_parallel_rank)
if params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue)
return queue
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, params)
for idx in range(params.n):
request_id, params = parent_request.get_child_info(idx)
child_request = request if idx == params.n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
await self._add_request(child_request, prompt_str, parent_request,
idx, queue)
return queue
async def _add_request(self, request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest], index: int,
queue: RequestOutputCollector):
# Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, prompt, parent_req, index,
queue)
# Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
* 2) Processing the Input.
* 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process).
A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the
per-request AsyncStream.
The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller.
"""
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
q = await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() or await q.get()
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
yield out
# If the request is disconnected by the client, generate()
# is cancelled or the generator is garbage collected. So,
# we abort the request if we end up here.
except (asyncio.CancelledError, GeneratorExit):
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
# Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
if self.output_handler is not None:
return
# Ensure that the task doesn't have a circular ref back to the AsyncLLM
# object, or else it won't be garbage collected and cleaned up properly.
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
stat_loggers = self.stat_loggers if log_stats else None
async def output_handler():
try:
while True:
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await engine_core.get_output_async()
num_outputs = len(outputs.outputs)
iteration_stats = IterationStats() if (
log_stats and num_outputs) else None
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
else:
slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
for i, outputs_slice in enumerate(slices):
# 2) Process EngineCoreOutputs.
processed_outputs = output_processor.process_outputs(
outputs_slice, outputs.timestamp, iteration_stats)
# NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs
# Allow other asyncio tasks to run between chunks
if i + 1 < len(slices):
await asyncio.sleep(0)
# 3) Abort any reqs that finished due to stop strings.
await engine_core.abort_requests_async(
processed_outputs.reqs_to_abort)
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if stat_loggers:
AsyncLLM._record_stats(
stat_loggers[outputs.engine_index],
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
except Exception as e:
logger.exception("AsyncLLM output_handler failed.")
output_processor.propagate_error(e)
self.output_handler = asyncio.create_task(output_handler())
async def abort(self, request_id: str) -> None:
"""Abort RequestId in OutputProcessor and EngineCore."""
request_ids = self.output_processor.abort_requests((request_id, ))
await self.engine_core.abort_requests_async(request_ids)
if self.log_requests:
logger.info("Aborted request %s.", request_id)
@staticmethod
def _record_stats(
stat_loggers: list[StatLoggerBase],
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
):
"""static so that it can be used from the output_handler task
without a circular ref to AsyncLLM."""
for stat_logger in stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
):
raise ValueError("Not Supported on V1 yet.")
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def get_decoding_config(self):
raise ValueError("Not Supported on V1 yet.")
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.processor.input_preprocessor
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.tokenizer.get_lora_tokenizer(lora_request)
async def is_tracing_enabled(self) -> bool:
return False
async def do_log_stats(
self,
scheduler_outputs=None,
model_output=None,
) -> None:
for loggers in self.stat_loggers:
for stat_logger in loggers:
stat_logger.log()
async def check_health(self) -> None:
logger.debug("Called check_health.")
async def start_profile(self) -> None:
await self.engine_core.profile_async(True)
async def stop_profile(self) -> None:
await self.engine_core.profile_async(False)
async def reset_mm_cache(self) -> None:
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_input_cache_client.reset()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
if device == Device.CPU:
raise ValueError("Not supported on CPU.")
await self.engine_core.reset_prefix_cache_async()
async def sleep(self, level: int = 1) -> None:
await self.engine_core.sleep_async(level)
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
await self.engine_core.wake_up_async(tags)
async def is_sleeping(self) -> bool:
return await self.engine_core.is_sleeping_async()
async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
return await self.engine_core.add_lora_async(lora_request)
async def remove_lora(self, lora_id: int) -> bool:
"""Remove an already loaded LoRA adapter."""
return await self.engine_core.remove_lora_async(lora_id)
async def list_loras(self) -> set[int]:
"""List all registered adapters."""
return await self.engine_core.list_loras_async()
async def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted."""
return await self.engine_core.pin_lora_async(lora_id)
async def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
"""
Perform a collective RPC call to the given path.
"""
return await self.engine_core.collective_rpc_async(
method, timeout, args, kwargs)
@property
def is_running(self) -> bool:
# Is None before the loop is started.
return self.output_handler is None or not self.output_handler.done()
@property
def is_stopped(self) -> bool:
return self.errored
@property
def errored(self) -> bool:
return self.engine_core.resources.engine_dead or not self.is_running
@property
def dead_error(self) -> BaseException:
return EngineDeadError()

View File

@@ -0,0 +1,253 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import time
import weakref
from typing import Optional
import msgspec.msgpack
import zmq
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
from vllm.v1.serial_utils import MsgpackDecoder
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
logger = init_logger(__name__)
class DPCoordinator:
"""Coordinator process used for data-parallel deployments (DP>1).
Intermediates between multiple DP engine rank processes and one or more
front-end API server processes.
* Collects stats from each DP engine (currently just waiting and running
queue lengths), and publishes these to all front-ends for use in
load-balancing decisions.
* Keeps track of the current DP "request wave" number and running state
of the engines. This is received from the DP rank 0 engine and published
to the front-end processes along with the current load stats.
The engines alternate between a global running/paused state. The global
"request wave" number is a count of the number of times that the workers
collectively move from a running state to a paused state. This transition
is synchronized via the all-reduce operation performed in the
DPEngineCoreProc._has_global_unfinished_reqs method.
* Broadcasts the START_DP_WAVE message to engines to move them from paused
to running state when one engine receives a new request. This can happen
in two cases:
1) A front-end sending a new request while the engines are paused will
concurrently notify the coordinator.
2) An engine receiving a request for a stale request wave while in paused
state will notify the coordinator.
Engines will move into running state when receiving a new request or
START_DP_WAVE message.
"""
def __init__(self, parallel_config: ParallelConfig):
# Assume coordinator is colocated with front-end procs.
front_publish_address = get_open_zmq_ipc_path()
dp_size = parallel_config.data_parallel_size
assert dp_size > 1, "Coordinator only used for data parallel"
local_only = dp_size == parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
back_publish_address = get_engine_client_zmq_addr(local_only, host)
back_output_address = get_engine_client_zmq_addr(local_only, host)
context = get_mp_context()
self.proc: multiprocessing.Process = context.Process(
target=CoordinatorProc.run_coordinator,
name="VLLM_DP_Coordinator",
kwargs={
"engine_count": parallel_config.data_parallel_size,
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
},
daemon=True)
self.proc.start()
self.stats_publish_address = front_publish_address
self.coord_in_address = back_publish_address
self.coord_out_address = back_output_address
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
def get_stats_publish_address(self) -> str:
return self.stats_publish_address
def get_engine_socket_addresses(self) -> tuple[str, str]:
"""Returns tuple of ZMQ input address, output address."""
return self.coord_in_address, self.coord_out_address
def close(self):
self._finalizer()
class EngineState:
def __init__(self):
self.request_counts = [0, 0] # [waiting, running]
class CoordinatorProc:
def __init__(self, engine_count: int):
self.ctx = zmq.Context()
self.engines = [EngineState() for _ in range(engine_count)]
self.current_wave = 0
self.engines_running = False
self.stats_changed = False
@staticmethod
def run_coordinator(
engine_count: int,
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
):
coordinator = CoordinatorProc(engine_count=engine_count)
try:
coordinator.process_input_socket(
front_publish_address,
back_output_address,
back_publish_address,
)
except KeyboardInterrupt:
logger.info("DP Coordinator process exiting")
def process_input_socket(self, front_publish_address: str,
back_output_address: str,
back_publish_address: str):
decoder = MsgpackDecoder(EngineCoreOutputs)
with make_zmq_socket(
path=front_publish_address, # IPC
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_front, make_zmq_socket(
path=back_output_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.PULL,
bind=True,
) as output_back, make_zmq_socket(
path=back_publish_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_back:
poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN)
last_publish_time = 0
while True:
elapsed = int(time.time() * 1000) - last_publish_time
# Send at 100 ms interval if the stats have changed,
# or otherwise every 3 seconds.
wait_for = 100 if self.stats_changed else 3000
events = poller.poll(timeout=max(0, wait_for - elapsed))
if not events:
# Poller timeout - publish current stats to front-ends.
engine_req_counts_list = self._get_engine_counts()
to_publish = (engine_req_counts_list, self.current_wave,
self.engines_running)
publish_front.send(msgspec.msgpack.encode(to_publish))
last_publish_time = int(time.time() * 1000)
self.stats_changed = False
continue
events = dict(events)
if publish_front in events:
buffer = publish_front.recv()
if buffer == b'\x01':
# Ignore subscription messages.
continue
# We received a message on the front-end XPUB socket,
# from an API server sending a new request while the
# engines are paused, so that we can wake the other
# engines.
engine_to_exclude, wave = msgspec.msgpack.decode(buffer)
if wave < self.current_wave:
# If the wave number is stale, ensure the message is
# handled by all the engines.
engine_to_exclude = None
if not self.engines_running:
self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, self.current_wave,
engine_to_exclude)
if output_back in events:
# We received a message from one of the engines.
buffer = output_back.recv()
outputs: EngineCoreOutputs = decoder.decode(buffer)
assert not outputs.outputs
assert outputs.utility_output is None
eng_index = outputs.engine_index
if outputs.scheduler_stats:
# 1. Updated request load stats - update our local
# state with these.
stats = self.engines[eng_index].request_counts
stats[0] = outputs.scheduler_stats.num_waiting_reqs
stats[1] = outputs.scheduler_stats.num_running_reqs
self.stats_changed = True
if (wave := outputs.wave_complete) is not None:
# 2. Notification from rank 0 engine that we've
# moved into the global paused state
# (engines_running==False)
if self.current_wave <= wave:
logger.debug("Moving DP wave from %d to %d.",
self.current_wave, wave)
self.current_wave = wave + 1
self.engines_running = False
self.stats_changed = True
elif (wave := outputs.start_wave) is not None and (
wave > self.current_wave or
(wave == self.current_wave
and not self.engines_running)):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger.debug(
"Starting wave %d after notification of "
"stale wave request from engine.", wave)
self.current_wave = wave
self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, wave, eng_index)
@staticmethod
def _send_start_wave(socket: zmq.Socket, wave: int,
exclude_engine_index: Optional[int]):
"""Broadcast the START_DP_WAVE message to all the engines.
It includes the current wave number and index of engine which
has already received a request with this wave number and so doesn't
require additional notification.
"""
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
socket.send_multipart(
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
def _get_engine_counts(self) -> list[list[int]]:
"""Return list of [waiting, running] count lists for each engine."""
return [e.request_counts for e in self.engines]

962
vllm/v1/engine/core.py Normal file
View File

@@ -0,0 +1,962 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import queue
import signal
import sys
import threading
import time
from collections import deque
from collections.abc import Generator
from concurrent.futures import Future
from contextlib import ExitStack, contextmanager
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, Callable, Optional, TypeVar, Union
import msgspec
import zmq
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5
HANDSHAKE_TIMEOUT_MINS = 5
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCore:
"""Inner loop of vLLM's Engine."""
def __init__(self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
executor_fail_callback: Optional[Callable] = None):
assert vllm_config.model_config.runner_type != "pooling"
# plugins need to be loaded at the engine/scheduler level too
from vllm.plugins import load_general_plugins
load_general_plugins()
self.vllm_config = vllm_config
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
self.log_stats = log_stats
# Setup Model.
self.model_executor = executor_class(vllm_config)
if executor_fail_callback is not None:
self.model_executor.register_failure_callback(
executor_fail_callback)
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
self._initialize_kv_caches(vllm_config)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler.
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = vllm_config.scheduler_config.scheduler_cls
# This warning can be removed once the V1 Scheduler interface is
# finalized and we can maintain support for scheduler classes that
# implement it
if Scheduler is not V1Scheduler:
logger.warning(
"Using configured V1 scheduler class %s. "
"This scheduler interface is not public and "
"compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls)
self.scheduler: SchedulerInterface = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
log_stats=self.log_stats,
)
# Setup MM Input Mapper.
self.mm_input_cache_server = MirroredProcessingCache(
vllm_config.model_config)
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
# schedule and execute batches, and is required by pipeline parallelism
# to eliminate pipeline bubbles.
self.batch_queue_size = self.model_executor.max_concurrent_batches
self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
SchedulerOutput]]] = None
if self.batch_queue_size > 1:
logger.info("Batch queue is enabled with size %d",
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time()
# Get all kv cache needed by the model
kv_cache_specs = self.model_executor.get_kv_cache_specs()
# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
available_gpu_memory = self.model_executor.determine_available_memory()
assert len(kv_cache_specs) == len(available_gpu_memory)
# Get the kv cache tensor size
kv_cache_configs = [
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
available_gpu_memory_one_worker)
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
zip(kv_cache_specs, available_gpu_memory)
]
# Since we use a shared centralized controller, we need the
# `kv_cache_config` to be consistent across all workers to make sure
# all the memory operators can be applied to all workers.
unify_kv_cache_configs(kv_cache_configs)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to initialize the scheduler.
assert all([
cfg.num_blocks == kv_cache_configs[0].num_blocks
for cfg in kv_cache_configs
])
num_gpu_blocks = kv_cache_configs[0].num_blocks
num_cpu_blocks = 0
scheduler_kv_cache_config = kv_cache_configs[0]
# Initialize kv cache and warmup the execution
self.model_executor.initialize_from_config(kv_cache_configs)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""
if request.mm_hashes is not None:
# Here, if hash exists for a multimodal input, then it will be
# fetched from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client cache, so
# anything that has a hash must have a HIT cache entry here
# as well.
assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
if req.use_structured_output:
# Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req)
if req.kv_transfer_params is not None and (
not self.scheduler.get_kv_connector()):
logger.warning("Got kv_transfer_params, but no KVConnector found. "
"Disabling KVTransfer for this request.")
self.scheduler.add_request(req)
def abort_requests(self, request_ids: list[str]):
"""Abort requests from the scheduler."""
# TODO: The scheduler doesn't really need to know the
# specific finish reason, TBD whether we propagate that
# (i.e. client-aborted vs stop criteria met).
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)
def execute_model(self, scheduler_output: SchedulerOutput):
try:
return self.model_executor.execute_model(scheduler_output)
except BaseException as err:
# NOTE: This method is exception-free
dump_engine_exception(self.vllm_config, scheduler_output,
self.scheduler.make_stats())
# Re-raise exception
raise err
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
model_output = self.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore
return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0)
def step_with_batch_queue(
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
"""Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned.
The execution flow is as follows:
1. Try to schedule a new batch if the batch queue is not full.
If a new batch is scheduled, directly return an empty engine core
output. In other words, fulfilling the batch queue has a higher priority
than getting model outputs.
2. If there is no new scheduled batch, meaning that the batch queue
is full or no other requests can be scheduled, we block until the first
batch in the job queue is finished.
3. Update the scheduler from the output.
"""
assert self.batch_queue is not None
engine_core_outputs = None
scheduler_output = None
# Try to schedule a new batch if the batch queue is not full, but
# the scheduler may return an empty batch if all requests are scheduled.
# Note that this is not blocking.
if not self.batch_queue.full():
scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0:
future = self.model_executor.execute_model(scheduler_output)
self.batch_queue.put_nowait(
(future, scheduler_output)) # type: ignore
scheduled_batch = (scheduler_output is not None
and scheduler_output.total_num_scheduled_tokens > 0)
# If no more requests can be scheduled and the job queue is not empty,
# block until the first batch in the job queue is finished.
# TODO(comaniac): Ideally we should peek the first batch in the
# job queue to check if it's finished before scheduling a new batch,
# but peeking the first element in a queue is not thread-safe,
# so we need more work.
if not scheduled_batch and not self.batch_queue.empty():
future, scheduler_output = self.batch_queue.get_nowait()
# Blocking until the first result is available.
model_output = future.result()
self.batch_queue.task_done()
engine_core_outputs = (self.scheduler.update_from_output(
scheduler_output, model_output))
return engine_core_outputs, scheduled_batch
def shutdown(self):
self.structured_output_manager.clear_backend()
if self.model_executor:
self.model_executor.shutdown()
if self.scheduler:
self.scheduler.shutdown()
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
def reset_mm_cache(self):
# NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
if self.scheduler.has_unfinished_requests():
logger.warning("Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches.")
self.mm_input_cache_server.reset()
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()
def sleep(self, level: int = 1):
self.model_executor.sleep(level)
def wake_up(self, tags: Optional[list[str]] = None):
self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping
def execute_dummy_batch(self):
self.model_executor.collective_rpc("execute_dummy_batch")
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_executor.remove_lora(lora_id)
def list_loras(self) -> set[int]:
return self.model_executor.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.model_executor.save_sharded_state(path=path,
pattern=pattern,
max_size=max_size)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
def save_tensorized_model(
self,
tensorizer_config,
) -> None:
self.model_executor.save_tensorized_model(
tensorizer_config=tensorizer_config, )
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'
def __init__(
self,
vllm_config: VllmConfig,
on_head_node: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
engine_index: int = 0,
):
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
bytes]]()
executor_fail_callback = lambda: self.input_queue.put_nowait(
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
self.engine_index = engine_index
identity = self.engine_index.to_bytes(length=2, byteorder="little")
self.engines_running = False
with self._perform_handshake(handshake_address, identity, on_head_node,
vllm_config) as addresses:
self.client_count = len(addresses.outputs)
# Set up data parallel environment.
self.has_coordinator = addresses.coordinator_output is not None
self._init_data_parallel(vllm_config)
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs, addresses.coordinator_input,
identity),
daemon=True).start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
self.engine_index),
daemon=True)
self.output_thread.start()
@contextmanager
def _perform_handshake(
self, handshake_address: str, identity: bytes, on_head_node: bool,
vllm_config: VllmConfig
) -> Generator[EngineZmqAddresses, None, None]:
input_ctx = zmq.Context()
with make_zmq_socket(input_ctx,
handshake_address,
zmq.DEALER,
identity=identity,
linger=5000,
bind=False) as handshake_socket:
# Register engine with front-end.
addresses = self.startup_handshake(handshake_socket, on_head_node,
vllm_config.parallel_config)
# Update config which may have changed from the handshake
vllm_config.__post_init__()
yield addresses
# Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
handshake_socket.send(
msgspec.msgpack.encode({
"status": "READY",
"local": on_head_node,
"num_gpu_blocks": num_gpu_blocks,
}))
@staticmethod
def startup_handshake(
handshake_socket: zmq.Socket, on_head_node: bool,
parallel_config: ParallelConfig) -> EngineZmqAddresses:
# Send registration message.
handshake_socket.send(
msgspec.msgpack.encode({
"status": "HELLO",
"local": on_head_node,
}))
# Receive initialization message.
logger.info("Waiting for init message from front-end.")
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
raise RuntimeError("Did not receive response from front-end "
f"process within {HANDSHAKE_TIMEOUT_MINS} "
f"minutes")
init_bytes = handshake_socket.recv()
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
init_bytes, type=EngineHandshakeMetadata)
logger.debug("Received init message: %s", init_message)
received_parallel_config = init_message.parallel_config
for key, value in received_parallel_config.items():
setattr(parallel_config, key, value)
return init_message.addresses
@staticmethod
def run_engine_core(*args,
dp_rank: int = 0,
local_dp_rank: int = 0,
**kwargs):
"""Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
# Ensure we can serialize transformer config after spawning
maybe_register_config_serialize_by_value()
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the engine_core
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
engine_core: Optional[EngineCoreProc] = None
try:
parallel_config: ParallelConfig = kwargs[
"vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop()
except SystemExit:
logger.debug("EngineCore exiting.")
raise
except Exception as e:
if engine_core is None:
logger.exception("EngineCore failed to start.")
else:
logger.exception("EngineCore encountered a fatal error.")
engine_core._send_engine_dead()
raise e
finally:
if engine_core is not None:
engine_core.shutdown()
def _init_data_parallel(self, vllm_config: VllmConfig):
pass
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
waited = False
while not self.engines_running and not self.scheduler.has_requests():
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.")
waited = True
req = self.input_queue.get()
self._handle_client_request(*req)
if waited:
logger.debug("EngineCore loop active.")
# Handle any more client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
def _process_engine_step(self) -> bool:
"""Called only when there are unfinished local requests."""
# Step the engine core.
outputs, model_executed = self.step_fn()
# Put EngineCoreOutputs into the output queue.
for output in (outputs.items() if outputs else ()):
self.output_queue.put_nowait(output)
return model_executed
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
"""Dispatch request from client."""
if request_type == EngineCoreRequestType.ADD:
self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.UTILITY:
client_idx, call_id, method_name, args = request
output = UtilityOutput(call_id)
try:
method = getattr(self, method_name)
output.result = method(
*self._convert_msgspec_args(method, args))
except BaseException as e:
logger.exception("Invocation of %s method failed", method_name)
output.failure_message = (f"Call to {method_name} method"
f" failed: {str(e)}")
self.output_queue.put_nowait(
(client_idx, EngineCoreOutputs(utility_output=output)))
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
raise RuntimeError("Executor failed.")
else:
logger.error("Unrecognized input request type encountered: %s",
request_type)
@staticmethod
def _convert_msgspec_args(method, args):
"""If a provided arg type doesn't match corresponding target method
arg type, try converting to msgspec object."""
if not args:
return args
arg_types = signature(method).parameters.values()
assert len(args) <= len(arg_types)
return tuple(
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
and issubclass(p.annotation, msgspec.Struct)
and not isinstance(v, p.annotation) else v
for v, p in zip(args, arg_types))
def _send_engine_dead(self):
"""Send EngineDead status to the EngineCoreClient."""
# Put ENGINE_CORE_DEAD in the queue.
self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)
# Wait until msg sent by the daemon before shutdown.
self.output_thread.join(timeout=5.0)
if self.output_thread.is_alive():
logger.fatal("vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue.")
def process_input_sockets(self, input_addresses: list[str],
coord_input_address: Optional[str],
identity: bytes):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
with ExitStack() as stack, zmq.Context() as ctx:
input_sockets = [
stack.enter_context(
make_zmq_socket(ctx,
input_address,
zmq.DEALER,
identity=identity,
bind=False))
for input_address in input_addresses
]
if coord_input_address is None:
coord_socket = None
else:
coord_socket = stack.enter_context(
make_zmq_socket(ctx,
coord_input_address,
zmq.XSUB,
identity=identity,
bind=False))
# Send subscription message to coordinator.
coord_socket.send(b'\x01')
# Register sockets with poller.
poller = zmq.Poller()
for input_socket in input_sockets:
# Send initial message to each input socket - this is required
# before the front-end ROUTER socket can send input messages
# back to us.
input_socket.send(b'')
poller.register(input_socket, zmq.POLLIN)
if coord_socket is not None:
poller.register(coord_socket, zmq.POLLIN)
while True:
for input_socket, _ in poller.poll():
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(
copy=False)
request_type = EngineCoreRequestType(
bytes(type_frame.buffer))
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str],
engine_index: int):
"""Output socket IO thread."""
# Msgpack serialization encoding.
encoder = MsgpackEncoder()
# Send buffers to reuse.
reuse_buffers: list[bytearray] = []
# Keep references to outputs and buffers until zmq is finished
# with them (outputs may contain tensors/np arrays whose
# backing buffers were extracted for zero-copy send).
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
# We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket.
with ExitStack() as stack, zmq.Context() as ctx:
sockets = [
stack.enter_context(
make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
for output_path in output_paths
]
coord_socket = stack.enter_context(
make_zmq_socket(
ctx, coord_output_path, zmq.PUSH, bind=False,
linger=4000)) if coord_output_path is not None else None
max_reuse_bufs = len(sockets) + 1
while True:
output = self.output_queue.get()
if output == EngineCoreProc.ENGINE_CORE_DEAD:
for socket in sockets:
socket.send(output)
break
assert not isinstance(output, bytes)
client_index, outputs = output
outputs.engine_index = engine_index
if client_index == -1:
# Don't reuse buffer for coordinator message
# which will be very small.
assert coord_socket is not None
coord_socket.send_multipart(encoder.encode(outputs))
continue
# Reclaim buffers that zmq is finished with.
while pending and pending[-1][0].done:
reuse_buffers.append(pending.pop()[2])
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer)
tracker = sockets[client_index].send_multipart(buffers,
copy=False,
track=True)
if not tracker.done:
ref = outputs if len(buffers) > 1 else None
pending.appendleft((tracker, ref, buffer))
elif len(reuse_buffers) < max_reuse_bufs:
# Limit the number of buffers to reuse.
reuse_buffers.append(buffer)
class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
def __init__(
self,
vllm_config: VllmConfig,
on_head_node: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
):
self._decorate_logs()
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
self.current_wave = 0
self.last_counts = (0, 0)
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, on_head_node, handshake_address,
executor_class, log_stats, dp_rank)
def _decorate_logs(self):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from multiprocessing import current_process
process_name = current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
def _init_data_parallel(self, vllm_config: VllmConfig):
# Configure GPUs and stateless process group for data parallel.
dp_rank = vllm_config.parallel_config.data_parallel_rank
dp_size = vllm_config.parallel_config.data_parallel_size
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1
assert 0 <= local_dp_rank <= dp_rank < dp_size
if vllm_config.kv_transfer_config is not None:
# modify the engine_id and append the local_dp_rank to it to ensure
# that the kv_transfer_config is unique for each DP rank.
vllm_config.kv_transfer_config.engine_id = (
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
)
logger.debug("Setting kv_transfer_config.engine_id to %s",
vllm_config.kv_transfer_config.engine_id)
from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var
world_size = vllm_config.parallel_config.world_size
os.environ[device_control_env_var] = ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
world_size))
os.environ["MACA_VISIBLE_DEVICES"] = os.environ[device_control_env_var]
self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
def shutdown(self):
super().shutdown()
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
def add_request(self, request: EngineCoreRequest):
if self.has_coordinator and request.current_wave != self.current_wave:
if request.current_wave > self.current_wave:
self.current_wave = request.current_wave
elif not self.engines_running:
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
super().add_request(request)
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
if request_type == EngineCoreRequestType.START_DP_WAVE:
new_wave, exclude_eng_index = request
if exclude_eng_index != self.engine_index and (
new_wave >= self.current_wave):
self.current_wave = new_wave
if not self.engines_running:
logger.debug("EngineCore starting idle loop for wave %d.",
new_wave)
self.engines_running = True
else:
super()._handle_client_request(request_type, request)
def _maybe_publish_request_counts(self):
if not self.has_coordinator:
return
# Publish our request counts (if they've changed).
counts = self.scheduler.get_request_counts()
if counts != self.last_counts:
self.last_counts = counts
stats = SchedulerStats(*counts)
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(scheduler_stats=stats)))
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core.
executed = self._process_engine_step()
self._maybe_publish_request_counts()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if not executed:
if not local_unfinished_reqs and not self.engines_running:
# All engines are idle.
continue
# We are in a running state and so must execute a dummy pass
# if the model didn't execute any ready requests.
self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs.
self.engines_running = self._has_global_unfinished_reqs(
local_unfinished_reqs)
if not self.engines_running:
if self.dp_rank == 0:
# Notify client that we are pausing the loop.
logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave)
self.output_queue.put_nowait(
(-1,
EngineCoreOutputs(wave_complete=self.current_wave)))
self.current_wave += 1
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 24 steps.
self.counter += 1
if self.counter != 24:
return True
self.counter = 0
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)
class DPEngineCoreActor(DPEngineCoreProc):
"""
Ray actor for running EngineCore in a data parallel context
"""
def __init__(
self,
vllm_config: VllmConfig,
on_head_node: bool,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
dp_rank: int = 0,
local_dp_rank: int = 0,
):
self.addresses = addresses
vllm_config.parallel_config.data_parallel_rank = dp_rank
vllm_config.parallel_config.data_parallel_rank_local = \
local_dp_rank
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
# we clean this up to be able to properly initialize
# data parallel groups.
# del os.environ['CUDA_VISIBLE_DEVICES']
super().__init__(vllm_config, on_head_node, "", executor_class,
log_stats)
def _decorate_logs(self):
pass
@contextmanager
def _perform_handshake(self, handshake_address: str, identity: bytes,
on_head_node: bool, vllm_config: VllmConfig):
"""
For Ray, we don't need to actually perform handshake.
All addresses information is known before the actor creation.
Therefore, we simply yield these addresses.
"""
yield self.addresses
def wait_for_init(self):
"""
Wait until the engine core is initialized.
This is just an empty method. When ray.get() on this method
(or any other method of the actor) returns, it is guaranteed
that actor creation (i.e., __init__) is complete.
"""
pass
def run(self):
"""
Run the engine core busy loop.
"""
try:
self.run_busy_loop()
except SystemExit:
logger.debug("EngineCore exiting.")
raise
except Exception:
logger.exception("EngineCore encountered a fatal error.")
raise
finally:
self.shutdown()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,286 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Optional
import tokenizers
from packaging import version
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
# Only tokenizers >= 0.21.1 supports DecodeStream used for
# FastIncrementalDetokenizer.
USE_FAST_DETOKENIZER = version.parse(
tokenizers.__version__) >= version.parse("0.21.1")
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
class IncrementalDetokenizer:
def __init__(self):
self.token_ids: list[int] = []
@property
def output_token_ids(self) -> list[int]:
return self.token_ids
def update(self, new_token_ids: list[int],
stop_terminated: bool) -> Optional[str]:
self.token_ids.extend(new_token_ids)
return None
def get_next_output_text(self, finished: bool, delta: bool) -> str:
return ""
@classmethod
def from_new_request(
cls,
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "IncrementalDetokenizer":
if tokenizer is None:
# No tokenizer => skipping detokenization.
return IncrementalDetokenizer()
if USE_FAST_DETOKENIZER and isinstance(tokenizer,
PreTrainedTokenizerFast):
# Fast tokenizer => use tokenizers library DecodeStream.
return FastIncrementalDetokenizer(tokenizer, request)
# Fall back to slow python-based incremental detokenization.
return SlowIncrementalDetokenizer(tokenizer, request)
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
def __init__(self, request: EngineCoreRequest):
super().__init__()
# Stop strings
params = request.sampling_params
self.stop = stop = params.stop
self.include_stop_str_in_output = params.include_stop_str_in_output
# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
if stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in stop) - 1
else:
self.stop_buffer_length = 0
self._last_output_text_offset: int = 0
# Generation data
self.output_text = ""
def update(self, new_token_ids: list[int],
stop_terminated: bool) -> Optional[str]:
"""
Update RequestState for the request_id by:
1) Detokenize the new token ids incrementally.
2) Evaluate stop criteria.
Return matched stop string or None.
"""
if not new_token_ids:
# Skip detokenization if no new token ids.
return None
if stop_terminated and not self.include_stop_str_in_output:
# If stop-terminated, exclude last token from detokenization
# based on include_stop_str_in_output parameter.
skipped_stop_token_id = new_token_ids[-1]
new_token_ids = new_token_ids[:-1]
else:
skipped_stop_token_id = None
# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
offset_before = len(self.output_text)
for new_token_id in new_token_ids:
self.token_ids.append(new_token_id)
self.output_text += self.decode_next(new_token_id)
if stop_terminated:
if skipped_stop_token_id is not None:
# Cleanup after skipping detokenization.
self.token_ids.append(skipped_stop_token_id)
# Stop token triggered; skip stop string check.
return None
# 2) Evaluate stop strings.
stop_string = None
if self.stop:
stop = StopChecker.check_stop_strings(
output_text=self.output_text,
new_char_count=len(self.output_text) - offset_before,
stop=self.stop,
include_in_output=self.include_stop_str_in_output,
)
if stop is not None:
stop_string, truncate_to = stop
if truncate_to != -1:
self.output_text = self.output_text[:truncate_to]
return stop_string
@abstractmethod
def decode_next(self, next_token_id: int) -> str:
raise NotImplementedError
def get_next_output_text(self, finished: bool, delta: bool) -> str:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished.
buffer_length = 0 if finished else self.stop_buffer_length
if not delta:
return self.output_text[:-buffer_length] if buffer_length else (
self.output_text)
length = len(self.output_text) - buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
return self.output_text[last_offset:length]
return ""
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: PreTrainedTokenizerFast,
request: EngineCoreRequest):
super().__init__(request)
sampling_params = request.sampling_params
self.request_id = request.request_id
self.skip_special_tokens = sampling_params.skip_special_tokens
self.stream = DecodeStream(
skip_special_tokens=self.skip_special_tokens)
self.tokenizer: Tokenizer = tokenizer._tokenizer
# Find a safe place to start.
prompt_suffix = request.prompt_token_ids
prompt_len = len(prompt_suffix)
if prompt_len > 4:
for i in range(4, min(prompt_len + 1, 24)):
suffix = request.prompt_token_ids[-i:]
if '<EFBFBD>' not in self.tokenizer.decode(suffix):
prompt_suffix = suffix
break
# Prime the stream.
for tid in prompt_suffix:
self._protected_step(tid)
self.spaces_between_special_tokens = (
sampling_params.skip_special_tokens
or sampling_params.spaces_between_special_tokens)
if not self.spaces_between_special_tokens:
# Store dict of added token ids so that we can suppress
# the spaces between them.
if (added_token_ids := getattr(self.tokenizer, "added_token_ids",
None)) is None:
self.tokenizer.added_token_ids = added_token_ids = {
tid: tok.content
for tid, tok in
self.tokenizer.get_added_tokens_decoder().items()
}
if added_token_ids:
self.last_special = False
self.added_token_ids = added_token_ids
else:
# No added tokens.
self.spaces_between_special_tokens = True
def decode_next(self, next_token_id: int) -> str:
token = self._protected_step(next_token_id)
if not self.spaces_between_special_tokens:
special_token = self.added_token_ids.get(next_token_id)
is_special = special_token is not None
if is_special and self.last_special:
# Return raw token string without any prefixed spaces.
token = special_token
self.last_special = is_special
return token or ""
def _protected_step(self, next_token_id: int) -> Optional[str]:
try:
token = self.stream.step(self.tokenizer, next_token_id)
except Exception as e:
if str(e) != INVALID_PREFIX_ERR_MSG:
raise e
# Recover from edge case where tokenizer can produce non-monotonic,
# invalid UTF-8 output, which breaks the internal state of
# tokenizers' DecodeStream.
# See https://github.com/vllm-project/vllm/issues/17448.
logger.warning(
"Encountered invalid prefix detokenization error"
" for request %s, resetting decode stream.", self.request_id)
self.stream = DecodeStream(self.skip_special_tokens)
token = self.stream.step(self.tokenizer, next_token_id)
return token
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
super().__init__(request)
self.tokenizer = tokenizer
# Metadata for incremental detokenization.
self.tokens, self.prefix_offset, self.read_offset = (
convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=request.sampling_params.
skip_special_tokens,
))
self.token_ids.extend(request.prompt_token_ids)
self.prompt_len = len(request.prompt_token_ids)
params = request.sampling_params
self.skip_special_tokens = params.skip_special_tokens
self.spaces_between_special_tokens = (
params.spaces_between_special_tokens)
@property
def output_token_ids(self) -> list[int]:
return self.token_ids if not self.prompt_len else (
self.token_ids[self.prompt_len:])
def decode_next(self, next_token_id: int) -> str:
new_tokens, decoded_text, prefix_offset, read_offset = (
detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=self.token_ids,
prev_tokens=self.tokens,
prefix_offset=self.prefix_offset,
read_offset=self.read_offset,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.
spaces_between_special_tokens,
))
self.tokens.extend(new_tokens)
self.prefix_offset = prefix_offset
self.read_offset = read_offset
return decoded_text

View File

@@ -0,0 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
class EngineGenerateError(Exception):
"""Raised when a AsyncLLM.generate() fails. Recoverable."""
pass
class EngineDeadError(Exception):
"""Raised when the EngineCore dies. Unrecoverable."""
def __init__(self, *args, suppress_context: bool = False, **kwargs):
ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501
super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs)
# Make stack trace clearer when using with LLMEngine by
# silencing irrelevant ZMQError.
self.__suppress_context__ = suppress_context

View File

@@ -0,0 +1,317 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from copy import copy
from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import (
TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class LLMEngine:
"""Legacy LLMEngine for backwards compatibility."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
) -> None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
if stat_loggers is not None:
raise NotImplementedError(
"Passing StatLoggers to LLMEngine in V1 is not yet supported. "
"Set VLLM_USE_V1=0 and file and issue on Github.")
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.log_stats = log_stats
self.stat_logger: Optional[StatLoggerBase] = None
if self.log_stats:
self.stat_logger = PrometheusStatLogger(vllm_config)
# important: init dp group before init the engine_core
# In the decoupled engine case this is handled in EngineCoreProc.
parallel_config = vllm_config.parallel_config
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
self.dp_group = parallel_config.stateless_init_dp_group()
else:
self.dp_group = None
self.should_execute_dummy_batch = False
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config,
tokenizer=self.tokenizer,
mm_registry=mm_registry)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(
multiprocess_mode=multiprocess_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)
if not multiprocess_mode:
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
# Don't keep the dummy data in memory
self.reset_mm_cache()
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
return cls(vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=(not disable_log_stats),
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
enable_multiprocessing: bool = False,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.")
enable_multiprocessing = True
# Create the LLMEngine.
return cls(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing)
def get_num_unfinished_requests(self) -> int:
return self.output_processor.get_num_unfinished_requests()
def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
if self.dp_group is None:
return has_unfinished
return self.has_unfinished_requests_dp(has_unfinished)
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
self.dp_group, has_unfinished)
if not has_unfinished and aggregated_has_unfinished:
self.should_execute_dummy_batch = True
return aggregated_has_unfinished
@classmethod
def validate_outputs(cls, outputs, output_type):
return outputs
def abort_request(self, request_ids: list[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
request_ids = self.output_processor.abort_requests(request_ids)
self.engine_core.abort_requests(request_ids)
def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
# Process raw inputs into the request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
# Make a new RequestState and queue.
self.output_processor.add_request(request, prompt_str, None, 0)
# Add the request to EngineCore.
self.engine_core.add_request(request)
return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
# Make a new RequestState and queue.
self.output_processor.add_request(child_request, prompt_str,
parent_req, idx)
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput]:
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
self.engine_core.execute_dummy_batch()
return []
# 1) Get EngineCoreOutput from the EngineCore.
outputs = self.engine_core.get_output()
# 2) Process EngineCoreOutputs.
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats)
# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
# 4) Record stats
if self.stat_logger is not None:
assert outputs.scheduler_stats is not None
self.stat_logger.record(scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats)
return processed_outputs.request_outputs
def get_vllm_config(self):
return self.vllm_config
def get_model_config(self):
return self.model_config
def start_profile(self):
self.engine_core.profile(True)
def stop_profile(self):
self.engine_core.profile(False)
def reset_mm_cache(self):
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_input_cache_client.reset()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, device: Optional[Device] = None):
self.engine_core.reset_prefix_cache()
def sleep(self, level: int = 1):
self.engine_core.sleep(level)
def wake_up(self, tags: Optional[list[str]] = None):
self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()
def get_metrics(self) -> list[Metric]:
assert self.log_stats, "Stat logging disabled"
return get_metrics_snapshot()
def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
return self.tokenizer
def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
return self.engine_core.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
"""Remove an already loaded LoRA adapter."""
return self.engine_core.remove_lora(lora_id)
def list_loras(self) -> set[int]:
"""List all registered adapters."""
return self.engine_core.list_loras()
def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def __del__(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)

199
vllm/v1/engine/logprobs.py Normal file
View File

@@ -0,0 +1,199 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional
from vllm.logger import init_logger
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_ids_list_to_tokens)
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
logger = init_logger(__name__)
NONES = itertools.repeat(None)
@dataclass
class LogprobsProcessor:
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: Optional[AnyTokenizer]
# Logprobs for this request
logprobs: Optional[SampleLogprobs]
prompt_logprobs: Optional[PromptLogprobs]
cumulative_logprob: Optional[float]
num_logprobs: Optional[int]
num_prompt_logprobs: Optional[int]
@classmethod
def from_new_request(
cls,
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "LogprobsProcessor":
num_logprobs = request.sampling_params.logprobs
num_prompt_logprobs = request.sampling_params.prompt_logprobs
return cls(
tokenizer=tokenizer,
cumulative_logprob=(None if num_logprobs is None else 0.),
logprobs=(None if num_logprobs is None else []),
# NOTE: logprob of first prompt token is None.
prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
num_prompt_logprobs=num_prompt_logprobs,
num_logprobs=num_logprobs,
)
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
"""Update with sample logprobs from EngineCore.
Outer lists are only of len > 1 if EngineCore made
>1 tokens in prior step (e.g. in spec decoding).
Args:
logprobs_lists: the lists of logprob tokens, logprobs, and ranks.
"""
assert self.num_logprobs is not None
assert self.logprobs is not None
assert self.cumulative_logprob is not None
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst,
token_ids_lst):
# Detokenize (non-incrementally).
decoded_tokens = NONES if self.tokenizer is None else (
convert_ids_list_to_tokens(self.tokenizer, token_ids))
# Sampler puts the sampled logprob in first.
sampled_token_logprob = logprobs[0]
self.cumulative_logprob += sampled_token_logprob
# Update with the Logprob dictionary for this pos.
self.logprobs.append(
self._make_logprob_dict(
logprobs,
token_ids,
decoded_tokens,
rank,
self.num_logprobs,
))
def _update_prompt_logprobs(
self,
prompt_logprobs_tensors: LogprobsTensors,
) -> None:
"""Update with prompt logprobs from EngineCore.
Args:
prompt_logprobs_tensors: tuple containing the prompt logprobs
tensors.
"""
# Prompt logprobs are enabled.
assert self.num_prompt_logprobs is not None
assert self.prompt_logprobs is not None
token_ids, logprobs, ranks = prompt_logprobs_tensors
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = None if self.tokenizer is None else (
convert_ids_list_to_tokens(self.tokenizer,
token_ids.flatten().tolist()))
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the torch tensors.
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
# Make Logprob for each position.
for pos in range(num_prompt_tokens):
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = NONES \
if decoded_tokens is None else decoded_tokens[offset:offset_end]
# Update with the Logprob dictionary for this pos.
self.prompt_logprobs.append(
self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
self.num_prompt_logprobs))
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
"""Pop and return all request prompt logprobs
The logprobs processor aggregates prompt chunk logprobs
over one or more prefill chunks. This method returns
all prompt logprobs at once and then forgets them.
Ensures correct RequestOutputKind.DELTA semantics
wherein all prompt logprobs are returned at once at
the end of prefill.
Returns:
None if prompt logprobs are disabled for this request.
List of all prompt logprobs, otherwise.
"""
plp = self.prompt_logprobs
if plp:
self.prompt_logprobs = []
return plp
@staticmethod
def _make_logprob_dict(
logprobs: list[float],
logprob_token_ids: list[int],
decoded_tokens: Iterable[Optional[str]],
rank: int,
num_logprobs: int,
) -> dict[int, Logprob]:
"""Make a Logprob dictionary for a position.
Args:
logprobs: list of log probabilities
logprob_token_ids: list of top token ids
decoded_tokens: list of decoded top tokens
rank: rank of the sampled token
num_logprobs: number of logprobs requested
by the user (in addition to sampled logprob)
Returns:
dict[token id, Logprob]
"""
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank, ), topk_ranks)
return {
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(
logprob_token_ids, logprobs, ranks, decoded_tokens)
}
def update_from_output(self, output: EngineCoreOutput) -> None:
if output.new_logprobs is not None:
self._update_sample_logprobs(output.new_logprobs)
if output.new_prompt_logprobs_tensors is not None:
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)

View File

@@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Optional
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.processing import ProcessingCache
from vllm.utils import is_list_of
# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
#
# -- Client:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier.
# - MirroredProcessingCache to keep track of the cached entries and
# determine whether to send the MultiModalKwargs to P1.
#
# -- Server:
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
#
# The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
# client (=P0) and server (=P1) processes if the mm_hash is found in the client
# cache.
# Both Client and Server must use the same cache size
# (to perform mirrored caching). This cache size is set by the environment
# variable VLLM_MM_INPUT_CACHE_GIB.
class MirroredProcessingCache:
def __init__(self, model_config):
mm_config = model_config.multimodal_config
disable_mm_preprocessor_cache = (
mm_config is not None and mm_config.disable_mm_preprocessor_cache)
self.use_cache = not disable_mm_preprocessor_cache
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)
def get_and_update_p0(
self,
mm_inputs: Sequence[MultiModalKwargs],
mm_hashes: list[str],
) -> Sequence[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = list[Optional[MultiModalKwargs]]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if self.mm_cache.get(mm_hash) is not None:
mm_input = None
else:
self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input)
return full_mm_inputs
def get_and_update_p1(
self,
mm_inputs: Sequence[Optional[MultiModalKwargs]],
mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = list[MultiModalKwargs]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_input is None:
mm_input = self.mm_cache[mm_hash]
else:
self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input)
return full_mm_inputs
def reset(self) -> bool:
self.mm_cache.clear()
return True

View File

@@ -0,0 +1,428 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Optional, Union
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
RequestStateStats)
class RequestOutputCollector:
"""
Collects streamed RequestOutputs per individual request,
for hand-off to the consuming asyncio generate task.
When streaming deltas, RequestOutputs are merged if the
producer gets ahead of the consumer.
"""
def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: Optional[Union[RequestOutput, Exception]] = None
self.ready = asyncio.Event()
def put(self, output: Union[RequestOutput, Exception]) -> None:
"""Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
self.output = output
self.ready.set()
elif isinstance(self.output, RequestOutput):
# This ensures that request outputs with different request indexes
# (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate)
async def get(self) -> RequestOutput:
"""Get operation blocks on put event."""
while (output := self.output) is None:
await self.ready.wait()
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
def get_nowait(self) -> Optional[RequestOutput]:
"""Non-blocking get operation."""
output = self.output
if output is not None:
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
@dataclass
class OutputProcessorOutput:
request_outputs: list[RequestOutput]
reqs_to_abort: list[str]
class RequestState:
def __init__(
self,
request_id: str,
parent_req: Optional[ParentRequest],
request_index: int,
lora_name: Optional[str],
output_kind: RequestOutputKind,
prompt: Optional[str],
prompt_token_ids: list[int],
logprobs_processor: LogprobsProcessor,
detokenizer: IncrementalDetokenizer,
max_tokens_param: Optional[int],
arrival_time: float,
queue: Optional[RequestOutputCollector],
log_stats: bool,
):
self.request_id = request_id
self.parent_req = parent_req
self.request_index = request_index
self.lora_name = lora_name
self.output_kind = output_kind
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_len = len(prompt_token_ids)
self.logprobs_processor = logprobs_processor
self.detokenizer = detokenizer
self.max_tokens_param = max_tokens_param
self.is_prefilling = True
self.queue = queue
self.stats = RequestStateStats(
arrival_time=arrival_time) if log_stats else None
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest],
request_index: int,
queue: Optional[RequestOutputCollector],
log_stats: bool,
) -> "RequestState":
if not request.sampling_params.detokenize:
tokenizer = None
return cls(
request_id=request.request_id,
parent_req=parent_req,
request_index=request_index,
lora_name=(request.lora_request.name
if request.lora_request is not None else None),
output_kind=request.sampling_params.output_kind,
prompt=prompt,
prompt_token_ids=request.prompt_token_ids,
logprobs_processor=LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
),
detokenizer=IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
),
max_tokens_param=(request.sampling_params.max_tokens if
request.sampling_params is not None else None),
arrival_time=request.arrival_time,
queue=queue,
log_stats=log_stats,
)
def make_request_output(
self,
new_token_ids: list[int],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> Optional[RequestOutput]:
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
if not finished and final_only:
# Only the final output is required in FINAL_ONLY mode.
return None
completion_output = self._new_completion_output(
new_token_ids, finish_reason, stop_reason)
request_id = self.request_id
if self.parent_req is None:
outputs = [completion_output]
else:
request_id, outputs, finished = self.parent_req.get_outputs(
request_id, completion_output)
if not outputs:
return None
return self._new_request_output(request_id, outputs, finished,
kv_transfer_params, num_cached_tokens)
def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput],
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> RequestOutput:
if self.output_kind == RequestOutputKind.DELTA:
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
else:
prompt_logprobs = self.logprobs_processor.prompt_logprobs
return RequestOutput(
request_id=request_id,
prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids,
prompt_logprobs=prompt_logprobs,
outputs=outputs,
finished=finished,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=num_cached_tokens,
)
def _new_completion_output(
self,
token_ids: list[int],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
) -> CompletionOutput:
finished = finish_reason is not None
delta = self.output_kind == RequestOutputKind.DELTA
# Prepare text and token_ids, based on delta mode
text = self.detokenizer.get_next_output_text(finished, delta)
if not delta:
token_ids = self.detokenizer.output_token_ids
# Prepare logprobs, based on delta mode
logprobs = self.logprobs_processor.logprobs
if delta and logprobs:
logprobs = logprobs[-len(token_ids):]
return CompletionOutput(
index=self.request_index,
text=text,
token_ids=token_ids,
logprobs=logprobs,
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
finish_reason=str(finish_reason) if finished else None,
stop_reason=stop_reason if finished else None)
class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""
def __init__(
self,
tokenizer: TokenizerGroup,
log_stats: bool,
):
self.log_stats = log_stats
self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates()
def get_num_unfinished_requests(self):
return len(self.request_states)
def has_unfinished_requests(self) -> bool:
return len(self.request_states) > 0
def propagate_error(self, e: Exception):
"""Propagate error to all generate() tasks."""
for _, state in self.request_states.items():
assert state.queue is not None
state.queue.put(e)
def abort_requests(
self,
request_ids: Iterable[str],
) -> list[str]:
request_ids_to_abort = []
for request_id in request_ids:
req_state = self.request_states.pop(request_id, None)
if req_state is not None:
self.lora_states.abort_request(req_state)
request_ids_to_abort.append(request_id)
else:
parent = self.parent_requests.pop(request_id, None)
if parent and parent.child_requests:
self.abort_requests(parent.child_requests)
request_ids_to_abort.extend(parent.child_requests)
return request_ids_to_abort
def add_request(
self,
request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest] = None,
request_index: int = 0,
queue: Optional[RequestOutputCollector] = None,
) -> None:
request_id = request.request_id
if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.")
req_state = RequestState.from_new_request(
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
request=request,
prompt=prompt,
parent_req=parent_req,
request_index=request_index,
queue=queue,
log_stats=self.log_stats)
self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req
def process_outputs(
self,
engine_core_outputs: list[EngineCoreOutput],
engine_core_timestamp: Optional[float] = None,
iteration_stats: Optional[IterationStats] = None,
) -> OutputProcessorOutput:
"""
Process the EngineCoreOutputs:
1) Compute stats for logging
2) Detokenize
3) Create and handle RequestOutput objects:
* If there is a queue (for usage with AsyncLLM),
put the RequestOutput objects into the queue for
handling by the per-request generate() tasks.
* If there is no queue (for usage with LLMEngine),
return a list of RequestOutput objects.
NOTE FOR DEVELOPERS
vLLM V1 minimizes the number of python loops over the full
batch to ensure system overheads are minimized. This is the
only function that should loop over EngineCoreOutputs.
If you need to touch every element of the batch, do it from
within the loop below.
"""
request_outputs: list[RequestOutput] = []
reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id
req_state = self.request_states.get(req_id)
if req_state is None:
# Ignore output for already-aborted request.
continue
# 1) Compute stats for this iteration.
self._update_stats_from_output(req_state, engine_core_output,
engine_core_timestamp,
iteration_stats)
new_token_ids = engine_core_output.new_token_ids
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP)
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request, if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason,
kv_transfer_params, num_cached_tokens):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
else:
# LLMEngine: return list of RequestOutputs.
request_outputs.append(request_output)
# Free completed requests.
if finish_reason is not None:
self.request_states.pop(req_id)
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
# Track per-request stats
self._update_stats_from_finished(req_state, finish_reason,
iteration_stats)
self.lora_states.update_iteration_stats(iteration_stats)
return OutputProcessorOutput(
request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort,
)
def _update_stats_from_output(self, req_state: RequestState,
engine_core_output: EngineCoreOutput,
engine_core_timestamp: Optional[float],
iteration_stats: Optional[IterationStats]):
if iteration_stats is None:
return
lora_stats = self.lora_states.get_stats(req_state)
assert engine_core_timestamp is not None
assert req_state.stats is not None
iteration_stats.update_from_output(engine_core_output,
engine_core_timestamp,
req_state.is_prefilling,
req_state.prompt_len,
req_state.stats, lora_stats)
def _update_stats_from_finished(self, req_state: RequestState,
finish_reason: Optional[FinishReason],
iteration_stats: Optional[IterationStats]):
if iteration_stats is None:
return
assert finish_reason is not None
assert req_state.stats is not None
iteration_stats.update_from_finished_request(
finish_reason=finish_reason,
num_prompt_tokens=len(req_state.prompt_token_ids),
max_tokens_param=req_state.max_tokens_param,
req_stats=req_state.stats)
self.lora_states.finish_request(req_state)
ParentRequest.observe_finished_request(
req_state.parent_req, iteration_stats,
req_state.stats.num_generation_tokens)

View File

@@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import copy
from typing import Optional
from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.metrics.stats import IterationStats
class ParentRequest:
"""Info, state & processing for parallel sampling request.
Store parent request ID and sampling params.
Facilitate generating child request sampling params.
"""
request_id: str
sampling_params: SamplingParams
# To track the completion of child requests
child_requests: set[str]
# To aggregate child completions when not streaming
output_aggregator: list[CompletionOutput]
# To find the max number of generated tokens across all children
max_num_generation_tokens: int
# To efficiently obtain child sampling params
cached_child_sampling_params: Optional[SamplingParams]
def __init__(self, request_id: str,
sampling_params: SamplingParams) -> None:
self.request_id = request_id
self.sampling_params = sampling_params
self.child_requests = set()
self.output_aggregator = [None] * sampling_params.n if (
sampling_params.output_kind
== RequestOutputKind.FINAL_ONLY) else []
self.max_num_generation_tokens = 0
self.cached_child_sampling_params = None
def _get_child_sampling_params(
self,
index: int,
) -> SamplingParams:
"""Efficiently obtain child `sampling_params`
If `sampling_params.seed` is not `None` then
each child request requires a unique clone of
parent `sampling_params` with a unique seed.
Args:
index: index within `n` child requests
Returns:
Child `sampling_params` instance.
"""
seed = self.sampling_params.seed
if self.cached_child_sampling_params:
# Reuse child sampling_params data structure
return self.cached_child_sampling_params
# Build child sampling_params
child_sampling_params = copy(self.sampling_params)
child_sampling_params.n = 1
if seed is None:
# Cache child sampling_params for later reuse
self.cached_child_sampling_params = child_sampling_params
else:
# Each child gets a clone with a unique seed
child_sampling_params.seed = seed + index
return child_sampling_params
def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
"""Get child request ID and sampling params.
Args:
index: index within `n` child requests.
Returns:
(request ID, sampling_params) tuple
"""
child_req_id = f"{index}_{self.request_id}"
self.child_requests.add(child_req_id)
return child_req_id, self._get_child_sampling_params(index)
@property
def n(self) -> int:
return self.sampling_params.n
def get_outputs(
self,
child_request_id: str,
completion_output: CompletionOutput,
) -> tuple[str, list[CompletionOutput], bool]:
if completion_output.finished():
self.child_requests.remove(child_request_id)
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
# If streaming, just return the current output.
outputs = [completion_output]
else:
# If not streaming, aggregate the n final outputs.
self.output_aggregator[completion_output.index] = completion_output
outputs = [] if self.child_requests else self.output_aggregator
finished = not self.child_requests
return self.request_id, outputs, finished
def observe_num_generation_tokens(self, num_generation_tokens: int):
self.max_num_generation_tokens = max(num_generation_tokens,
self.max_num_generation_tokens)
return self.max_num_generation_tokens
@staticmethod
def observe_finished_request(parent_req: Optional['ParentRequest'],
iteration_stats: IterationStats,
num_generation_tokens: int):
n_param = parent_req.n if parent_req is not None else 1
if parent_req is not None:
num_generation_tokens = parent_req.observe_num_generation_tokens(
num_generation_tokens)
# Child requests finished, we can now record to iteration stats
if parent_req is None or not parent_req.child_requests:
iteration_stats.max_num_generation_tokens_iter.append(
num_generation_tokens)
iteration_stats.n_params_iter.append(n_param)

407
vllm/v1/engine/processor.py Normal file
View File

@@ -0,0 +1,407 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Mapping, Sequence
from typing import Any, Literal, Optional, Union
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_xgrammar import (
validate_xgrammar_grammar)
class Processor:
def __init__(
self,
vllm_config: VllmConfig,
tokenizer: TokenizerGroup,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.decoding_config = vllm_config.decoding_config
self.tokenizer = tokenizer
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
# Multi-modal hasher (for images)
self.use_hash = self.mm_input_cache_client.use_cache or \
self.cache_config.enable_prefix_caching
@property
def mm_registry(self):
return self.input_preprocessor.mm_registry
def _validate_logprobs(
self,
params: SamplingParams,
) -> None:
max_logprobs = self.model_config.max_logprobs
# Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs:
raise ValueError(
f"Requested sample logprobs of {params.logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
# Validate prompt logprobs.
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
raise ValueError(
f"Requested prompt logprobs of {params.prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
def _validate_sampling_params(
self,
params: SamplingParams,
lora_request: Optional[LoRARequest],
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
if params.allowed_token_ids is None:
return
if not params.allowed_token_ids:
raise ValueError("allowed_token_ids is not None and empty!")
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
vocab_size = len(tokenizer)
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")
def _validate_logit_bias(
self,
params: SamplingParams,
) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not params.logit_bias:
return
vocab_size = self.model_config.get_vocab_size()
invalid_token_ids = []
for token_id in params.logit_bias:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if invalid_token_ids:
raise ValueError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
def _validate_supported_sampling_params(
self,
params: SamplingParams,
) -> None:
# Best of not yet supported.
if params.best_of is not None and params.best_of > 1:
raise ValueError("vLLM V1 does not yet support best_of.")
# Logits processors not supported.
if params.logits_processors:
raise ValueError("vLLM V1 does not support per request "
"user provided logits processors.")
def _validate_params(
self,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest],
):
"""
Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server.
"""
if not isinstance(params, SamplingParams):
raise ValueError("V1 does not yet support Pooling models.")
self._validate_logprobs(params)
self._validate_sampling_params(params, lora_request)
self._validate_supported_sampling_params(params)
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return
engine_level_backend = self.decoding_config.backend
if params.guided_decoding.backend:
# Request-level backend selection is not supported in V1.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_auto` option set on the backend in the params.
if (params.guided_decoding.backend != engine_level_backend
and not (engine_level_backend == "auto"
and params.guided_decoding.backend_was_auto)):
raise ValueError(
"Request-level structured output backend selection is no "
"longer supported. The request specified "
f"'{params.guided_decoding.backend}', but vLLM was "
f"initialised with '{engine_level_backend}'. This error "
"can be resolved by removing backend selection from the "
"request.")
else:
params.guided_decoding.backend = engine_level_backend
# Request content validation
if engine_level_backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_xgrammar_grammar(params)
elif engine_level_backend.startswith("guidance"):
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
else:
# NOTE: engine_level_backend must be "auto" here, because we have
# checked supported_backends above.
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
# default as it is less predictable and subject to change
# between releases as feature support changes.
try:
validate_xgrammar_grammar(params)
params.guided_decoding.backend = "xgrammar"
except ValueError:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
validate_guidance_grammar(params, tokenizer=None)
params.guided_decoding.backend = "guidance"
# Remember that this backend was set automatically
params.guided_decoding.backend_was_auto = True
def process_inputs(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> tuple[Optional[str], EngineCoreRequest]:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request)
self._validate_params(params, lora_request)
if priority != 0:
raise ValueError("V1 does not support priority yet.")
if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.")
if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.")
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
data_parallel_size):
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
f"is out of range [0, {data_parallel_size}).")
if arrival_time is None:
arrival_time = time.time()
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
# 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=self.use_hash,
)
from vllm.platforms import current_platform
current_platform.validate_request(
prompt=prompt,
params=params,
processed_inputs=processed_inputs,
)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
self._validate_model_inputs(processed_inputs, lora_request)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
# TODO: Impl encoder-decoder
if encoder_inputs is not None:
raise NotImplementedError
assert isinstance(params, SamplingParams)
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (
self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
# Multimodal related.
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
(
sorted_item_modalities,
sorted_mm_positions,
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
decoder_inputs["mm_placeholders"],
decoder_inputs["mm_hashes"] if self.use_hash else None,
)
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# is a single MultiModalKwargs for all items from all modalities.
# This code flattens kwargs for individual items in a list and
# sorts them by each item's position in the input sequence if there
# are multiple modalities.
unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1:
orig_sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities}
for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
orig_sorted_mm_inputs.append(
MultiModalKwargs.from_items([item]))
used_indices[modality] += 1
else:
orig_sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
orig_sorted_mm_inputs, sorted_mm_hashes)
else:
sorted_mm_inputs = orig_sorted_mm_inputs
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_inputs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
sampling_params=sampling_params,
eos_token_id=eos_token_id,
arrival_time=arrival_time,
lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"),
data_parallel_rank=data_parallel_rank,
)
def _validate_model_inputs(self,
inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
model_config = self.model_config
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
prompt_ids = prompt_inputs["prompt_token_ids"]
if not prompt_ids:
if prompt_type == "encoder" and model_config.is_multimodal_model:
pass # Mllama may have empty encoder inputs for text-only data
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config,
tokenizer=tokenizer,
)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper
if model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens.")
raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens