[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
173
vllm/v1/engine/__init__.py
Normal file
173
vllm/v1/engine/__init__.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
# These are possible values of RequestOutput.finish_reason,
|
||||
# so form part of the external API.
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort")
|
||||
|
||||
|
||||
class FinishReason(enum.IntEnum):
|
||||
"""
|
||||
Reason a request finished - stop, length, or abort.
|
||||
|
||||
Int rather than Str for more compact serialization.
|
||||
|
||||
stop - a stop string was emitted
|
||||
length - max_tokens was consumed, or max_model_len was reached
|
||||
abort - aborted for another reason
|
||||
|
||||
"""
|
||||
STOP = 0
|
||||
LENGTH = 1
|
||||
ABORT = 2
|
||||
|
||||
def __str__(self):
|
||||
return FINISH_REASON_STRINGS[self.value]
|
||||
|
||||
|
||||
class EngineCoreRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
request_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
|
||||
mm_hashes: Optional[list[str]]
|
||||
mm_placeholders: Optional[list[PlaceholderRange]]
|
||||
sampling_params: SamplingParams
|
||||
eos_token_id: Optional[int]
|
||||
arrival_time: float
|
||||
lora_request: Optional[LoRARequest]
|
||||
cache_salt: Optional[str]
|
||||
data_parallel_rank: Optional[int]
|
||||
|
||||
# Index of the client, used to ensure outputs are sent back to the same
|
||||
# client for this request when scaling out the front-end.
|
||||
client_index: int = 0
|
||||
|
||||
# Used in DP case to indicate which wave of requests this is expected to
|
||||
# belong to, to cover a race condition where the request is sent before
|
||||
# a wave finished notification is received.
|
||||
current_wave: int = 0
|
||||
|
||||
|
||||
class EngineCoreEventType(enum.IntEnum):
|
||||
"""The type of engine core request event."""
|
||||
QUEUED = 1
|
||||
SCHEDULED = 2
|
||||
PREEMPTED = 3
|
||||
|
||||
|
||||
class EngineCoreEvent(msgspec.Struct):
|
||||
"""A timestamped engine core event associated with a request.
|
||||
|
||||
The timestamp is a monotonic timestamps and is used for by the engine
|
||||
frontend to calculate intervals between engine core events. These
|
||||
timestamps should not be compared with timestamps from other processes.
|
||||
"""
|
||||
type: EngineCoreEventType
|
||||
timestamp: float
|
||||
|
||||
@classmethod
|
||||
def new_event(cls,
|
||||
event_type: EngineCoreEventType,
|
||||
timestamp: Optional[float] = None) -> "EngineCoreEvent":
|
||||
timestamp = time.monotonic() if timestamp is None else timestamp
|
||||
return cls(event_type, timestamp)
|
||||
|
||||
|
||||
class EngineCoreOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
request_id: str
|
||||
new_token_ids: list[int]
|
||||
|
||||
new_logprobs: Optional[LogprobsLists] = None
|
||||
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
|
||||
|
||||
finish_reason: Optional[FinishReason] = None
|
||||
stop_reason: Union[int, str, None] = None
|
||||
events: Optional[list[EngineCoreEvent]] = None
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
# The number of tokens with prefix cache hits.
|
||||
num_cached_tokens: int = 0
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self.finish_reason is not None
|
||||
|
||||
|
||||
class UtilityOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
call_id: int
|
||||
|
||||
# Non-None implies the call failed, result should be None.
|
||||
failure_message: Optional[str] = None
|
||||
result: Any = None
|
||||
|
||||
|
||||
class EngineCoreOutputs(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
#NOTE(Nick): We could consider ways to make this more compact,
|
||||
# e.g. columnwise layout
|
||||
|
||||
engine_index: int = 0
|
||||
|
||||
# [num_reqs]
|
||||
outputs: list[EngineCoreOutput] = []
|
||||
scheduler_stats: Optional[SchedulerStats] = None
|
||||
timestamp: float = 0.0
|
||||
|
||||
utility_output: Optional[UtilityOutput] = None
|
||||
finished_requests: Optional[set[str]] = None
|
||||
|
||||
# In DP case, used to signal that the current wave of requests
|
||||
# has finished and the engines are paused.
|
||||
wave_complete: Optional[int] = None
|
||||
# In DP case, used to signal that a request was received for an
|
||||
# "old" wave, so the next wave needs to be started in other engines.
|
||||
start_wave: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp == 0.0:
|
||||
self.timestamp = time.monotonic()
|
||||
|
||||
|
||||
class EngineCoreRequestType(enum.Enum):
|
||||
"""
|
||||
Request types defined as hex byte strings, so it can be sent over sockets
|
||||
without separate encoding step.
|
||||
"""
|
||||
ADD = b'\x00'
|
||||
ABORT = b'\x01'
|
||||
START_DP_WAVE = b'\x02'
|
||||
UTILITY = b'\x03'
|
||||
# Sentinel used within EngineCoreProc.
|
||||
EXECUTOR_FAILED = b'\x04'
|
||||
558
vllm/v1/engine/async_llm.py
Normal file
558
vllm/v1/engine/async_llm.py
Normal file
@@ -0,0 +1,558 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from copy import copy
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, cdiv
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||
RequestOutputCollector)
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
|
||||
setup_default_loggers)
|
||||
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncLLM(EngineClient):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = True,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Create an AsyncLLM.
|
||||
|
||||
Args:
|
||||
vllm_config: global configuration.
|
||||
executor_class: an Executor impl, e.g. MultiprocExecutor.
|
||||
log_stats: Whether to log stats.
|
||||
usage_context: Usage context of the LLM.
|
||||
mm_registry: Multi-modal registry.
|
||||
use_cached_outputs: Whether to use cached outputs.
|
||||
log_requests: Whether to log requests.
|
||||
start_engine_loop: Whether to start the engine loop.
|
||||
stat_loggers: customized stat loggers for the engine.
|
||||
If not provided, default stat loggers will be used.
|
||||
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
|
||||
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
# Ensure we can serialize custom transformer configs
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.log_requests = log_requests
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Set up stat loggers; independent set for each DP rank.
|
||||
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
|
||||
vllm_config=vllm_config,
|
||||
log_stats=self.log_stats,
|
||||
engine_num=vllm_config.parallel_config.data_parallel_size,
|
||||
custom_stat_loggers=stat_loggers,
|
||||
)
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(
|
||||
vllm_config=vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
mm_registry=mm_registry,
|
||||
)
|
||||
|
||||
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(self.tokenizer,
|
||||
log_stats=self.log_stats)
|
||||
|
||||
# EngineCore (starts the engine in background process).
|
||||
|
||||
self.engine_core = EngineCoreClient.make_async_mp_client(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
client_addresses=client_addresses,
|
||||
client_index=client_index,
|
||||
)
|
||||
if self.stat_loggers:
|
||||
for stat_logger in self.stat_loggers[0]:
|
||||
stat_logger.log_engine_initialized()
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
try:
|
||||
# Start output handler eagerly if we are in the asyncio eventloop.
|
||||
asyncio.get_running_loop()
|
||||
self._run_output_handler()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
disable_log_requests: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0,
|
||||
) -> "AsyncLLM":
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
# Create the LLMEngine.
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
start_engine_loop=start_engine_loop,
|
||||
stat_loggers=stat_loggers,
|
||||
log_requests=not disable_log_requests,
|
||||
log_stats=not disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
client_addresses=client_addresses,
|
||||
client_index=client_index,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
) -> "AsyncLLM":
|
||||
"""Create an AsyncLLM from the EngineArgs."""
|
||||
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
# Create the AsyncLLM.
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown, cleaning up the background proc and IPC."""
|
||||
|
||||
shutdown_prometheus()
|
||||
|
||||
if engine_core := getattr(self, "engine_core", None):
|
||||
engine_core.shutdown()
|
||||
|
||||
if handler := getattr(self, "output_handler", None):
|
||||
handler.cancel()
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> RequestOutputCollector:
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
if self.errored:
|
||||
raise EngineDeadError()
|
||||
|
||||
assert isinstance(params, SamplingParams), \
|
||||
"Pooling is not supported in V1"
|
||||
|
||||
# Create a new output collector for the request.
|
||||
queue = RequestOutputCollector(output_kind=params.output_kind)
|
||||
|
||||
# Convert Input --> Request.
|
||||
prompt_str, request = self.processor.process_inputs(
|
||||
request_id, prompt, params, arrival_time, lora_request,
|
||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||
priority, data_parallel_rank)
|
||||
|
||||
if params.n == 1:
|
||||
await self._add_request(request, prompt_str, None, 0, queue)
|
||||
return queue
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_request = ParentRequest(request_id, params)
|
||||
for idx in range(params.n):
|
||||
request_id, params = parent_request.get_child_info(idx)
|
||||
child_request = request if idx == params.n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
await self._add_request(child_request, prompt_str, parent_request,
|
||||
idx, queue)
|
||||
return queue
|
||||
|
||||
async def _add_request(self, request: EngineCoreRequest,
|
||||
prompt: Optional[str],
|
||||
parent_req: Optional[ParentRequest], index: int,
|
||||
queue: RequestOutputCollector):
|
||||
|
||||
# Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, prompt, parent_req, index,
|
||||
queue)
|
||||
|
||||
# Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
||||
# TODO: we should support multiple prompts in one call, as you
|
||||
# can do with LLM.generate. So that for multi-prompt completion
|
||||
# requests we don't need to send multiple messages to core proc,
|
||||
# and so we don't need multiple streams which then get
|
||||
# re-multiplexed in the API server anyhow.
|
||||
async def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""
|
||||
Main function called by the API server to kick off a request
|
||||
* 1) Making an AsyncStream corresponding to the Request.
|
||||
* 2) Processing the Input.
|
||||
* 3) Adding the Request to the Detokenizer.
|
||||
* 4) Adding the Request to the EngineCore (separate process).
|
||||
|
||||
A separate output_handler loop runs in a background AsyncIO task,
|
||||
pulling outputs from EngineCore and putting them into the
|
||||
per-request AsyncStream.
|
||||
|
||||
The caller of generate() iterates the returned AsyncGenerator,
|
||||
returning the RequestOutput back to the caller.
|
||||
"""
|
||||
|
||||
try:
|
||||
# We start the output_handler on the first call to generate() so
|
||||
# we can call __init__ before the event loop, which enables us
|
||||
# to handle startup failure gracefully in the OpenAI server.
|
||||
self._run_output_handler()
|
||||
|
||||
q = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
# The output_handler task pushes items into the queue.
|
||||
# This task pulls from the queue and yields to caller.
|
||||
finished = False
|
||||
while not finished:
|
||||
# Note: drain queue without await if possible (avoids
|
||||
# task switching under load which helps performance).
|
||||
out = q.get_nowait() or await q.get()
|
||||
|
||||
# Note: both OutputProcessor and EngineCore handle their
|
||||
# own request cleanup based on finished.
|
||||
finished = out.finished
|
||||
yield out
|
||||
|
||||
# If the request is disconnected by the client, generate()
|
||||
# is cancelled or the generator is garbage collected. So,
|
||||
# we abort the request if we end up here.
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s aborted.", request_id)
|
||||
raise
|
||||
|
||||
# Engine is dead. Do not abort since we shut down.
|
||||
except EngineDeadError:
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (engine dead).", request_id)
|
||||
raise
|
||||
|
||||
# Request validation error.
|
||||
except ValueError:
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (bad request).", request_id)
|
||||
raise
|
||||
|
||||
# Unexpected error in the generate() task (possibly recoverable).
|
||||
except Exception as e:
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed.", request_id)
|
||||
raise EngineGenerateError() from e
|
||||
|
||||
def _run_output_handler(self):
|
||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||
|
||||
if self.output_handler is not None:
|
||||
return
|
||||
|
||||
# Ensure that the task doesn't have a circular ref back to the AsyncLLM
|
||||
# object, or else it won't be garbage collected and cleaned up properly.
|
||||
engine_core = self.engine_core
|
||||
output_processor = self.output_processor
|
||||
log_stats = self.log_stats
|
||||
stat_loggers = self.stat_loggers if log_stats else None
|
||||
|
||||
async def output_handler():
|
||||
try:
|
||||
while True:
|
||||
# 1) Pull EngineCoreOutputs from the EngineCore.
|
||||
outputs = await engine_core.get_output_async()
|
||||
num_outputs = len(outputs.outputs)
|
||||
|
||||
iteration_stats = IterationStats() if (
|
||||
log_stats and num_outputs) else None
|
||||
|
||||
# Split outputs into chunks of at most
|
||||
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
|
||||
# event loop for too long.
|
||||
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
|
||||
slices = (outputs.outputs, )
|
||||
else:
|
||||
slices = np.array_split(
|
||||
outputs.outputs,
|
||||
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
|
||||
|
||||
for i, outputs_slice in enumerate(slices):
|
||||
# 2) Process EngineCoreOutputs.
|
||||
processed_outputs = output_processor.process_outputs(
|
||||
outputs_slice, outputs.timestamp, iteration_stats)
|
||||
# NOTE: RequestOutputs are pushed to their queues.
|
||||
assert not processed_outputs.request_outputs
|
||||
|
||||
# Allow other asyncio tasks to run between chunks
|
||||
if i + 1 < len(slices):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
await engine_core.abort_requests_async(
|
||||
processed_outputs.reqs_to_abort)
|
||||
|
||||
# 4) Logging.
|
||||
# TODO(rob): make into a coroutine and launch it in
|
||||
# background thread once Prometheus overhead is non-trivial.
|
||||
if stat_loggers:
|
||||
AsyncLLM._record_stats(
|
||||
stat_loggers[outputs.engine_index],
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("AsyncLLM output_handler failed.")
|
||||
output_processor.propagate_error(e)
|
||||
|
||||
self.output_handler = asyncio.create_task(output_handler())
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort RequestId in OutputProcessor and EngineCore."""
|
||||
|
||||
request_ids = self.output_processor.abort_requests((request_id, ))
|
||||
await self.engine_core.abort_requests_async(request_ids)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Aborted request %s.", request_id)
|
||||
|
||||
@staticmethod
|
||||
def _record_stats(
|
||||
stat_loggers: list[StatLoggerBase],
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
):
|
||||
"""static so that it can be used from the output_handler task
|
||||
without a circular ref to AsyncLLM."""
|
||||
for stat_logger in stat_loggers:
|
||||
stat_logger.record(scheduler_stats=scheduler_stats,
|
||||
iteration_stats=iteration_stats)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
):
|
||||
raise ValueError("Not Supported on V1 yet.")
|
||||
|
||||
async def get_vllm_config(self) -> VllmConfig:
|
||||
return self.vllm_config
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
async def get_decoding_config(self):
|
||||
raise ValueError("Not Supported on V1 yet.")
|
||||
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
return self.processor.input_preprocessor
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return False
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs=None,
|
||||
model_output=None,
|
||||
) -> None:
|
||||
for loggers in self.stat_loggers:
|
||||
for stat_logger in loggers:
|
||||
stat_logger.log()
|
||||
|
||||
async def check_health(self) -> None:
|
||||
logger.debug("Called check_health.")
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
await self.engine_core.profile_async(True)
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
await self.engine_core.profile_async(False)
|
||||
|
||||
async def reset_mm_cache(self) -> None:
|
||||
self.processor.mm_registry.reset_processor_cache()
|
||||
self.processor.mm_input_cache_client.reset()
|
||||
await self.engine_core.reset_mm_cache_async()
|
||||
|
||||
async def reset_prefix_cache(self,
|
||||
device: Optional[Device] = None) -> None:
|
||||
if device == Device.CPU:
|
||||
raise ValueError("Not supported on CPU.")
|
||||
await self.engine_core.reset_prefix_cache_async()
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
await self.engine_core.sleep_async(level)
|
||||
|
||||
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
await self.engine_core.wake_up_async(tags)
|
||||
|
||||
async def is_sleeping(self) -> bool:
|
||||
return await self.engine_core.is_sleeping_async()
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
return await self.engine_core.add_lora_async(lora_request)
|
||||
|
||||
async def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove an already loaded LoRA adapter."""
|
||||
return await self.engine_core.remove_lora_async(lora_id)
|
||||
|
||||
async def list_loras(self) -> set[int]:
|
||||
"""List all registered adapters."""
|
||||
return await self.engine_core.list_loras_async()
|
||||
|
||||
async def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return await self.engine_core.pin_lora_async(lora_id)
|
||||
|
||||
async def collective_rpc(self,
|
||||
method: str,
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None):
|
||||
"""
|
||||
Perform a collective RPC call to the given path.
|
||||
"""
|
||||
return await self.engine_core.collective_rpc_async(
|
||||
method, timeout, args, kwargs)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
# Is None before the loop is started.
|
||||
return self.output_handler is None or not self.output_handler.done()
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self.errored
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return self.engine_core.resources.engine_dead or not self.is_running
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
return EngineDeadError()
|
||||
253
vllm/v1/engine/coordinator.py
Normal file
253
vllm/v1/engine/coordinator.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
from typing import Optional
|
||||
|
||||
import msgspec.msgpack
|
||||
import zmq
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket
|
||||
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DPCoordinator:
|
||||
"""Coordinator process used for data-parallel deployments (DP>1).
|
||||
|
||||
Intermediates between multiple DP engine rank processes and one or more
|
||||
front-end API server processes.
|
||||
|
||||
* Collects stats from each DP engine (currently just waiting and running
|
||||
queue lengths), and publishes these to all front-ends for use in
|
||||
load-balancing decisions.
|
||||
|
||||
* Keeps track of the current DP "request wave" number and running state
|
||||
of the engines. This is received from the DP rank 0 engine and published
|
||||
to the front-end processes along with the current load stats.
|
||||
|
||||
The engines alternate between a global running/paused state. The global
|
||||
"request wave" number is a count of the number of times that the workers
|
||||
collectively move from a running state to a paused state. This transition
|
||||
is synchronized via the all-reduce operation performed in the
|
||||
DPEngineCoreProc._has_global_unfinished_reqs method.
|
||||
|
||||
* Broadcasts the START_DP_WAVE message to engines to move them from paused
|
||||
to running state when one engine receives a new request. This can happen
|
||||
in two cases:
|
||||
1) A front-end sending a new request while the engines are paused will
|
||||
concurrently notify the coordinator.
|
||||
2) An engine receiving a request for a stale request wave while in paused
|
||||
state will notify the coordinator.
|
||||
|
||||
Engines will move into running state when receiving a new request or
|
||||
START_DP_WAVE message.
|
||||
"""
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
|
||||
# Assume coordinator is colocated with front-end procs.
|
||||
front_publish_address = get_open_zmq_ipc_path()
|
||||
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
assert dp_size > 1, "Coordinator only used for data parallel"
|
||||
|
||||
local_only = dp_size == parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only, host)
|
||||
|
||||
context = get_mp_context()
|
||||
self.proc: multiprocessing.Process = context.Process(
|
||||
target=CoordinatorProc.run_coordinator,
|
||||
name="VLLM_DP_Coordinator",
|
||||
kwargs={
|
||||
"engine_count": parallel_config.data_parallel_size,
|
||||
"front_publish_address": front_publish_address,
|
||||
"back_output_address": back_output_address,
|
||||
"back_publish_address": back_publish_address,
|
||||
},
|
||||
daemon=True)
|
||||
self.proc.start()
|
||||
|
||||
self.stats_publish_address = front_publish_address
|
||||
self.coord_in_address = back_publish_address
|
||||
self.coord_out_address = back_output_address
|
||||
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
|
||||
|
||||
def get_stats_publish_address(self) -> str:
|
||||
return self.stats_publish_address
|
||||
|
||||
def get_engine_socket_addresses(self) -> tuple[str, str]:
|
||||
"""Returns tuple of ZMQ input address, output address."""
|
||||
return self.coord_in_address, self.coord_out_address
|
||||
|
||||
def close(self):
|
||||
self._finalizer()
|
||||
|
||||
|
||||
class EngineState:
|
||||
|
||||
def __init__(self):
|
||||
self.request_counts = [0, 0] # [waiting, running]
|
||||
|
||||
|
||||
class CoordinatorProc:
|
||||
|
||||
def __init__(self, engine_count: int):
|
||||
|
||||
self.ctx = zmq.Context()
|
||||
|
||||
self.engines = [EngineState() for _ in range(engine_count)]
|
||||
|
||||
self.current_wave = 0
|
||||
self.engines_running = False
|
||||
self.stats_changed = False
|
||||
|
||||
@staticmethod
|
||||
def run_coordinator(
|
||||
engine_count: int,
|
||||
front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
):
|
||||
coordinator = CoordinatorProc(engine_count=engine_count)
|
||||
try:
|
||||
coordinator.process_input_socket(
|
||||
front_publish_address,
|
||||
back_output_address,
|
||||
back_publish_address,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DP Coordinator process exiting")
|
||||
|
||||
def process_input_socket(self, front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str):
|
||||
|
||||
decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
with make_zmq_socket(
|
||||
path=front_publish_address, # IPC
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.XPUB,
|
||||
bind=True,
|
||||
) as publish_front, make_zmq_socket(
|
||||
path=back_output_address, # IPC or TCP
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.PULL,
|
||||
bind=True,
|
||||
) as output_back, make_zmq_socket(
|
||||
path=back_publish_address, # IPC or TCP
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.XPUB,
|
||||
bind=True,
|
||||
) as publish_back:
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(publish_front, zmq.POLLIN)
|
||||
poller.register(output_back, zmq.POLLIN)
|
||||
last_publish_time = 0
|
||||
while True:
|
||||
elapsed = int(time.time() * 1000) - last_publish_time
|
||||
# Send at 100 ms interval if the stats have changed,
|
||||
# or otherwise every 3 seconds.
|
||||
wait_for = 100 if self.stats_changed else 3000
|
||||
events = poller.poll(timeout=max(0, wait_for - elapsed))
|
||||
if not events:
|
||||
# Poller timeout - publish current stats to front-ends.
|
||||
engine_req_counts_list = self._get_engine_counts()
|
||||
to_publish = (engine_req_counts_list, self.current_wave,
|
||||
self.engines_running)
|
||||
publish_front.send(msgspec.msgpack.encode(to_publish))
|
||||
last_publish_time = int(time.time() * 1000)
|
||||
self.stats_changed = False
|
||||
continue
|
||||
|
||||
events = dict(events)
|
||||
|
||||
if publish_front in events:
|
||||
buffer = publish_front.recv()
|
||||
if buffer == b'\x01':
|
||||
# Ignore subscription messages.
|
||||
continue
|
||||
|
||||
# We received a message on the front-end XPUB socket,
|
||||
# from an API server sending a new request while the
|
||||
# engines are paused, so that we can wake the other
|
||||
# engines.
|
||||
engine_to_exclude, wave = msgspec.msgpack.decode(buffer)
|
||||
if wave < self.current_wave:
|
||||
# If the wave number is stale, ensure the message is
|
||||
# handled by all the engines.
|
||||
engine_to_exclude = None
|
||||
if not self.engines_running:
|
||||
self.engines_running = True
|
||||
self.stats_changed = True
|
||||
self._send_start_wave(publish_back, self.current_wave,
|
||||
engine_to_exclude)
|
||||
|
||||
if output_back in events:
|
||||
# We received a message from one of the engines.
|
||||
|
||||
buffer = output_back.recv()
|
||||
outputs: EngineCoreOutputs = decoder.decode(buffer)
|
||||
|
||||
assert not outputs.outputs
|
||||
assert outputs.utility_output is None
|
||||
|
||||
eng_index = outputs.engine_index
|
||||
if outputs.scheduler_stats:
|
||||
# 1. Updated request load stats - update our local
|
||||
# state with these.
|
||||
stats = self.engines[eng_index].request_counts
|
||||
stats[0] = outputs.scheduler_stats.num_waiting_reqs
|
||||
stats[1] = outputs.scheduler_stats.num_running_reqs
|
||||
self.stats_changed = True
|
||||
|
||||
if (wave := outputs.wave_complete) is not None:
|
||||
# 2. Notification from rank 0 engine that we've
|
||||
# moved into the global paused state
|
||||
# (engines_running==False)
|
||||
if self.current_wave <= wave:
|
||||
logger.debug("Moving DP wave from %d to %d.",
|
||||
self.current_wave, wave)
|
||||
self.current_wave = wave + 1
|
||||
self.engines_running = False
|
||||
self.stats_changed = True
|
||||
elif (wave := outputs.start_wave) is not None and (
|
||||
wave > self.current_wave or
|
||||
(wave == self.current_wave
|
||||
and not self.engines_running)):
|
||||
# 3. The engine received request for a non-current wave
|
||||
# so we must ensure that other engines progress to the
|
||||
# next wave (race condition handling).
|
||||
logger.debug(
|
||||
"Starting wave %d after notification of "
|
||||
"stale wave request from engine.", wave)
|
||||
self.current_wave = wave
|
||||
self.engines_running = True
|
||||
self.stats_changed = True
|
||||
self._send_start_wave(publish_back, wave, eng_index)
|
||||
|
||||
@staticmethod
|
||||
def _send_start_wave(socket: zmq.Socket, wave: int,
|
||||
exclude_engine_index: Optional[int]):
|
||||
"""Broadcast the START_DP_WAVE message to all the engines.
|
||||
It includes the current wave number and index of engine which
|
||||
has already received a request with this wave number and so doesn't
|
||||
require additional notification.
|
||||
"""
|
||||
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
|
||||
socket.send_multipart(
|
||||
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
|
||||
|
||||
def _get_engine_counts(self) -> list[list[int]]:
|
||||
"""Return list of [waiting, running] count lists for each engine."""
|
||||
return [e.request_counts for e in self.engines]
|
||||
962
vllm/v1/engine/core.py
Normal file
962
vllm/v1/engine/core.py
Normal file
@@ -0,0 +1,962 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import Future
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.executor.multiproc_worker_utils import _add_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils.dump_input import dump_engine_exception
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
|
||||
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_S = 2.5
|
||||
HANDSHAKE_TIMEOUT_MINS = 5
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
|
||||
class EngineCore:
|
||||
"""Inner loop of vLLM's Engine."""
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
executor_fail_callback: Optional[Callable] = None):
|
||||
assert vllm_config.model_config.runner_type != "pooling"
|
||||
|
||||
# plugins need to be loaded at the engine/scheduler level too
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
|
||||
VLLM_VERSION, vllm_config)
|
||||
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Setup Model.
|
||||
self.model_executor = executor_class(vllm_config)
|
||||
if executor_fail_callback is not None:
|
||||
self.model_executor.register_failure_callback(
|
||||
executor_fail_callback)
|
||||
|
||||
# Setup KV Caches and update CacheConfig after profiling.
|
||||
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
|
||||
self._initialize_kv_caches(vllm_config)
|
||||
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self.structured_output_manager = StructuredOutputManager(vllm_config)
|
||||
|
||||
# Setup scheduler.
|
||||
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
|
||||
Scheduler = resolve_obj_by_qualname(
|
||||
vllm_config.scheduler_config.scheduler_cls)
|
||||
else:
|
||||
Scheduler = vllm_config.scheduler_config.scheduler_cls
|
||||
|
||||
# This warning can be removed once the V1 Scheduler interface is
|
||||
# finalized and we can maintain support for scheduler classes that
|
||||
# implement it
|
||||
if Scheduler is not V1Scheduler:
|
||||
logger.warning(
|
||||
"Using configured V1 scheduler class %s. "
|
||||
"This scheduler interface is not public and "
|
||||
"compatibility may not be maintained.",
|
||||
vllm_config.scheduler_config.scheduler_cls)
|
||||
|
||||
self.scheduler: SchedulerInterface = Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
structured_output_manager=self.structured_output_manager,
|
||||
include_finished_set=vllm_config.parallel_config.data_parallel_size
|
||||
> 1,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
# Setup MM Input Mapper.
|
||||
self.mm_input_cache_server = MirroredProcessingCache(
|
||||
vllm_config.model_config)
|
||||
|
||||
# Setup batch queue for pipeline parallelism.
|
||||
# Batch queue for scheduled batches. This enables us to asynchronously
|
||||
# schedule and execute batches, and is required by pipeline parallelism
|
||||
# to eliminate pipeline bubbles.
|
||||
self.batch_queue_size = self.model_executor.max_concurrent_batches
|
||||
self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
|
||||
SchedulerOutput]]] = None
|
||||
if self.batch_queue_size > 1:
|
||||
logger.info("Batch queue is enabled with size %d",
|
||||
self.batch_queue_size)
|
||||
self.batch_queue = queue.Queue(self.batch_queue_size)
|
||||
|
||||
def _initialize_kv_caches(
|
||||
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
||||
start = time.time()
|
||||
|
||||
# Get all kv cache needed by the model
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
|
||||
# Profiles the peak memory usage of the model to determine how much
|
||||
# memory can be allocated for kv cache.
|
||||
available_gpu_memory = self.model_executor.determine_available_memory()
|
||||
|
||||
assert len(kv_cache_specs) == len(available_gpu_memory)
|
||||
# Get the kv cache tensor size
|
||||
kv_cache_configs = [
|
||||
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
|
||||
available_gpu_memory_one_worker)
|
||||
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
|
||||
zip(kv_cache_specs, available_gpu_memory)
|
||||
]
|
||||
|
||||
# Since we use a shared centralized controller, we need the
|
||||
# `kv_cache_config` to be consistent across all workers to make sure
|
||||
# all the memory operators can be applied to all workers.
|
||||
unify_kv_cache_configs(kv_cache_configs)
|
||||
|
||||
# All workers have the same kv_cache_config except layer names, so use
|
||||
# an arbitrary one to initialize the scheduler.
|
||||
assert all([
|
||||
cfg.num_blocks == kv_cache_configs[0].num_blocks
|
||||
for cfg in kv_cache_configs
|
||||
])
|
||||
num_gpu_blocks = kv_cache_configs[0].num_blocks
|
||||
num_cpu_blocks = 0
|
||||
scheduler_kv_cache_config = kv_cache_configs[0]
|
||||
|
||||
# Initialize kv cache and warmup the execution
|
||||
self.model_executor.initialize_from_config(kv_cache_configs)
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(("init engine (profile, create kv cache, "
|
||||
"warmup model) took %.2f seconds"), elapsed)
|
||||
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
"""Add request to the scheduler."""
|
||||
|
||||
if request.mm_hashes is not None:
|
||||
# Here, if hash exists for a multimodal input, then it will be
|
||||
# fetched from the cache, else it will be added to the cache.
|
||||
# Note that the cache here is mirrored with the client cache, so
|
||||
# anything that has a hash must have a HIT cache entry here
|
||||
# as well.
|
||||
assert request.mm_inputs is not None
|
||||
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
|
||||
request.mm_inputs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
if req.use_structured_output:
|
||||
# Start grammar compilation asynchronously
|
||||
self.structured_output_manager.grammar_init(req)
|
||||
|
||||
if req.kv_transfer_params is not None and (
|
||||
not self.scheduler.get_kv_connector()):
|
||||
logger.warning("Got kv_transfer_params, but no KVConnector found. "
|
||||
"Disabling KVTransfer for this request.")
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def abort_requests(self, request_ids: list[str]):
|
||||
"""Abort requests from the scheduler."""
|
||||
|
||||
# TODO: The scheduler doesn't really need to know the
|
||||
# specific finish reason, TBD whether we propagate that
|
||||
# (i.e. client-aborted vs stop criteria met).
|
||||
self.scheduler.finish_requests(request_ids,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
|
||||
def execute_model(self, scheduler_output: SchedulerOutput):
|
||||
try:
|
||||
return self.model_executor.execute_model(scheduler_output)
|
||||
except BaseException as err:
|
||||
# NOTE: This method is exception-free
|
||||
dump_engine_exception(self.vllm_config, scheduler_output,
|
||||
self.scheduler.make_stats())
|
||||
# Re-raise exception
|
||||
raise err
|
||||
|
||||
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
|
||||
"""Schedule, execute, and make output.
|
||||
|
||||
Returns tuple of outputs and a flag indicating whether the model
|
||||
was executed.
|
||||
"""
|
||||
|
||||
# Check for any requests remaining in the scheduler - unfinished,
|
||||
# or finished and not yet removed from the batch.
|
||||
if not self.scheduler.has_requests():
|
||||
return {}, False
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
model_output = self.execute_model(scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output) # type: ignore
|
||||
|
||||
return (engine_core_outputs,
|
||||
scheduler_output.total_num_scheduled_tokens > 0)
|
||||
|
||||
def step_with_batch_queue(
|
||||
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
|
||||
"""Schedule and execute batches with the batch queue.
|
||||
Note that if nothing to output in this step, None is returned.
|
||||
|
||||
The execution flow is as follows:
|
||||
1. Try to schedule a new batch if the batch queue is not full.
|
||||
If a new batch is scheduled, directly return an empty engine core
|
||||
output. In other words, fulfilling the batch queue has a higher priority
|
||||
than getting model outputs.
|
||||
2. If there is no new scheduled batch, meaning that the batch queue
|
||||
is full or no other requests can be scheduled, we block until the first
|
||||
batch in the job queue is finished.
|
||||
3. Update the scheduler from the output.
|
||||
"""
|
||||
assert self.batch_queue is not None
|
||||
|
||||
engine_core_outputs = None
|
||||
scheduler_output = None
|
||||
# Try to schedule a new batch if the batch queue is not full, but
|
||||
# the scheduler may return an empty batch if all requests are scheduled.
|
||||
# Note that this is not blocking.
|
||||
if not self.batch_queue.full():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
future = self.model_executor.execute_model(scheduler_output)
|
||||
self.batch_queue.put_nowait(
|
||||
(future, scheduler_output)) # type: ignore
|
||||
|
||||
scheduled_batch = (scheduler_output is not None
|
||||
and scheduler_output.total_num_scheduled_tokens > 0)
|
||||
|
||||
# If no more requests can be scheduled and the job queue is not empty,
|
||||
# block until the first batch in the job queue is finished.
|
||||
# TODO(comaniac): Ideally we should peek the first batch in the
|
||||
# job queue to check if it's finished before scheduling a new batch,
|
||||
# but peeking the first element in a queue is not thread-safe,
|
||||
# so we need more work.
|
||||
if not scheduled_batch and not self.batch_queue.empty():
|
||||
future, scheduler_output = self.batch_queue.get_nowait()
|
||||
# Blocking until the first result is available.
|
||||
model_output = future.result()
|
||||
self.batch_queue.task_done()
|
||||
engine_core_outputs = (self.scheduler.update_from_output(
|
||||
scheduler_output, model_output))
|
||||
|
||||
return engine_core_outputs, scheduled_batch
|
||||
|
||||
def shutdown(self):
|
||||
self.structured_output_manager.clear_backend()
|
||||
if self.model_executor:
|
||||
self.model_executor.shutdown()
|
||||
if self.scheduler:
|
||||
self.scheduler.shutdown()
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
self.model_executor.profile(is_start)
|
||||
|
||||
def reset_mm_cache(self):
|
||||
# NOTE: Since this is mainly for debugging, we don't attempt to
|
||||
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
|
||||
if self.scheduler.has_unfinished_requests():
|
||||
logger.warning("Resetting the multi-modal cache when requests are "
|
||||
"in progress may lead to desynced internal caches.")
|
||||
|
||||
self.mm_input_cache_server.reset()
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
self.scheduler.reset_prefix_cache()
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
self.model_executor.sleep(level)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None):
|
||||
self.model_executor.wake_up(tags)
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
return self.model_executor.is_sleeping
|
||||
|
||||
def execute_dummy_batch(self):
|
||||
self.model_executor.collective_rpc("execute_dummy_batch")
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_executor.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_executor.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
return self.model_executor.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_executor.pin_lora(lora_id)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model_executor.save_sharded_state(path=path,
|
||||
pattern=pattern,
|
||||
max_size=max_size)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.model_executor.collective_rpc(method, timeout, args,
|
||||
kwargs)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config,
|
||||
) -> None:
|
||||
self.model_executor.save_tensorized_model(
|
||||
tensorizer_config=tensorizer_config, )
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
|
||||
ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
engine_index: int = 0,
|
||||
):
|
||||
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
|
||||
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
|
||||
bytes]]()
|
||||
executor_fail_callback = lambda: self.input_queue.put_nowait(
|
||||
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
|
||||
|
||||
self.engine_index = engine_index
|
||||
identity = self.engine_index.to_bytes(length=2, byteorder="little")
|
||||
self.engines_running = False
|
||||
|
||||
with self._perform_handshake(handshake_address, identity, on_head_node,
|
||||
vllm_config) as addresses:
|
||||
self.client_count = len(addresses.outputs)
|
||||
|
||||
# Set up data parallel environment.
|
||||
self.has_coordinator = addresses.coordinator_output is not None
|
||||
self._init_data_parallel(vllm_config)
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
executor_fail_callback)
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
threading.Thread(target=self.process_input_sockets,
|
||||
args=(addresses.inputs, addresses.coordinator_input,
|
||||
identity),
|
||||
daemon=True).start()
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_sockets,
|
||||
args=(addresses.outputs, addresses.coordinator_output,
|
||||
self.engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshake(
|
||||
self, handshake_address: str, identity: bytes, on_head_node: bool,
|
||||
vllm_config: VllmConfig
|
||||
) -> Generator[EngineZmqAddresses, None, None]:
|
||||
input_ctx = zmq.Context()
|
||||
with make_zmq_socket(input_ctx,
|
||||
handshake_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
linger=5000,
|
||||
bind=False) as handshake_socket:
|
||||
# Register engine with front-end.
|
||||
addresses = self.startup_handshake(handshake_socket, on_head_node,
|
||||
vllm_config.parallel_config)
|
||||
|
||||
# Update config which may have changed from the handshake
|
||||
vllm_config.__post_init__()
|
||||
|
||||
yield addresses
|
||||
|
||||
# Send ready message.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "READY",
|
||||
"local": on_head_node,
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
}))
|
||||
|
||||
@staticmethod
|
||||
def startup_handshake(
|
||||
handshake_socket: zmq.Socket, on_head_node: bool,
|
||||
parallel_config: ParallelConfig) -> EngineZmqAddresses:
|
||||
|
||||
# Send registration message.
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "HELLO",
|
||||
"local": on_head_node,
|
||||
}))
|
||||
|
||||
# Receive initialization message.
|
||||
logger.info("Waiting for init message from front-end.")
|
||||
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
|
||||
raise RuntimeError("Did not receive response from front-end "
|
||||
f"process within {HANDSHAKE_TIMEOUT_MINS} "
|
||||
f"minutes")
|
||||
init_bytes = handshake_socket.recv()
|
||||
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
|
||||
init_bytes, type=EngineHandshakeMetadata)
|
||||
logger.debug("Received init message: %s", init_message)
|
||||
|
||||
received_parallel_config = init_message.parallel_config
|
||||
for key, value in received_parallel_config.items():
|
||||
setattr(parallel_config, key, value)
|
||||
|
||||
return init_message.addresses
|
||||
|
||||
@staticmethod
|
||||
def run_engine_core(*args,
|
||||
dp_rank: int = 0,
|
||||
local_dp_rank: int = 0,
|
||||
**kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
# processes to terminate without error
|
||||
shutdown_requested = False
|
||||
|
||||
# Ensure we can serialize transformer config after spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the engine_core
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
engine_core: Optional[EngineCoreProc] = None
|
||||
try:
|
||||
parallel_config: ParallelConfig = kwargs[
|
||||
"vllm_config"].parallel_config
|
||||
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
|
||||
# Set data parallel rank for this engine process.
|
||||
parallel_config.data_parallel_rank = dp_rank
|
||||
parallel_config.data_parallel_rank_local = local_dp_rank
|
||||
engine_core = DPEngineCoreProc(*args, **kwargs)
|
||||
else:
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
logger.debug("EngineCore exiting.")
|
||||
raise
|
||||
except Exception as e:
|
||||
if engine_core is None:
|
||||
logger.exception("EngineCore failed to start.")
|
||||
else:
|
||||
logger.exception("EngineCore encountered a fatal error.")
|
||||
engine_core._send_engine_dead()
|
||||
raise e
|
||||
finally:
|
||||
if engine_core is not None:
|
||||
engine_core.shutdown()
|
||||
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
pass
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
# 2) Step the engine core and return the outputs.
|
||||
self._process_engine_step()
|
||||
|
||||
def _process_input_queue(self):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
waited = False
|
||||
while not self.engines_running and not self.scheduler.has_requests():
|
||||
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
req = self.input_queue.get()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
if waited:
|
||||
logger.debug("EngineCore loop active.")
|
||||
|
||||
# Handle any more client requests.
|
||||
while not self.input_queue.empty():
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
def _process_engine_step(self) -> bool:
|
||||
"""Called only when there are unfinished local requests."""
|
||||
|
||||
# Step the engine core.
|
||||
outputs, model_executed = self.step_fn()
|
||||
# Put EngineCoreOutputs into the output queue.
|
||||
for output in (outputs.items() if outputs else ()):
|
||||
self.output_queue.put_nowait(output)
|
||||
|
||||
return model_executed
|
||||
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
"""Dispatch request from client."""
|
||||
|
||||
if request_type == EngineCoreRequestType.ADD:
|
||||
self.add_request(request)
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.UTILITY:
|
||||
client_idx, call_id, method_name, args = request
|
||||
output = UtilityOutput(call_id)
|
||||
try:
|
||||
method = getattr(self, method_name)
|
||||
output.result = method(
|
||||
*self._convert_msgspec_args(method, args))
|
||||
except BaseException as e:
|
||||
logger.exception("Invocation of %s method failed", method_name)
|
||||
output.failure_message = (f"Call to {method_name} method"
|
||||
f" failed: {str(e)}")
|
||||
self.output_queue.put_nowait(
|
||||
(client_idx, EngineCoreOutputs(utility_output=output)))
|
||||
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
|
||||
raise RuntimeError("Executor failed.")
|
||||
else:
|
||||
logger.error("Unrecognized input request type encountered: %s",
|
||||
request_type)
|
||||
|
||||
@staticmethod
|
||||
def _convert_msgspec_args(method, args):
|
||||
"""If a provided arg type doesn't match corresponding target method
|
||||
arg type, try converting to msgspec object."""
|
||||
if not args:
|
||||
return args
|
||||
arg_types = signature(method).parameters.values()
|
||||
assert len(args) <= len(arg_types)
|
||||
return tuple(
|
||||
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
|
||||
and issubclass(p.annotation, msgspec.Struct)
|
||||
and not isinstance(v, p.annotation) else v
|
||||
for v, p in zip(args, arg_types))
|
||||
|
||||
def _send_engine_dead(self):
|
||||
"""Send EngineDead status to the EngineCoreClient."""
|
||||
|
||||
# Put ENGINE_CORE_DEAD in the queue.
|
||||
self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)
|
||||
|
||||
# Wait until msg sent by the daemon before shutdown.
|
||||
self.output_thread.join(timeout=5.0)
|
||||
if self.output_thread.is_alive():
|
||||
logger.fatal("vLLM shutdown signal from EngineCore failed "
|
||||
"to send. Please report this issue.")
|
||||
|
||||
def process_input_sockets(self, input_addresses: list[str],
|
||||
coord_input_address: Optional[str],
|
||||
identity: bytes):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
|
||||
with ExitStack() as stack, zmq.Context() as ctx:
|
||||
input_sockets = [
|
||||
stack.enter_context(
|
||||
make_zmq_socket(ctx,
|
||||
input_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
bind=False))
|
||||
for input_address in input_addresses
|
||||
]
|
||||
if coord_input_address is None:
|
||||
coord_socket = None
|
||||
else:
|
||||
coord_socket = stack.enter_context(
|
||||
make_zmq_socket(ctx,
|
||||
coord_input_address,
|
||||
zmq.XSUB,
|
||||
identity=identity,
|
||||
bind=False))
|
||||
# Send subscription message to coordinator.
|
||||
coord_socket.send(b'\x01')
|
||||
|
||||
# Register sockets with poller.
|
||||
poller = zmq.Poller()
|
||||
for input_socket in input_sockets:
|
||||
# Send initial message to each input socket - this is required
|
||||
# before the front-end ROUTER socket can send input messages
|
||||
# back to us.
|
||||
input_socket.send(b'')
|
||||
poller.register(input_socket, zmq.POLLIN)
|
||||
if coord_socket is not None:
|
||||
poller.register(coord_socket, zmq.POLLIN)
|
||||
|
||||
while True:
|
||||
for input_socket, _ in poller.poll():
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = input_socket.recv_multipart(
|
||||
copy=False)
|
||||
request_type = EngineCoreRequestType(
|
||||
bytes(type_frame.buffer))
|
||||
|
||||
# Deserialize the request data.
|
||||
decoder = add_request_decoder if (
|
||||
request_type
|
||||
== EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frames)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
|
||||
def process_output_sockets(self, output_paths: list[str],
|
||||
coord_output_path: Optional[str],
|
||||
engine_index: int):
|
||||
"""Output socket IO thread."""
|
||||
|
||||
# Msgpack serialization encoding.
|
||||
encoder = MsgpackEncoder()
|
||||
# Send buffers to reuse.
|
||||
reuse_buffers: list[bytearray] = []
|
||||
# Keep references to outputs and buffers until zmq is finished
|
||||
# with them (outputs may contain tensors/np arrays whose
|
||||
# backing buffers were extracted for zero-copy send).
|
||||
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
|
||||
|
||||
# We must set linger to ensure the ENGINE_CORE_DEAD
|
||||
# message is sent prior to closing the socket.
|
||||
with ExitStack() as stack, zmq.Context() as ctx:
|
||||
sockets = [
|
||||
stack.enter_context(
|
||||
make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
|
||||
for output_path in output_paths
|
||||
]
|
||||
coord_socket = stack.enter_context(
|
||||
make_zmq_socket(
|
||||
ctx, coord_output_path, zmq.PUSH, bind=False,
|
||||
linger=4000)) if coord_output_path is not None else None
|
||||
max_reuse_bufs = len(sockets) + 1
|
||||
|
||||
while True:
|
||||
output = self.output_queue.get()
|
||||
if output == EngineCoreProc.ENGINE_CORE_DEAD:
|
||||
for socket in sockets:
|
||||
socket.send(output)
|
||||
break
|
||||
assert not isinstance(output, bytes)
|
||||
client_index, outputs = output
|
||||
outputs.engine_index = engine_index
|
||||
|
||||
if client_index == -1:
|
||||
# Don't reuse buffer for coordinator message
|
||||
# which will be very small.
|
||||
assert coord_socket is not None
|
||||
coord_socket.send_multipart(encoder.encode(outputs))
|
||||
continue
|
||||
|
||||
# Reclaim buffers that zmq is finished with.
|
||||
while pending and pending[-1][0].done:
|
||||
reuse_buffers.append(pending.pop()[2])
|
||||
|
||||
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
|
||||
buffers = encoder.encode_into(outputs, buffer)
|
||||
tracker = sockets[client_index].send_multipart(buffers,
|
||||
copy=False,
|
||||
track=True)
|
||||
if not tracker.done:
|
||||
ref = outputs if len(buffers) > 1 else None
|
||||
pending.appendleft((tracker, ref, buffer))
|
||||
elif len(reuse_buffers) < max_reuse_bufs:
|
||||
# Limit the number of buffers to reuse.
|
||||
reuse_buffers.append(buffer)
|
||||
|
||||
|
||||
class DPEngineCoreProc(EngineCoreProc):
|
||||
"""ZMQ-wrapper for running EngineCore in background process
|
||||
in a data parallel context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
|
||||
self._decorate_logs()
|
||||
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
self.current_wave = 0
|
||||
self.last_counts = (0, 0)
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
super().__init__(vllm_config, on_head_node, handshake_address,
|
||||
executor_class, log_stats, dp_rank)
|
||||
|
||||
def _decorate_logs(self):
|
||||
# Add process-specific prefix to stdout and stderr before
|
||||
# we initialize the engine.
|
||||
from multiprocessing import current_process
|
||||
process_name = current_process().name
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
|
||||
# Configure GPUs and stateless process group for data parallel.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
||||
|
||||
assert dp_size > 1
|
||||
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
||||
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
# modify the engine_id and append the local_dp_rank to it to ensure
|
||||
# that the kv_transfer_config is unique for each DP rank.
|
||||
vllm_config.kv_transfer_config.engine_id = (
|
||||
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
|
||||
)
|
||||
logger.debug("Setting kv_transfer_config.engine_id to %s",
|
||||
vllm_config.kv_transfer_config.engine_id)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
device_control_env_var = current_platform.device_control_env_var
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
os.environ[device_control_env_var] = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
|
||||
world_size))
|
||||
os.environ["MACA_VISIBLE_DEVICES"] = os.environ[device_control_env_var]
|
||||
|
||||
self.dp_rank = dp_rank
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
|
||||
def shutdown(self):
|
||||
super().shutdown()
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
if self.has_coordinator and request.current_wave != self.current_wave:
|
||||
if request.current_wave > self.current_wave:
|
||||
self.current_wave = request.current_wave
|
||||
elif not self.engines_running:
|
||||
# Request received for an already-completed wave, notify
|
||||
# front-end that we need to start the next one.
|
||||
self.output_queue.put_nowait(
|
||||
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
|
||||
|
||||
super().add_request(request)
|
||||
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
if request_type == EngineCoreRequestType.START_DP_WAVE:
|
||||
new_wave, exclude_eng_index = request
|
||||
if exclude_eng_index != self.engine_index and (
|
||||
new_wave >= self.current_wave):
|
||||
self.current_wave = new_wave
|
||||
if not self.engines_running:
|
||||
logger.debug("EngineCore starting idle loop for wave %d.",
|
||||
new_wave)
|
||||
self.engines_running = True
|
||||
else:
|
||||
super()._handle_client_request(request_type, request)
|
||||
|
||||
def _maybe_publish_request_counts(self):
|
||||
if not self.has_coordinator:
|
||||
return
|
||||
|
||||
# Publish our request counts (if they've changed).
|
||||
counts = self.scheduler.get_request_counts()
|
||||
if counts != self.last_counts:
|
||||
self.last_counts = counts
|
||||
stats = SchedulerStats(*counts)
|
||||
self.output_queue.put_nowait(
|
||||
(-1, EngineCoreOutputs(scheduler_stats=stats)))
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore for data parallel case."""
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
|
||||
# 2) Step the engine core.
|
||||
executed = self._process_engine_step()
|
||||
self._maybe_publish_request_counts()
|
||||
|
||||
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
|
||||
if not executed:
|
||||
if not local_unfinished_reqs and not self.engines_running:
|
||||
# All engines are idle.
|
||||
continue
|
||||
|
||||
# We are in a running state and so must execute a dummy pass
|
||||
# if the model didn't execute any ready requests.
|
||||
self.execute_dummy_batch()
|
||||
|
||||
# 3) All-reduce operation to determine global unfinished reqs.
|
||||
self.engines_running = self._has_global_unfinished_reqs(
|
||||
local_unfinished_reqs)
|
||||
|
||||
if not self.engines_running:
|
||||
if self.dp_rank == 0:
|
||||
# Notify client that we are pausing the loop.
|
||||
logger.debug("Wave %d finished, pausing engine loop.",
|
||||
self.current_wave)
|
||||
self.output_queue.put_nowait(
|
||||
(-1,
|
||||
EngineCoreOutputs(wave_complete=self.current_wave)))
|
||||
self.current_wave += 1
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
|
||||
# Optimization - only perform finish-sync all-reduce every 24 steps.
|
||||
self.counter += 1
|
||||
if self.counter != 24:
|
||||
return True
|
||||
self.counter = 0
|
||||
|
||||
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
||||
local_unfinished)
|
||||
|
||||
|
||||
class DPEngineCoreActor(DPEngineCoreProc):
|
||||
"""
|
||||
Ray actor for running EngineCore in a data parallel context
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
addresses: EngineZmqAddresses,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
dp_rank: int = 0,
|
||||
local_dp_rank: int = 0,
|
||||
):
|
||||
self.addresses = addresses
|
||||
vllm_config.parallel_config.data_parallel_rank = dp_rank
|
||||
vllm_config.parallel_config.data_parallel_rank_local = \
|
||||
local_dp_rank
|
||||
|
||||
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
|
||||
# we clean this up to be able to properly initialize
|
||||
# data parallel groups.
|
||||
# del os.environ['CUDA_VISIBLE_DEVICES']
|
||||
|
||||
super().__init__(vllm_config, on_head_node, "", executor_class,
|
||||
log_stats)
|
||||
|
||||
def _decorate_logs(self):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshake(self, handshake_address: str, identity: bytes,
|
||||
on_head_node: bool, vllm_config: VllmConfig):
|
||||
"""
|
||||
For Ray, we don't need to actually perform handshake.
|
||||
All addresses information is known before the actor creation.
|
||||
Therefore, we simply yield these addresses.
|
||||
"""
|
||||
yield self.addresses
|
||||
|
||||
def wait_for_init(self):
|
||||
"""
|
||||
Wait until the engine core is initialized.
|
||||
|
||||
This is just an empty method. When ray.get() on this method
|
||||
(or any other method of the actor) returns, it is guaranteed
|
||||
that actor creation (i.e., __init__) is complete.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the engine core busy loop.
|
||||
"""
|
||||
try:
|
||||
self.run_busy_loop()
|
||||
except SystemExit:
|
||||
logger.debug("EngineCore exiting.")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("EngineCore encountered a fatal error.")
|
||||
raise
|
||||
finally:
|
||||
self.shutdown()
|
||||
1129
vllm/v1/engine/core_client.py
Normal file
1129
vllm/v1/engine/core_client.py
Normal file
File diff suppressed because it is too large
Load Diff
286
vllm/v1/engine/detokenizer.py
Normal file
286
vllm/v1/engine/detokenizer.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import tokenizers
|
||||
from packaging import version
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.decoders import DecodeStream
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.detokenizer_utils import (
|
||||
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Only tokenizers >= 0.21.1 supports DecodeStream used for
|
||||
# FastIncrementalDetokenizer.
|
||||
USE_FAST_DETOKENIZER = version.parse(
|
||||
tokenizers.__version__) >= version.parse("0.21.1")
|
||||
|
||||
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
|
||||
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
|
||||
|
||||
class IncrementalDetokenizer:
|
||||
|
||||
def __init__(self):
|
||||
self.token_ids: list[int] = []
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> list[int]:
|
||||
return self.token_ids
|
||||
|
||||
def update(self, new_token_ids: list[int],
|
||||
stop_terminated: bool) -> Optional[str]:
|
||||
self.token_ids.extend(new_token_ids)
|
||||
return None
|
||||
|
||||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
request: EngineCoreRequest,
|
||||
) -> "IncrementalDetokenizer":
|
||||
|
||||
if tokenizer is None:
|
||||
# No tokenizer => skipping detokenization.
|
||||
return IncrementalDetokenizer()
|
||||
|
||||
if USE_FAST_DETOKENIZER and isinstance(tokenizer,
|
||||
PreTrainedTokenizerFast):
|
||||
# Fast tokenizer => use tokenizers library DecodeStream.
|
||||
return FastIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
# Fall back to slow python-based incremental detokenization.
|
||||
return SlowIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
|
||||
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
||||
|
||||
def __init__(self, request: EngineCoreRequest):
|
||||
super().__init__()
|
||||
|
||||
# Stop strings
|
||||
params = request.sampling_params
|
||||
self.stop = stop = params.stop
|
||||
self.include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
# Number of chars to hold back when stop strings are to be excluded
|
||||
# from streamed output.
|
||||
if stop and not self.include_stop_str_in_output:
|
||||
self.stop_buffer_length = max(len(s) for s in stop) - 1
|
||||
else:
|
||||
self.stop_buffer_length = 0
|
||||
self._last_output_text_offset: int = 0
|
||||
|
||||
# Generation data
|
||||
self.output_text = ""
|
||||
|
||||
def update(self, new_token_ids: list[int],
|
||||
stop_terminated: bool) -> Optional[str]:
|
||||
"""
|
||||
Update RequestState for the request_id by:
|
||||
1) Detokenize the new token ids incrementally.
|
||||
2) Evaluate stop criteria.
|
||||
|
||||
Return matched stop string or None.
|
||||
"""
|
||||
if not new_token_ids:
|
||||
# Skip detokenization if no new token ids.
|
||||
return None
|
||||
|
||||
if stop_terminated and not self.include_stop_str_in_output:
|
||||
# If stop-terminated, exclude last token from detokenization
|
||||
# based on include_stop_str_in_output parameter.
|
||||
skipped_stop_token_id = new_token_ids[-1]
|
||||
new_token_ids = new_token_ids[:-1]
|
||||
else:
|
||||
skipped_stop_token_id = None
|
||||
|
||||
# 1) Detokenize the new token ids incrementally.
|
||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||
# new_token_ids is more than 1. We need to optimize this.
|
||||
offset_before = len(self.output_text)
|
||||
for new_token_id in new_token_ids:
|
||||
self.token_ids.append(new_token_id)
|
||||
self.output_text += self.decode_next(new_token_id)
|
||||
|
||||
if stop_terminated:
|
||||
if skipped_stop_token_id is not None:
|
||||
# Cleanup after skipping detokenization.
|
||||
self.token_ids.append(skipped_stop_token_id)
|
||||
# Stop token triggered; skip stop string check.
|
||||
return None
|
||||
|
||||
# 2) Evaluate stop strings.
|
||||
stop_string = None
|
||||
if self.stop:
|
||||
stop = StopChecker.check_stop_strings(
|
||||
output_text=self.output_text,
|
||||
new_char_count=len(self.output_text) - offset_before,
|
||||
stop=self.stop,
|
||||
include_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
if stop is not None:
|
||||
stop_string, truncate_to = stop
|
||||
if truncate_to != -1:
|
||||
self.output_text = self.output_text[:truncate_to]
|
||||
|
||||
return stop_string
|
||||
|
||||
@abstractmethod
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
"""If delta is True, only new text since the last call to
|
||||
this method is returned"""
|
||||
|
||||
# We return the full output text if the sequence is finished.
|
||||
buffer_length = 0 if finished else self.stop_buffer_length
|
||||
if not delta:
|
||||
return self.output_text[:-buffer_length] if buffer_length else (
|
||||
self.output_text)
|
||||
length = len(self.output_text) - buffer_length
|
||||
last_offset = self._last_output_text_offset
|
||||
if last_offset < length:
|
||||
self._last_output_text_offset = length
|
||||
return self.output_text[last_offset:length]
|
||||
return ""
|
||||
|
||||
|
||||
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerFast,
|
||||
request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
|
||||
self.request_id = request.request_id
|
||||
self.skip_special_tokens = sampling_params.skip_special_tokens
|
||||
self.stream = DecodeStream(
|
||||
skip_special_tokens=self.skip_special_tokens)
|
||||
|
||||
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
||||
|
||||
# Find a safe place to start.
|
||||
prompt_suffix = request.prompt_token_ids
|
||||
prompt_len = len(prompt_suffix)
|
||||
if prompt_len > 4:
|
||||
for i in range(4, min(prompt_len + 1, 24)):
|
||||
suffix = request.prompt_token_ids[-i:]
|
||||
if '<EFBFBD>' not in self.tokenizer.decode(suffix):
|
||||
prompt_suffix = suffix
|
||||
break
|
||||
|
||||
# Prime the stream.
|
||||
for tid in prompt_suffix:
|
||||
self._protected_step(tid)
|
||||
|
||||
self.spaces_between_special_tokens = (
|
||||
sampling_params.skip_special_tokens
|
||||
or sampling_params.spaces_between_special_tokens)
|
||||
|
||||
if not self.spaces_between_special_tokens:
|
||||
# Store dict of added token ids so that we can suppress
|
||||
# the spaces between them.
|
||||
if (added_token_ids := getattr(self.tokenizer, "added_token_ids",
|
||||
None)) is None:
|
||||
self.tokenizer.added_token_ids = added_token_ids = {
|
||||
tid: tok.content
|
||||
for tid, tok in
|
||||
self.tokenizer.get_added_tokens_decoder().items()
|
||||
}
|
||||
|
||||
if added_token_ids:
|
||||
self.last_special = False
|
||||
self.added_token_ids = added_token_ids
|
||||
else:
|
||||
# No added tokens.
|
||||
self.spaces_between_special_tokens = True
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
token = self._protected_step(next_token_id)
|
||||
|
||||
if not self.spaces_between_special_tokens:
|
||||
special_token = self.added_token_ids.get(next_token_id)
|
||||
is_special = special_token is not None
|
||||
if is_special and self.last_special:
|
||||
# Return raw token string without any prefixed spaces.
|
||||
token = special_token
|
||||
self.last_special = is_special
|
||||
|
||||
return token or ""
|
||||
|
||||
def _protected_step(self, next_token_id: int) -> Optional[str]:
|
||||
try:
|
||||
token = self.stream.step(self.tokenizer, next_token_id)
|
||||
except Exception as e:
|
||||
if str(e) != INVALID_PREFIX_ERR_MSG:
|
||||
raise e
|
||||
# Recover from edge case where tokenizer can produce non-monotonic,
|
||||
# invalid UTF-8 output, which breaks the internal state of
|
||||
# tokenizers' DecodeStream.
|
||||
# See https://github.com/vllm-project/vllm/issues/17448.
|
||||
logger.warning(
|
||||
"Encountered invalid prefix detokenization error"
|
||||
" for request %s, resetting decode stream.", self.request_id)
|
||||
self.stream = DecodeStream(self.skip_special_tokens)
|
||||
token = self.stream.step(self.tokenizer, next_token_id)
|
||||
return token
|
||||
|
||||
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Metadata for incremental detokenization.
|
||||
self.tokens, self.prefix_offset, self.read_offset = (
|
||||
convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=request.prompt_token_ids,
|
||||
skip_special_tokens=request.sampling_params.
|
||||
skip_special_tokens,
|
||||
))
|
||||
|
||||
self.token_ids.extend(request.prompt_token_ids)
|
||||
self.prompt_len = len(request.prompt_token_ids)
|
||||
|
||||
params = request.sampling_params
|
||||
self.skip_special_tokens = params.skip_special_tokens
|
||||
self.spaces_between_special_tokens = (
|
||||
params.spaces_between_special_tokens)
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> list[int]:
|
||||
return self.token_ids if not self.prompt_len else (
|
||||
self.token_ids[self.prompt_len:])
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
new_tokens, decoded_text, prefix_offset, read_offset = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=self.token_ids,
|
||||
prev_tokens=self.tokens,
|
||||
prefix_offset=self.prefix_offset,
|
||||
read_offset=self.read_offset,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.
|
||||
spaces_between_special_tokens,
|
||||
))
|
||||
|
||||
self.tokens.extend(new_tokens)
|
||||
self.prefix_offset = prefix_offset
|
||||
self.read_offset = read_offset
|
||||
|
||||
return decoded_text
|
||||
17
vllm/v1/engine/exceptions.py
Normal file
17
vllm/v1/engine/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
class EngineGenerateError(Exception):
|
||||
"""Raised when a AsyncLLM.generate() fails. Recoverable."""
|
||||
pass
|
||||
|
||||
|
||||
class EngineDeadError(Exception):
|
||||
"""Raised when the EngineCore dies. Unrecoverable."""
|
||||
|
||||
def __init__(self, *args, suppress_context: bool = False, **kwargs):
|
||||
ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501
|
||||
|
||||
super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs)
|
||||
# Make stack trace clearer when using with LLMEngine by
|
||||
# silencing irrelevant ZMQError.
|
||||
self.__suppress_context__ = suppress_context
|
||||
317
vllm/v1/engine/llm_engine.py
Normal file
317
vllm/v1/engine/llm_engine.py
Normal file
@@ -0,0 +1,317 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
TokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
|
||||
StatLoggerFactory)
|
||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""Legacy LLMEngine for backwards compatibility."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
multiprocess_mode: bool = False,
|
||||
) -> None:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"LLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
if stat_loggers is not None:
|
||||
raise NotImplementedError(
|
||||
"Passing StatLoggers to LLMEngine in V1 is not yet supported. "
|
||||
"Set VLLM_USE_V1=0 and file and issue on Github.")
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
self.log_stats = log_stats
|
||||
self.stat_logger: Optional[StatLoggerBase] = None
|
||||
if self.log_stats:
|
||||
self.stat_logger = PrometheusStatLogger(vllm_config)
|
||||
|
||||
# important: init dp group before init the engine_core
|
||||
# In the decoupled engine case this is handled in EngineCoreProc.
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
|
||||
self.dp_group = parallel_config.stateless_init_dp_group()
|
||||
else:
|
||||
self.dp_group = None
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(vllm_config=vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
mm_registry=mm_registry)
|
||||
|
||||
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(self.tokenizer,
|
||||
log_stats=self.log_stats)
|
||||
|
||||
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
||||
self.engine_core = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocess_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
if not multiprocess_mode:
|
||||
# for v0 compatibility
|
||||
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
||||
|
||||
# Don't keep the dummy data in memory
|
||||
self.reset_mm_cache()
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
disable_log_stats: bool = False,
|
||||
) -> "LLMEngine":
|
||||
return cls(vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=(not disable_log_stats),
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
enable_multiprocessing: bool = False,
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||
logger.debug("Enabling multiprocessing for LLMEngine.")
|
||||
enable_multiprocessing = True
|
||||
|
||||
# Create the LLMEngine.
|
||||
return cls(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=enable_multiprocessing)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.output_processor.get_num_unfinished_requests()
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
has_unfinished = self.output_processor.has_unfinished_requests()
|
||||
if self.dp_group is None:
|
||||
return has_unfinished
|
||||
return self.has_unfinished_requests_dp(has_unfinished)
|
||||
|
||||
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
|
||||
aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
|
||||
self.dp_group, has_unfinished)
|
||||
if not has_unfinished and aggregated_has_unfinished:
|
||||
self.should_execute_dummy_batch = True
|
||||
return aggregated_has_unfinished
|
||||
|
||||
@classmethod
|
||||
def validate_outputs(cls, outputs, output_type):
|
||||
return outputs
|
||||
|
||||
def abort_request(self, request_ids: list[str]) -> None:
|
||||
"""Remove request_ids from EngineCore and Detokenizer."""
|
||||
|
||||
request_ids = self.output_processor.abort_requests(request_ids)
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
# Process raw inputs into the request.
|
||||
prompt_str, request = self.processor.process_inputs(
|
||||
request_id, prompt, params, arrival_time, lora_request,
|
||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
|
||||
if n == 1:
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request, prompt_str, None, 0)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(request)
|
||||
return
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_req = ParentRequest(request_id, params)
|
||||
for idx in range(n):
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(child_request, prompt_str,
|
||||
parent_req, idx)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(child_request)
|
||||
|
||||
def step(self) -> list[RequestOutput]:
|
||||
|
||||
if self.should_execute_dummy_batch:
|
||||
self.should_execute_dummy_batch = False
|
||||
self.engine_core.execute_dummy_batch()
|
||||
return []
|
||||
|
||||
# 1) Get EngineCoreOutput from the EngineCore.
|
||||
outputs = self.engine_core.get_output()
|
||||
|
||||
# 2) Process EngineCoreOutputs.
|
||||
iteration_stats = IterationStats() if self.log_stats else None
|
||||
processed_outputs = self.output_processor.process_outputs(
|
||||
outputs.outputs,
|
||||
engine_core_timestamp=outputs.timestamp,
|
||||
iteration_stats=iteration_stats)
|
||||
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
|
||||
# 4) Record stats
|
||||
if self.stat_logger is not None:
|
||||
assert outputs.scheduler_stats is not None
|
||||
self.stat_logger.record(scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats)
|
||||
|
||||
return processed_outputs.request_outputs
|
||||
|
||||
def get_vllm_config(self):
|
||||
return self.vllm_config
|
||||
|
||||
def get_model_config(self):
|
||||
return self.model_config
|
||||
|
||||
def start_profile(self):
|
||||
self.engine_core.profile(True)
|
||||
|
||||
def stop_profile(self):
|
||||
self.engine_core.profile(False)
|
||||
|
||||
def reset_mm_cache(self):
|
||||
self.processor.mm_registry.reset_processor_cache()
|
||||
self.processor.mm_input_cache_client.reset()
|
||||
self.engine_core.reset_mm_cache()
|
||||
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None):
|
||||
self.engine_core.reset_prefix_cache()
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
self.engine_core.sleep(level)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None):
|
||||
self.engine_core.wake_up(tags)
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
return self.engine_core.is_sleeping()
|
||||
|
||||
def get_metrics(self) -> list[Metric]:
|
||||
assert self.log_stats, "Stat logging disabled"
|
||||
return get_metrics_snapshot()
|
||||
|
||||
def get_tokenizer_group(self) -> TokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
return self.engine_core.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove an already loaded LoRA adapter."""
|
||||
return self.engine_core.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
"""List all registered adapters."""
|
||||
return self.engine_core.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
def __del__(self):
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
199
vllm/v1/engine/logprobs.py
Normal file
199
vllm/v1/engine/logprobs.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||
from vllm.transformers_utils.detokenizer_utils import (
|
||||
AnyTokenizer, convert_ids_list_to_tokens)
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NONES = itertools.repeat(None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogprobsProcessor:
|
||||
|
||||
# Tokenizer for this request,
|
||||
# None if detokenization is disabled.
|
||||
tokenizer: Optional[AnyTokenizer]
|
||||
|
||||
# Logprobs for this request
|
||||
logprobs: Optional[SampleLogprobs]
|
||||
prompt_logprobs: Optional[PromptLogprobs]
|
||||
cumulative_logprob: Optional[float]
|
||||
num_logprobs: Optional[int]
|
||||
num_prompt_logprobs: Optional[int]
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
request: EngineCoreRequest,
|
||||
) -> "LogprobsProcessor":
|
||||
num_logprobs = request.sampling_params.logprobs
|
||||
num_prompt_logprobs = request.sampling_params.prompt_logprobs
|
||||
return cls(
|
||||
tokenizer=tokenizer,
|
||||
cumulative_logprob=(None if num_logprobs is None else 0.),
|
||||
logprobs=(None if num_logprobs is None else []),
|
||||
# NOTE: logprob of first prompt token is None.
|
||||
prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
|
||||
num_prompt_logprobs=num_prompt_logprobs,
|
||||
num_logprobs=num_logprobs,
|
||||
)
|
||||
|
||||
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
|
||||
"""Update with sample logprobs from EngineCore.
|
||||
|
||||
Outer lists are only of len > 1 if EngineCore made
|
||||
>1 tokens in prior step (e.g. in spec decoding).
|
||||
|
||||
Args:
|
||||
logprobs_lists: the lists of logprob tokens, logprobs, and ranks.
|
||||
|
||||
"""
|
||||
|
||||
assert self.num_logprobs is not None
|
||||
assert self.logprobs is not None
|
||||
assert self.cumulative_logprob is not None
|
||||
|
||||
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
|
||||
|
||||
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst,
|
||||
token_ids_lst):
|
||||
|
||||
# Detokenize (non-incrementally).
|
||||
decoded_tokens = NONES if self.tokenizer is None else (
|
||||
convert_ids_list_to_tokens(self.tokenizer, token_ids))
|
||||
|
||||
# Sampler puts the sampled logprob in first.
|
||||
sampled_token_logprob = logprobs[0]
|
||||
self.cumulative_logprob += sampled_token_logprob
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
self.logprobs.append(
|
||||
self._make_logprob_dict(
|
||||
logprobs,
|
||||
token_ids,
|
||||
decoded_tokens,
|
||||
rank,
|
||||
self.num_logprobs,
|
||||
))
|
||||
|
||||
def _update_prompt_logprobs(
|
||||
self,
|
||||
prompt_logprobs_tensors: LogprobsTensors,
|
||||
) -> None:
|
||||
"""Update with prompt logprobs from EngineCore.
|
||||
|
||||
Args:
|
||||
prompt_logprobs_tensors: tuple containing the prompt logprobs
|
||||
tensors.
|
||||
|
||||
"""
|
||||
|
||||
# Prompt logprobs are enabled.
|
||||
assert self.num_prompt_logprobs is not None
|
||||
assert self.prompt_logprobs is not None
|
||||
|
||||
token_ids, logprobs, ranks = prompt_logprobs_tensors
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
decoded_tokens = None if self.tokenizer is None else (
|
||||
convert_ids_list_to_tokens(self.tokenizer,
|
||||
token_ids.flatten().tolist()))
|
||||
|
||||
# Recover shapes.
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
|
||||
# Pythonize the torch tensors.
|
||||
prompt_token_ranks = ranks.tolist()
|
||||
prompt_logprobs = logprobs.tolist()
|
||||
token_ids = token_ids.tolist()
|
||||
|
||||
# Make Logprob for each position.
|
||||
for pos in range(num_prompt_tokens):
|
||||
# Handle flattening.
|
||||
offset = pos * num_logprobs
|
||||
offset_end = offset + num_logprobs
|
||||
decoded_tokens_for_pos = NONES \
|
||||
if decoded_tokens is None else decoded_tokens[offset:offset_end]
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
self.prompt_logprobs.append(
|
||||
self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos],
|
||||
decoded_tokens_for_pos,
|
||||
prompt_token_ranks[pos],
|
||||
self.num_prompt_logprobs))
|
||||
|
||||
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
|
||||
"""Pop and return all request prompt logprobs
|
||||
|
||||
The logprobs processor aggregates prompt chunk logprobs
|
||||
over one or more prefill chunks. This method returns
|
||||
all prompt logprobs at once and then forgets them.
|
||||
Ensures correct RequestOutputKind.DELTA semantics
|
||||
wherein all prompt logprobs are returned at once at
|
||||
the end of prefill.
|
||||
|
||||
Returns:
|
||||
None if prompt logprobs are disabled for this request.
|
||||
List of all prompt logprobs, otherwise.
|
||||
"""
|
||||
plp = self.prompt_logprobs
|
||||
if plp:
|
||||
self.prompt_logprobs = []
|
||||
return plp
|
||||
|
||||
@staticmethod
|
||||
def _make_logprob_dict(
|
||||
logprobs: list[float],
|
||||
logprob_token_ids: list[int],
|
||||
decoded_tokens: Iterable[Optional[str]],
|
||||
rank: int,
|
||||
num_logprobs: int,
|
||||
) -> dict[int, Logprob]:
|
||||
"""Make a Logprob dictionary for a position.
|
||||
|
||||
Args:
|
||||
logprobs: list of log probabilities
|
||||
logprob_token_ids: list of top token ids
|
||||
decoded_tokens: list of decoded top tokens
|
||||
rank: rank of the sampled token
|
||||
num_logprobs: number of logprobs requested
|
||||
by the user (in addition to sampled logprob)
|
||||
|
||||
Returns:
|
||||
dict[token id, Logprob]
|
||||
"""
|
||||
|
||||
# We do not need a special case for the sampled token
|
||||
# being in the topk, since inserting duplicated data
|
||||
# into a dictionary twice is the same as doing it once.
|
||||
topk_ranks = range(1, num_logprobs + 1)
|
||||
ranks = itertools.chain((rank, ), topk_ranks)
|
||||
|
||||
return {
|
||||
token_id: Logprob(
|
||||
logprob=logprob,
|
||||
rank=rank,
|
||||
decoded_token=token,
|
||||
)
|
||||
for token_id, logprob, rank, token in zip(
|
||||
logprob_token_ids, logprobs, ranks, decoded_tokens)
|
||||
}
|
||||
|
||||
def update_from_output(self, output: EngineCoreOutput) -> None:
|
||||
if output.new_logprobs is not None:
|
||||
self._update_sample_logprobs(output.new_logprobs)
|
||||
if output.new_prompt_logprobs_tensors is not None:
|
||||
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)
|
||||
91
vllm/v1/engine/mm_input_cache.py
Normal file
91
vllm/v1/engine/mm_input_cache.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.processing import ProcessingCache
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
# The idea of multimodal preprocessing caching is based on having a client and
|
||||
# a server, where the client executes in the frontend process (=P0) and the
|
||||
# server in the core process (=P1).
|
||||
#
|
||||
# -- Client:
|
||||
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
|
||||
# with built-in caching functionality, with mm_hash as its identifier.
|
||||
# - MirroredProcessingCache to keep track of the cached entries and
|
||||
# determine whether to send the MultiModalKwargs to P1.
|
||||
#
|
||||
# -- Server:
|
||||
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
|
||||
#
|
||||
# The caching for both client and server is mirrored, and this allows us
|
||||
# to avoid the serialization of "mm_inputs" (like pixel values) between
|
||||
# client (=P0) and server (=P1) processes if the mm_hash is found in the client
|
||||
# cache.
|
||||
|
||||
# Both Client and Server must use the same cache size
|
||||
# (to perform mirrored caching). This cache size is set by the environment
|
||||
# variable VLLM_MM_INPUT_CACHE_GIB.
|
||||
|
||||
|
||||
class MirroredProcessingCache:
|
||||
|
||||
def __init__(self, model_config):
|
||||
mm_config = model_config.multimodal_config
|
||||
disable_mm_preprocessor_cache = (
|
||||
mm_config is not None and mm_config.disable_mm_preprocessor_cache)
|
||||
self.use_cache = not disable_mm_preprocessor_cache
|
||||
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
|
||||
MultiModalKwargs)
|
||||
|
||||
def get_and_update_p0(
|
||||
self,
|
||||
mm_inputs: Sequence[MultiModalKwargs],
|
||||
mm_hashes: list[str],
|
||||
) -> Sequence[Optional[MultiModalKwargs]]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
if not self.use_cache:
|
||||
assert is_list_of(mm_inputs, MultiModalKwargs)
|
||||
return mm_inputs
|
||||
|
||||
full_mm_inputs = list[Optional[MultiModalKwargs]]()
|
||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||
if self.mm_cache.get(mm_hash) is not None:
|
||||
mm_input = None
|
||||
else:
|
||||
self.mm_cache[mm_hash] = mm_input
|
||||
|
||||
full_mm_inputs.append(mm_input)
|
||||
|
||||
return full_mm_inputs
|
||||
|
||||
def get_and_update_p1(
|
||||
self,
|
||||
mm_inputs: Sequence[Optional[MultiModalKwargs]],
|
||||
mm_hashes: list[str],
|
||||
) -> Sequence[MultiModalKwargs]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
if not self.use_cache:
|
||||
assert is_list_of(mm_inputs, MultiModalKwargs)
|
||||
return mm_inputs
|
||||
|
||||
full_mm_inputs = list[MultiModalKwargs]()
|
||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||
if mm_input is None:
|
||||
mm_input = self.mm_cache[mm_hash]
|
||||
else:
|
||||
self.mm_cache[mm_hash] = mm_input
|
||||
|
||||
full_mm_inputs.append(mm_input)
|
||||
|
||||
return full_mm_inputs
|
||||
|
||||
def reset(self) -> bool:
|
||||
self.mm_cache.clear()
|
||||
|
||||
return True
|
||||
428
vllm/v1/engine/output_processor.py
Normal file
428
vllm/v1/engine/output_processor.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
|
||||
RequestStateStats)
|
||||
|
||||
|
||||
class RequestOutputCollector:
|
||||
"""
|
||||
Collects streamed RequestOutputs per individual request,
|
||||
for hand-off to the consuming asyncio generate task.
|
||||
|
||||
When streaming deltas, RequestOutputs are merged if the
|
||||
producer gets ahead of the consumer.
|
||||
"""
|
||||
|
||||
def __init__(self, output_kind: RequestOutputKind):
|
||||
self.aggregate = output_kind == RequestOutputKind.DELTA
|
||||
self.output: Optional[Union[RequestOutput, Exception]] = None
|
||||
self.ready = asyncio.Event()
|
||||
|
||||
def put(self, output: Union[RequestOutput, Exception]) -> None:
|
||||
"""Non-blocking put operation."""
|
||||
if self.output is None or isinstance(output, Exception):
|
||||
self.output = output
|
||||
self.ready.set()
|
||||
elif isinstance(self.output, RequestOutput):
|
||||
# This ensures that request outputs with different request indexes
|
||||
# (if n > 1) do not override each other.
|
||||
self.output.add(output, aggregate=self.aggregate)
|
||||
|
||||
async def get(self) -> RequestOutput:
|
||||
"""Get operation blocks on put event."""
|
||||
while (output := self.output) is None:
|
||||
await self.ready.wait()
|
||||
self.output = None
|
||||
self.ready.clear()
|
||||
if isinstance(output, Exception):
|
||||
raise output
|
||||
return output
|
||||
|
||||
def get_nowait(self) -> Optional[RequestOutput]:
|
||||
"""Non-blocking get operation."""
|
||||
output = self.output
|
||||
if output is not None:
|
||||
self.output = None
|
||||
self.ready.clear()
|
||||
if isinstance(output, Exception):
|
||||
raise output
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputProcessorOutput:
|
||||
|
||||
request_outputs: list[RequestOutput]
|
||||
reqs_to_abort: list[str]
|
||||
|
||||
|
||||
class RequestState:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
parent_req: Optional[ParentRequest],
|
||||
request_index: int,
|
||||
lora_name: Optional[str],
|
||||
output_kind: RequestOutputKind,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: list[int],
|
||||
logprobs_processor: LogprobsProcessor,
|
||||
detokenizer: IncrementalDetokenizer,
|
||||
max_tokens_param: Optional[int],
|
||||
arrival_time: float,
|
||||
queue: Optional[RequestOutputCollector],
|
||||
log_stats: bool,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.parent_req = parent_req
|
||||
self.request_index = request_index
|
||||
self.lora_name = lora_name
|
||||
self.output_kind = output_kind
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_len = len(prompt_token_ids)
|
||||
self.logprobs_processor = logprobs_processor
|
||||
self.detokenizer = detokenizer
|
||||
self.max_tokens_param = max_tokens_param
|
||||
self.is_prefilling = True
|
||||
self.queue = queue
|
||||
|
||||
self.stats = RequestStateStats(
|
||||
arrival_time=arrival_time) if log_stats else None
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: AnyTokenizer,
|
||||
request: EngineCoreRequest,
|
||||
prompt: Optional[str],
|
||||
parent_req: Optional[ParentRequest],
|
||||
request_index: int,
|
||||
queue: Optional[RequestOutputCollector],
|
||||
log_stats: bool,
|
||||
) -> "RequestState":
|
||||
if not request.sampling_params.detokenize:
|
||||
tokenizer = None
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
lora_name=(request.lora_request.name
|
||||
if request.lora_request is not None else None),
|
||||
output_kind=request.sampling_params.output_kind,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
logprobs_processor=LogprobsProcessor.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
),
|
||||
detokenizer=IncrementalDetokenizer.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
),
|
||||
max_tokens_param=(request.sampling_params.max_tokens if
|
||||
request.sampling_params is not None else None),
|
||||
arrival_time=request.arrival_time,
|
||||
queue=queue,
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
def make_request_output(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
finish_reason: Optional[FinishReason],
|
||||
stop_reason: Union[int, str, None],
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||
num_cached_tokens: int = 0,
|
||||
) -> Optional[RequestOutput]:
|
||||
|
||||
finished = finish_reason is not None
|
||||
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
|
||||
|
||||
if not finished and final_only:
|
||||
# Only the final output is required in FINAL_ONLY mode.
|
||||
return None
|
||||
|
||||
completion_output = self._new_completion_output(
|
||||
new_token_ids, finish_reason, stop_reason)
|
||||
|
||||
request_id = self.request_id
|
||||
if self.parent_req is None:
|
||||
outputs = [completion_output]
|
||||
else:
|
||||
request_id, outputs, finished = self.parent_req.get_outputs(
|
||||
request_id, completion_output)
|
||||
if not outputs:
|
||||
return None
|
||||
|
||||
return self._new_request_output(request_id, outputs, finished,
|
||||
kv_transfer_params, num_cached_tokens)
|
||||
|
||||
def _new_request_output(
|
||||
self,
|
||||
request_id: str,
|
||||
outputs: list[CompletionOutput],
|
||||
finished: bool,
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||
num_cached_tokens: int = 0,
|
||||
) -> RequestOutput:
|
||||
|
||||
if self.output_kind == RequestOutputKind.DELTA:
|
||||
# Side effect: logprobs processor forgets prompt logprobs
|
||||
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
|
||||
else:
|
||||
prompt_logprobs = self.logprobs_processor.prompt_logprobs
|
||||
|
||||
return RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=self.prompt,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
outputs=outputs,
|
||||
finished=finished,
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
num_cached_tokens=num_cached_tokens,
|
||||
)
|
||||
|
||||
def _new_completion_output(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
finish_reason: Optional[FinishReason],
|
||||
stop_reason: Union[int, str, None],
|
||||
) -> CompletionOutput:
|
||||
|
||||
finished = finish_reason is not None
|
||||
delta = self.output_kind == RequestOutputKind.DELTA
|
||||
|
||||
# Prepare text and token_ids, based on delta mode
|
||||
text = self.detokenizer.get_next_output_text(finished, delta)
|
||||
if not delta:
|
||||
token_ids = self.detokenizer.output_token_ids
|
||||
|
||||
# Prepare logprobs, based on delta mode
|
||||
logprobs = self.logprobs_processor.logprobs
|
||||
if delta and logprobs:
|
||||
logprobs = logprobs[-len(token_ids):]
|
||||
|
||||
return CompletionOutput(
|
||||
index=self.request_index,
|
||||
text=text,
|
||||
token_ids=token_ids,
|
||||
logprobs=logprobs,
|
||||
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
|
||||
finish_reason=str(finish_reason) if finished else None,
|
||||
stop_reason=stop_reason if finished else None)
|
||||
|
||||
|
||||
class OutputProcessor:
|
||||
"""Process EngineCoreOutputs into RequestOutputs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: TokenizerGroup,
|
||||
log_stats: bool,
|
||||
):
|
||||
self.log_stats = log_stats
|
||||
self.tokenizer = tokenizer
|
||||
self.request_states: dict[str, RequestState] = {}
|
||||
self.parent_requests: dict[str, ParentRequest] = {}
|
||||
self.lora_states = LoRARequestStates()
|
||||
|
||||
def get_num_unfinished_requests(self):
|
||||
return len(self.request_states)
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return len(self.request_states) > 0
|
||||
|
||||
def propagate_error(self, e: Exception):
|
||||
"""Propagate error to all generate() tasks."""
|
||||
|
||||
for _, state in self.request_states.items():
|
||||
assert state.queue is not None
|
||||
state.queue.put(e)
|
||||
|
||||
def abort_requests(
|
||||
self,
|
||||
request_ids: Iterable[str],
|
||||
) -> list[str]:
|
||||
request_ids_to_abort = []
|
||||
for request_id in request_ids:
|
||||
req_state = self.request_states.pop(request_id, None)
|
||||
if req_state is not None:
|
||||
self.lora_states.abort_request(req_state)
|
||||
request_ids_to_abort.append(request_id)
|
||||
else:
|
||||
parent = self.parent_requests.pop(request_id, None)
|
||||
if parent and parent.child_requests:
|
||||
self.abort_requests(parent.child_requests)
|
||||
request_ids_to_abort.extend(parent.child_requests)
|
||||
return request_ids_to_abort
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: EngineCoreRequest,
|
||||
prompt: Optional[str],
|
||||
parent_req: Optional[ParentRequest] = None,
|
||||
request_index: int = 0,
|
||||
queue: Optional[RequestOutputCollector] = None,
|
||||
) -> None:
|
||||
request_id = request.request_id
|
||||
if request_id in self.request_states:
|
||||
raise ValueError(f"Request id {request_id} already running.")
|
||||
|
||||
req_state = RequestState.from_new_request(
|
||||
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
|
||||
request=request,
|
||||
prompt=prompt,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
queue=queue,
|
||||
log_stats=self.log_stats)
|
||||
self.request_states[request_id] = req_state
|
||||
self.lora_states.add_request(req_state)
|
||||
if parent_req:
|
||||
self.parent_requests[parent_req.request_id] = parent_req
|
||||
|
||||
def process_outputs(
|
||||
self,
|
||||
engine_core_outputs: list[EngineCoreOutput],
|
||||
engine_core_timestamp: Optional[float] = None,
|
||||
iteration_stats: Optional[IterationStats] = None,
|
||||
) -> OutputProcessorOutput:
|
||||
"""
|
||||
Process the EngineCoreOutputs:
|
||||
1) Compute stats for logging
|
||||
2) Detokenize
|
||||
3) Create and handle RequestOutput objects:
|
||||
* If there is a queue (for usage with AsyncLLM),
|
||||
put the RequestOutput objects into the queue for
|
||||
handling by the per-request generate() tasks.
|
||||
|
||||
* If there is no queue (for usage with LLMEngine),
|
||||
return a list of RequestOutput objects.
|
||||
|
||||
NOTE FOR DEVELOPERS
|
||||
|
||||
vLLM V1 minimizes the number of python loops over the full
|
||||
batch to ensure system overheads are minimized. This is the
|
||||
only function that should loop over EngineCoreOutputs.
|
||||
|
||||
If you need to touch every element of the batch, do it from
|
||||
within the loop below.
|
||||
"""
|
||||
|
||||
request_outputs: list[RequestOutput] = []
|
||||
reqs_to_abort: list[str] = []
|
||||
for engine_core_output in engine_core_outputs:
|
||||
req_id = engine_core_output.request_id
|
||||
req_state = self.request_states.get(req_id)
|
||||
if req_state is None:
|
||||
# Ignore output for already-aborted request.
|
||||
continue
|
||||
|
||||
# 1) Compute stats for this iteration.
|
||||
self._update_stats_from_output(req_state, engine_core_output,
|
||||
engine_core_timestamp,
|
||||
iteration_stats)
|
||||
|
||||
new_token_ids = engine_core_output.new_token_ids
|
||||
finish_reason = engine_core_output.finish_reason
|
||||
stop_reason = engine_core_output.stop_reason
|
||||
kv_transfer_params = engine_core_output.kv_transfer_params
|
||||
num_cached_tokens = engine_core_output.num_cached_tokens
|
||||
req_state.is_prefilling = False
|
||||
|
||||
# 2) Detokenize the token ids into text and perform stop checks.
|
||||
stop_string = req_state.detokenizer.update(
|
||||
new_token_ids, finish_reason == FinishReason.STOP)
|
||||
if stop_string:
|
||||
finish_reason = FinishReason.STOP
|
||||
stop_reason = stop_string
|
||||
|
||||
# 3) Compute sample and prompt logprobs for request, if required.
|
||||
req_state.logprobs_processor.update_from_output(engine_core_output)
|
||||
|
||||
# 4) Create and handle RequestOutput objects.
|
||||
if request_output := req_state.make_request_output(
|
||||
new_token_ids, finish_reason, stop_reason,
|
||||
kv_transfer_params, num_cached_tokens):
|
||||
if req_state.queue is not None:
|
||||
# AsyncLLM: put into queue for handling by generate().
|
||||
req_state.queue.put(request_output)
|
||||
else:
|
||||
# LLMEngine: return list of RequestOutputs.
|
||||
request_outputs.append(request_output)
|
||||
|
||||
# Free completed requests.
|
||||
if finish_reason is not None:
|
||||
self.request_states.pop(req_id)
|
||||
# Remove parent request if applicable.
|
||||
parent_req = req_state.parent_req
|
||||
if parent_req and not parent_req.child_requests:
|
||||
self.parent_requests.pop(parent_req.request_id, None)
|
||||
if not engine_core_output.finished:
|
||||
# If req not finished in EngineCore, but Detokenizer
|
||||
# detected stop string, abort needed in EngineCore.
|
||||
reqs_to_abort.append(req_id)
|
||||
|
||||
# Track per-request stats
|
||||
self._update_stats_from_finished(req_state, finish_reason,
|
||||
iteration_stats)
|
||||
|
||||
self.lora_states.update_iteration_stats(iteration_stats)
|
||||
|
||||
return OutputProcessorOutput(
|
||||
request_outputs=request_outputs,
|
||||
reqs_to_abort=reqs_to_abort,
|
||||
)
|
||||
|
||||
def _update_stats_from_output(self, req_state: RequestState,
|
||||
engine_core_output: EngineCoreOutput,
|
||||
engine_core_timestamp: Optional[float],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
lora_stats = self.lora_states.get_stats(req_state)
|
||||
|
||||
assert engine_core_timestamp is not None
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_output(engine_core_output,
|
||||
engine_core_timestamp,
|
||||
req_state.is_prefilling,
|
||||
req_state.prompt_len,
|
||||
req_state.stats, lora_stats)
|
||||
|
||||
def _update_stats_from_finished(self, req_state: RequestState,
|
||||
finish_reason: Optional[FinishReason],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
assert finish_reason is not None
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_finished_request(
|
||||
finish_reason=finish_reason,
|
||||
num_prompt_tokens=len(req_state.prompt_token_ids),
|
||||
max_tokens_param=req_state.max_tokens_param,
|
||||
req_stats=req_state.stats)
|
||||
self.lora_states.finish_request(req_state)
|
||||
|
||||
ParentRequest.observe_finished_request(
|
||||
req_state.parent_req, iteration_stats,
|
||||
req_state.stats.num_generation_tokens)
|
||||
133
vllm/v1/engine/parallel_sampling.py
Normal file
133
vllm/v1/engine/parallel_sampling.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import copy
|
||||
from typing import Optional
|
||||
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
|
||||
class ParentRequest:
|
||||
"""Info, state & processing for parallel sampling request.
|
||||
|
||||
Store parent request ID and sampling params.
|
||||
Facilitate generating child request sampling params.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
sampling_params: SamplingParams
|
||||
|
||||
# To track the completion of child requests
|
||||
child_requests: set[str]
|
||||
|
||||
# To aggregate child completions when not streaming
|
||||
output_aggregator: list[CompletionOutput]
|
||||
|
||||
# To find the max number of generated tokens across all children
|
||||
max_num_generation_tokens: int
|
||||
|
||||
# To efficiently obtain child sampling params
|
||||
cached_child_sampling_params: Optional[SamplingParams]
|
||||
|
||||
def __init__(self, request_id: str,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
self.request_id = request_id
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.child_requests = set()
|
||||
self.output_aggregator = [None] * sampling_params.n if (
|
||||
sampling_params.output_kind
|
||||
== RequestOutputKind.FINAL_ONLY) else []
|
||||
self.max_num_generation_tokens = 0
|
||||
self.cached_child_sampling_params = None
|
||||
|
||||
def _get_child_sampling_params(
|
||||
self,
|
||||
index: int,
|
||||
) -> SamplingParams:
|
||||
"""Efficiently obtain child `sampling_params`
|
||||
|
||||
If `sampling_params.seed` is not `None` then
|
||||
each child request requires a unique clone of
|
||||
parent `sampling_params` with a unique seed.
|
||||
|
||||
Args:
|
||||
index: index within `n` child requests
|
||||
|
||||
Returns:
|
||||
Child `sampling_params` instance.
|
||||
"""
|
||||
seed = self.sampling_params.seed
|
||||
if self.cached_child_sampling_params:
|
||||
# Reuse child sampling_params data structure
|
||||
return self.cached_child_sampling_params
|
||||
# Build child sampling_params
|
||||
child_sampling_params = copy(self.sampling_params)
|
||||
child_sampling_params.n = 1
|
||||
if seed is None:
|
||||
# Cache child sampling_params for later reuse
|
||||
self.cached_child_sampling_params = child_sampling_params
|
||||
else:
|
||||
# Each child gets a clone with a unique seed
|
||||
child_sampling_params.seed = seed + index
|
||||
return child_sampling_params
|
||||
|
||||
def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
|
||||
"""Get child request ID and sampling params.
|
||||
|
||||
Args:
|
||||
index: index within `n` child requests.
|
||||
|
||||
Returns:
|
||||
(request ID, sampling_params) tuple
|
||||
"""
|
||||
child_req_id = f"{index}_{self.request_id}"
|
||||
self.child_requests.add(child_req_id)
|
||||
return child_req_id, self._get_child_sampling_params(index)
|
||||
|
||||
@property
|
||||
def n(self) -> int:
|
||||
return self.sampling_params.n
|
||||
|
||||
def get_outputs(
|
||||
self,
|
||||
child_request_id: str,
|
||||
completion_output: CompletionOutput,
|
||||
) -> tuple[str, list[CompletionOutput], bool]:
|
||||
if completion_output.finished():
|
||||
self.child_requests.remove(child_request_id)
|
||||
|
||||
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
|
||||
# If streaming, just return the current output.
|
||||
outputs = [completion_output]
|
||||
else:
|
||||
# If not streaming, aggregate the n final outputs.
|
||||
self.output_aggregator[completion_output.index] = completion_output
|
||||
outputs = [] if self.child_requests else self.output_aggregator
|
||||
|
||||
finished = not self.child_requests
|
||||
return self.request_id, outputs, finished
|
||||
|
||||
def observe_num_generation_tokens(self, num_generation_tokens: int):
|
||||
self.max_num_generation_tokens = max(num_generation_tokens,
|
||||
self.max_num_generation_tokens)
|
||||
return self.max_num_generation_tokens
|
||||
|
||||
@staticmethod
|
||||
def observe_finished_request(parent_req: Optional['ParentRequest'],
|
||||
iteration_stats: IterationStats,
|
||||
num_generation_tokens: int):
|
||||
|
||||
n_param = parent_req.n if parent_req is not None else 1
|
||||
|
||||
if parent_req is not None:
|
||||
num_generation_tokens = parent_req.observe_num_generation_tokens(
|
||||
num_generation_tokens)
|
||||
|
||||
# Child requests finished, we can now record to iteration stats
|
||||
if parent_req is None or not parent_req.child_requests:
|
||||
iteration_stats.max_num_generation_tokens_iter.append(
|
||||
num_generation_tokens)
|
||||
iteration_stats.n_params_iter.append(n_param)
|
||||
407
vllm/v1/engine/processor.py
Normal file
407
vllm/v1/engine/processor.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
MultiModalRegistry)
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
validate_xgrammar_grammar)
|
||||
|
||||
|
||||
class Processor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
tokenizer: TokenizerGroup,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.decoding_config = vllm_config.decoding_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.generation_config_fields = (
|
||||
self.model_config.try_get_generation_config())
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
|
||||
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
|
||||
|
||||
# Multi-modal hasher (for images)
|
||||
self.use_hash = self.mm_input_cache_client.use_cache or \
|
||||
self.cache_config.enable_prefix_caching
|
||||
|
||||
@property
|
||||
def mm_registry(self):
|
||||
return self.input_preprocessor.mm_registry
|
||||
|
||||
def _validate_logprobs(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
max_logprobs = self.model_config.max_logprobs
|
||||
# Validate sample logprobs.
|
||||
if params.logprobs and params.logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Requested sample logprobs of {params.logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}")
|
||||
|
||||
# Validate prompt logprobs.
|
||||
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Requested prompt logprobs of {params.prompt_logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}")
|
||||
|
||||
def _validate_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
self._validate_structured_output(params)
|
||||
self._validate_logit_bias(params)
|
||||
|
||||
if params.allowed_token_ids is None:
|
||||
return
|
||||
if not params.allowed_token_ids:
|
||||
raise ValueError("allowed_token_ids is not None and empty!")
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
vocab_size = len(tokenizer)
|
||||
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
|
||||
raise ValueError(
|
||||
"allowed_token_ids contains out-of-vocab token id!")
|
||||
|
||||
def _validate_logit_bias(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
"""Validate logit_bias token IDs are within vocabulary range."""
|
||||
if not params.logit_bias:
|
||||
return
|
||||
|
||||
vocab_size = self.model_config.get_vocab_size()
|
||||
invalid_token_ids = []
|
||||
|
||||
for token_id in params.logit_bias:
|
||||
if token_id < 0 or token_id >= vocab_size:
|
||||
invalid_token_ids.append(token_id)
|
||||
|
||||
if invalid_token_ids:
|
||||
raise ValueError(
|
||||
f"token_id(s) {invalid_token_ids} in logit_bias contain "
|
||||
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
|
||||
|
||||
def _validate_supported_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
# Best of not yet supported.
|
||||
if params.best_of is not None and params.best_of > 1:
|
||||
raise ValueError("vLLM V1 does not yet support best_of.")
|
||||
# Logits processors not supported.
|
||||
if params.logits_processors:
|
||||
raise ValueError("vLLM V1 does not support per request "
|
||||
"user provided logits processors.")
|
||||
|
||||
def _validate_params(
|
||||
self,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
lora_request: Optional[LoRARequest],
|
||||
):
|
||||
"""
|
||||
Validate supported SamplingParam.
|
||||
Should raise ValueError if unsupported for API Server.
|
||||
"""
|
||||
|
||||
if not isinstance(params, SamplingParams):
|
||||
raise ValueError("V1 does not yet support Pooling models.")
|
||||
|
||||
self._validate_logprobs(params)
|
||||
self._validate_sampling_params(params, lora_request)
|
||||
self._validate_supported_sampling_params(params)
|
||||
|
||||
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
|
||||
def _validate_structured_output(self, params: SamplingParams) -> None:
|
||||
if not params.guided_decoding or not self.decoding_config:
|
||||
return
|
||||
|
||||
engine_level_backend = self.decoding_config.backend
|
||||
if params.guided_decoding.backend:
|
||||
# Request-level backend selection is not supported in V1.
|
||||
# The values may differ if `params` is reused and was set
|
||||
# to a specific backend based on `auto` behavior in a previous
|
||||
# request. We remember that it was set as a result of `auto`
|
||||
# using the `_auto` option set on the backend in the params.
|
||||
if (params.guided_decoding.backend != engine_level_backend
|
||||
and not (engine_level_backend == "auto"
|
||||
and params.guided_decoding.backend_was_auto)):
|
||||
raise ValueError(
|
||||
"Request-level structured output backend selection is no "
|
||||
"longer supported. The request specified "
|
||||
f"'{params.guided_decoding.backend}', but vLLM was "
|
||||
f"initialised with '{engine_level_backend}'. This error "
|
||||
"can be resolved by removing backend selection from the "
|
||||
"request.")
|
||||
else:
|
||||
params.guided_decoding.backend = engine_level_backend
|
||||
|
||||
# Request content validation
|
||||
if engine_level_backend.startswith("xgrammar"):
|
||||
# xgrammar with no fallback
|
||||
validate_xgrammar_grammar(params)
|
||||
elif engine_level_backend.startswith("guidance"):
|
||||
# TODO: ideally we would have the LLTokenizer here as Lark syntax
|
||||
# allows <|special_token|> and similar, see
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||
# Without tokenizer these are disallowed in grammars.
|
||||
validate_guidance_grammar(params, tokenizer=None)
|
||||
else:
|
||||
# NOTE: engine_level_backend must be "auto" here, because we have
|
||||
# checked supported_backends above.
|
||||
# "auto" is an opt-in to opinionated behavior where we try to
|
||||
# choose a backend based on request contents. This is not the
|
||||
# default as it is less predictable and subject to change
|
||||
# between releases as feature support changes.
|
||||
try:
|
||||
validate_xgrammar_grammar(params)
|
||||
params.guided_decoding.backend = "xgrammar"
|
||||
except ValueError:
|
||||
# The request either failed validation
|
||||
# or includes some jsonschema feature(s) that
|
||||
# are not supported in xgrammar. Fall back to guidance.
|
||||
validate_guidance_grammar(params, tokenizer=None)
|
||||
params.guided_decoding.backend = "guidance"
|
||||
# Remember that this backend was set automatically
|
||||
params.guided_decoding.backend_was_auto = True
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> tuple[Optional[str], EngineCoreRequest]:
|
||||
|
||||
# TODO(woosuk): Support pooling models.
|
||||
# TODO(woosuk): Support encoder-decoder models.
|
||||
self._validate_lora(lora_request)
|
||||
self._validate_params(params, lora_request)
|
||||
if priority != 0:
|
||||
raise ValueError("V1 does not support priority yet.")
|
||||
if trace_headers is not None:
|
||||
raise ValueError("V1 does not support tracing yet.")
|
||||
if prompt_adapter_request is not None:
|
||||
raise ValueError("V1 does not support prompt_adapter_request.")
|
||||
|
||||
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
|
||||
data_parallel_size):
|
||||
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
|
||||
f"is out of range [0, {data_parallel_size}).")
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
# Process inputs, which includes:
|
||||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||
# 2. For multimodal models with a merged preprocessor, preprocess
|
||||
# multimodal data and expand prompt token ids accordingly.
|
||||
# 3. Apply prompt adapter to prompt token ids if one exists.
|
||||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=self.use_hash,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.validate_request(
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
processed_inputs=processed_inputs,
|
||||
)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
|
||||
# TODO: Impl encoder-decoder
|
||||
if encoder_inputs is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
assert isinstance(params, SamplingParams)
|
||||
# TODO: can we avoid cloning here in multiproc case?
|
||||
sampling_params = params.clone()
|
||||
# If unset max tokens, then generate up to the max_model_len.
|
||||
if sampling_params.max_tokens is None:
|
||||
sampling_params.max_tokens = (
|
||||
self.model_config.max_model_len -
|
||||
len(decoder_inputs["prompt_token_ids"]))
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
sampling_params.update_from_tokenizer(
|
||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||
|
||||
# Multimodal related.
|
||||
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
|
||||
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
||||
sorted_mm_hashes: Optional[list[str]] = None
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
|
||||
|
||||
# Merge and flatten multimodal placeholders, hashes and inputs
|
||||
# from dictionaries to lists, and sort them by each item's position
|
||||
# in the input sequence.
|
||||
(
|
||||
sorted_item_modalities,
|
||||
sorted_mm_positions,
|
||||
sorted_mm_hashes,
|
||||
) = merge_and_sort_multimodal_metadata(
|
||||
decoder_inputs["mm_placeholders"],
|
||||
decoder_inputs["mm_hashes"] if self.use_hash else None,
|
||||
)
|
||||
|
||||
# The output of merged multi-modal processor (`decoder_mm_inputs`)
|
||||
# is a single MultiModalKwargs for all items from all modalities.
|
||||
# This code flattens kwargs for individual items in a list and
|
||||
# sorts them by each item's position in the input sequence if there
|
||||
# are multiple modalities.
|
||||
unique_modalities = set(sorted_item_modalities)
|
||||
if len(unique_modalities) > 1:
|
||||
orig_sorted_mm_inputs = []
|
||||
used_indices = {modality: 0 for modality in unique_modalities}
|
||||
|
||||
for modality in sorted_item_modalities:
|
||||
items = decoder_mm_inputs.get_items(modality)
|
||||
item = items[used_indices[modality]]
|
||||
|
||||
orig_sorted_mm_inputs.append(
|
||||
MultiModalKwargs.from_items([item]))
|
||||
used_indices[modality] += 1
|
||||
else:
|
||||
orig_sorted_mm_inputs = [
|
||||
MultiModalKwargs.from_items([item]) for item in
|
||||
decoder_mm_inputs.get_items(sorted_item_modalities[0])
|
||||
]
|
||||
|
||||
if sorted_mm_hashes is not None:
|
||||
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
|
||||
orig_sorted_mm_inputs, sorted_mm_hashes)
|
||||
else:
|
||||
sorted_mm_inputs = orig_sorted_mm_inputs
|
||||
|
||||
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=decoder_inputs["prompt_token_ids"],
|
||||
mm_inputs=sorted_mm_inputs,
|
||||
mm_hashes=sorted_mm_hashes,
|
||||
mm_placeholders=sorted_mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
eos_token_id=eos_token_id,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
cache_salt=decoder_inputs.get("cache_salt"),
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
def _validate_model_inputs(self,
|
||||
inputs: ProcessorInputs,
|
||||
lora_request: Optional[LoRARequest] = None):
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
|
||||
|
||||
if encoder_inputs is not None:
|
||||
self._validate_model_input(encoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="encoder")
|
||||
|
||||
self._validate_model_input(decoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="decoder")
|
||||
|
||||
def _validate_model_input(
|
||||
self,
|
||||
prompt_inputs: SingletonInputs,
|
||||
lora_request: Optional[LoRARequest],
|
||||
*,
|
||||
prompt_type: Literal["encoder", "decoder"],
|
||||
):
|
||||
model_config = self.model_config
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
|
||||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||
if not prompt_ids:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
pass # Mllama may have empty encoder inputs for text-only data
|
||||
else:
|
||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||
|
||||
max_input_id = max(prompt_ids, default=0)
|
||||
if max_input_id > tokenizer.max_token_id:
|
||||
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
|
||||
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
if len(prompt_ids) > max_prompt_len:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
mm_registry = self.input_preprocessor.mm_registry
|
||||
mm_processor = mm_registry.create_processor(
|
||||
model_config,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
assert isinstance(mm_processor, EncDecMultiModalProcessor)
|
||||
|
||||
if mm_processor.pad_dummy_encoder_prompt:
|
||||
return # Skip encoder length check for Whisper
|
||||
|
||||
if model_config.is_multimodal_model:
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens plus multimodal tokens. For image "
|
||||
"inputs, the number of image tokens depends on the number "
|
||||
"of images, and possibly their aspect ratios as well.")
|
||||
else:
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens.")
|
||||
|
||||
raise ValueError(
|
||||
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
|
||||
f"longer than the maximum model length of {max_prompt_len}. "
|
||||
f"{suggestion}")
|
||||
|
||||
# TODO: Find out how many placeholder tokens are there so we can
|
||||
# check that chunked prefill does not truncate them
|
||||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||
Reference in New Issue
Block a user