Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
246
vllm/v1/engine/__init__.py
Normal file
246
vllm/v1/engine/__init__.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
from vllm.v1.serial_utils import UtilityResult
|
||||
|
||||
# Type for pause_generation mode parameter.
|
||||
# - "abort": Abort all in-flight requests immediately (default).
|
||||
# - "wait": Wait for in-flight requests to complete before pausing.
|
||||
# - "keep": Freeze requests in queue; they resume on resume_generation().
|
||||
PauseMode = Literal["abort", "wait", "keep"]
|
||||
|
||||
# These are possible values of RequestOutput.finish_reason,
|
||||
# so form part of the external API.
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
|
||||
|
||||
|
||||
class FinishReason(enum.IntEnum):
|
||||
"""
|
||||
Reason a request finished - stop, length, abort, or error.
|
||||
|
||||
Int rather than Str for more compact serialization.
|
||||
|
||||
stop - a stop string was emitted
|
||||
length - max_tokens was consumed, or max_model_len was reached
|
||||
abort - aborted by client
|
||||
error - retryable request-level internal error (e.g., KV load failure).
|
||||
Invariant: always converted to 500 Internal Server Error.
|
||||
|
||||
"""
|
||||
|
||||
STOP = 0
|
||||
LENGTH = 1
|
||||
ABORT = 2
|
||||
ERROR = 3
|
||||
|
||||
def __str__(self):
|
||||
return FINISH_REASON_STRINGS[self.value]
|
||||
|
||||
|
||||
class EngineCoreRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False,
|
||||
): # type: ignore[call-arg]
|
||||
request_id: str
|
||||
prompt_token_ids: list[int] | None
|
||||
mm_features: list[MultiModalFeatureSpec] | None
|
||||
sampling_params: SamplingParams | None
|
||||
pooling_params: PoolingParams | None
|
||||
arrival_time: float
|
||||
lora_request: LoRARequest | None
|
||||
cache_salt: str | None
|
||||
data_parallel_rank: int | None
|
||||
prompt_embeds: torch.Tensor | None = None
|
||||
|
||||
# Index of the client, used to ensure outputs are sent back to the same
|
||||
# client for this request when scaling out the front-end.
|
||||
client_index: int = 0
|
||||
|
||||
# Used in DP case to indicate which wave of requests this is expected to
|
||||
# belong to, to cover a race condition where the request is sent before
|
||||
# a wave finished notification is received.
|
||||
current_wave: int = 0
|
||||
priority: int = 0
|
||||
|
||||
trace_headers: Mapping[str, str] | None = None
|
||||
resumable: bool = False
|
||||
|
||||
# The user-provided request ID. This field is set internally,
|
||||
# copied from the provided request_id that's originally assigned
|
||||
# to the request_id field, see InputProcessor.assign_request_id().
|
||||
# Used in outputs and to support abort(req_id, internal=False).
|
||||
external_req_id: str | None = None
|
||||
|
||||
reasoning_ended: bool | None = None
|
||||
|
||||
@property
|
||||
def params(self) -> SamplingParams | PoolingParams:
|
||||
"""Return the processed params (sampling or pooling)."""
|
||||
if self.sampling_params is not None:
|
||||
return self.sampling_params
|
||||
assert self.pooling_params is not None
|
||||
return self.pooling_params
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"EngineCoreRequest.eos_token_id will be removed in v0.18. "
|
||||
"Please use EngineCoreRequest.sampling_params.eos_token_id instead."
|
||||
)
|
||||
def eos_token_id(self) -> int | None:
|
||||
if self.sampling_params is None:
|
||||
return None
|
||||
|
||||
return self.sampling_params.eos_token_id
|
||||
|
||||
|
||||
class EngineCoreEventType(enum.IntEnum):
|
||||
"""The type of engine core request event."""
|
||||
|
||||
QUEUED = 1
|
||||
SCHEDULED = 2
|
||||
PREEMPTED = 3
|
||||
|
||||
|
||||
class EngineCoreEvent(msgspec.Struct):
|
||||
"""A timestamped engine core event associated with a request.
|
||||
|
||||
The timestamp is a monotonic timestamps and is used for by the engine
|
||||
frontend to calculate intervals between engine core events. These
|
||||
timestamps should not be compared with timestamps from other processes.
|
||||
"""
|
||||
|
||||
type: EngineCoreEventType
|
||||
timestamp: float
|
||||
|
||||
@classmethod
|
||||
def new_event(
|
||||
cls, event_type: EngineCoreEventType, timestamp: float | None = None
|
||||
) -> "EngineCoreEvent":
|
||||
timestamp = time.monotonic() if timestamp is None else timestamp
|
||||
return cls(event_type, timestamp)
|
||||
|
||||
|
||||
class EngineCoreOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False,
|
||||
): # type: ignore[call-arg]
|
||||
request_id: str
|
||||
new_token_ids: list[int]
|
||||
|
||||
new_logprobs: LogprobsLists | None = None
|
||||
new_prompt_logprobs_tensors: LogprobsTensors | None = None
|
||||
|
||||
pooling_output: torch.Tensor | None = None
|
||||
|
||||
finish_reason: FinishReason | None = None
|
||||
stop_reason: int | str | None = None
|
||||
events: list[EngineCoreEvent] | None = None
|
||||
kv_transfer_params: dict[str, Any] | None = None
|
||||
|
||||
trace_headers: Mapping[str, str] | None = None
|
||||
# The number of tokens with prefix cache hits (local + external).
|
||||
num_cached_tokens: int = 0
|
||||
# The number of tokens computed remotely (original count from connector).
|
||||
num_external_computed_tokens: int = 0
|
||||
routed_experts: np.ndarray | None = None
|
||||
# The number of NaNs in logits.
|
||||
# A value greater than 0 indicates that the output is corrupted.
|
||||
num_nans_in_logits: int = 0
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self.finish_reason is not None
|
||||
|
||||
|
||||
class UtilityOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
gc=False,
|
||||
): # type: ignore[call-arg]
|
||||
call_id: int
|
||||
|
||||
# Non-None implies the call failed, result should be None.
|
||||
failure_message: str | None = None
|
||||
result: UtilityResult | None = None
|
||||
|
||||
|
||||
class EngineCoreOutputs(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False,
|
||||
): # type: ignore[call-arg]
|
||||
# NOTE(Nick): We could consider ways to make this more compact,
|
||||
# e.g. columnwise layout
|
||||
|
||||
engine_index: int = 0
|
||||
|
||||
# [num_reqs]
|
||||
outputs: list[EngineCoreOutput] = []
|
||||
scheduler_stats: SchedulerStats | None = None
|
||||
timestamp: float = 0.0
|
||||
|
||||
utility_output: UtilityOutput | None = None
|
||||
finished_requests: set[str] | None = None
|
||||
|
||||
# In DP case, used to signal that the current wave of requests
|
||||
# has finished and the engines are paused.
|
||||
wave_complete: int | None = None
|
||||
# In DP case, used to signal that a request was received for an
|
||||
# "old" wave, so the next wave needs to be started in other engines.
|
||||
start_wave: int | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp == 0.0:
|
||||
self.timestamp = time.monotonic()
|
||||
|
||||
|
||||
class EngineCoreRequestType(enum.Enum):
|
||||
"""
|
||||
Request types defined as hex byte strings, so it can be sent over sockets
|
||||
without separate encoding step.
|
||||
"""
|
||||
|
||||
ADD = b"\x00"
|
||||
ABORT = b"\x01"
|
||||
START_DP_WAVE = b"\x02"
|
||||
UTILITY = b"\x03"
|
||||
# Sentinel used within EngineCoreProc.
|
||||
EXECUTOR_FAILED = b"\x04"
|
||||
|
||||
|
||||
class ReconfigureDistributedRequest(msgspec.Struct):
|
||||
new_data_parallel_size: int
|
||||
new_data_parallel_rank: int
|
||||
new_data_parallel_rank_local: int
|
||||
new_data_parallel_master_ip: str
|
||||
new_data_parallel_master_port: int
|
||||
|
||||
|
||||
class ReconfigureRankType(enum.IntEnum):
|
||||
"""
|
||||
Rank type for reconfiguring distributed request.
|
||||
"""
|
||||
|
||||
KEEP_CURRENT_RANK = -1
|
||||
SHUTDOWN_CURRENT_RANK = -2
|
||||
1059
vllm/v1/engine/async_llm.py
Normal file
1059
vllm/v1/engine/async_llm.py
Normal file
File diff suppressed because it is too large
Load Diff
394
vllm/v1/engine/coordinator.py
Normal file
394
vllm/v1/engine/coordinator.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
|
||||
import msgspec.msgpack
|
||||
import zmq
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.utils.system_utils import get_mp_context, set_process_title
|
||||
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DPCoordinator:
|
||||
"""Coordinator process used for data-parallel deployments (DP>1).
|
||||
|
||||
Intermediates between multiple DP engine rank processes and one or more
|
||||
front-end API server processes.
|
||||
|
||||
* Collects stats from each DP engine (currently just waiting and running
|
||||
queue lengths), and publishes these to all front-ends for use in
|
||||
load-balancing decisions.
|
||||
|
||||
* Keeps track of the current DP "request wave" number and running state
|
||||
of the engines. This is received from the DP rank 0 engine and published
|
||||
to the front-end processes along with the current load stats.
|
||||
|
||||
The engines alternate between a global running/paused state. The global
|
||||
"request wave" number is a count of the number of times that the workers
|
||||
collectively move from a running state to a paused state. This transition
|
||||
is synchronized via the all-reduce operation performed in the
|
||||
DPEngineCoreProc._has_global_unfinished_reqs method.
|
||||
|
||||
* Broadcasts the START_DP_WAVE message to engines to move them from paused
|
||||
to running state when one engine receives a new request. This can happen
|
||||
in two cases:
|
||||
1) A front-end sending a new request while the engines are paused will
|
||||
concurrently notify the coordinator.
|
||||
2) An engine receiving a request for a stale request wave while in paused
|
||||
state will notify the coordinator.
|
||||
|
||||
Engines will move into running state when receiving a new request or
|
||||
START_DP_WAVE message.
|
||||
|
||||
Note that when deployed in External LB mode, no stats will be published by
|
||||
the engines and thus updates will only be sent to front-ends when the
|
||||
request wave / running state changes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, parallel_config: ParallelConfig, enable_wave_coordination: bool = True
|
||||
):
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
assert dp_size > 1, "Coordinator only used for data parallel"
|
||||
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
|
||||
# Assume coordinator is colocated with front-end procs when not in
|
||||
# either external or hybrid DP LB mode.
|
||||
local_only = not parallel_config.local_engines_only
|
||||
front_publish_address = get_engine_client_zmq_addr(
|
||||
local_only=local_only, host=host
|
||||
)
|
||||
|
||||
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
|
||||
context = get_mp_context()
|
||||
self.proc: multiprocessing.Process = context.Process(
|
||||
target=DPCoordinatorProc.run_coordinator,
|
||||
name="VLLM_DP_Coordinator",
|
||||
kwargs={
|
||||
"engine_count": parallel_config.data_parallel_size,
|
||||
"front_publish_address": front_publish_address,
|
||||
"back_output_address": back_output_address,
|
||||
"back_publish_address": back_publish_address,
|
||||
"enable_wave_coordination": enable_wave_coordination,
|
||||
},
|
||||
daemon=True,
|
||||
)
|
||||
self.proc.start()
|
||||
|
||||
self.stats_publish_address = front_publish_address
|
||||
self.coord_in_address = back_publish_address
|
||||
self.coord_out_address = back_output_address
|
||||
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
|
||||
|
||||
def get_stats_publish_address(self) -> str:
|
||||
return self.stats_publish_address
|
||||
|
||||
def get_engine_socket_addresses(self) -> tuple[str, str]:
|
||||
"""Returns tuple of ZMQ input address, output address."""
|
||||
return self.coord_in_address, self.coord_out_address
|
||||
|
||||
def close(self):
|
||||
self._finalizer()
|
||||
|
||||
|
||||
class EngineState:
|
||||
def __init__(self):
|
||||
self.request_counts = [0, 0] # [waiting, running]
|
||||
|
||||
|
||||
class DPCoordinatorProc:
|
||||
def __init__(
|
||||
self,
|
||||
engine_count: int,
|
||||
min_stats_update_interval_ms: int = 100,
|
||||
enable_wave_coordination: bool = True,
|
||||
):
|
||||
set_process_title("DPCoordinator")
|
||||
self.ctx = zmq.Context()
|
||||
|
||||
self.engines = [EngineState() for _ in range(engine_count)]
|
||||
|
||||
self.stats_update_interval_ms = min_stats_update_interval_ms
|
||||
self.enable_wave_coordination = enable_wave_coordination
|
||||
|
||||
@staticmethod
|
||||
def run_coordinator(
|
||||
engine_count: int,
|
||||
front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
min_stats_update_interval_ms: int = 100,
|
||||
enable_wave_coordination: bool = True,
|
||||
):
|
||||
coordinator = DPCoordinatorProc(
|
||||
engine_count=engine_count,
|
||||
min_stats_update_interval_ms=min_stats_update_interval_ms,
|
||||
enable_wave_coordination=enable_wave_coordination,
|
||||
)
|
||||
try:
|
||||
coordinator.process_input_socket(
|
||||
front_publish_address,
|
||||
back_output_address,
|
||||
back_publish_address,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DP Coordinator process exiting")
|
||||
|
||||
def process_input_socket(
|
||||
self,
|
||||
front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
):
|
||||
decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
# For tracking request wave progression.
|
||||
current_wave = 0
|
||||
engines_running = False
|
||||
|
||||
# For tracking request counts for internal load-balancing.
|
||||
stats_changed = False
|
||||
last_stats_step = -1
|
||||
last_stats_wave = -1
|
||||
last_step_counts: list[list[int]] | None = None
|
||||
|
||||
with (
|
||||
make_zmq_socket(
|
||||
path=front_publish_address, # IPC
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.XPUB,
|
||||
bind=True,
|
||||
) as publish_front,
|
||||
make_zmq_socket(
|
||||
path=back_output_address, # IPC or TCP
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.PULL,
|
||||
bind=True,
|
||||
) as output_back,
|
||||
make_zmq_socket(
|
||||
path=back_publish_address, # IPC or TCP
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.XPUB,
|
||||
bind=True,
|
||||
) as publish_back,
|
||||
):
|
||||
# Wait until all engines subscribe.
|
||||
for _ in self.engines:
|
||||
if publish_back.recv() != b"\x01":
|
||||
logger.error(
|
||||
"DP Coordinator received unexpected message while "
|
||||
"waiting for engines to subscribe"
|
||||
)
|
||||
return
|
||||
# Send ready message to engines.
|
||||
publish_back.send(b"READY")
|
||||
|
||||
logger.info("All engine subscriptions received by DP coordinator")
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(publish_front, zmq.POLLIN)
|
||||
poller.register(output_back, zmq.POLLIN)
|
||||
last_publish_time = 0
|
||||
while True:
|
||||
elapsed = int(time.time() * 1000) - last_publish_time
|
||||
# Send at stats_update_interval_ms interval if the stats have
|
||||
# changed, or otherwise every 5 seconds.
|
||||
wait_for = self.stats_update_interval_ms if stats_changed else 5000
|
||||
|
||||
# Wait at least 50ms to ensure we've received all stats for
|
||||
# the current step.
|
||||
min_timeout = 50 if last_step_counts is None else 0
|
||||
|
||||
events = poller.poll(timeout=max(min_timeout, wait_for - elapsed))
|
||||
if not events:
|
||||
# Poller timeout - publish current stats to front-ends.
|
||||
if last_step_counts is not None:
|
||||
engine_req_counts_list = last_step_counts
|
||||
last_step_counts = None
|
||||
else:
|
||||
engine_req_counts_list = self._get_engine_counts()
|
||||
stats_changed = False
|
||||
|
||||
to_publish = (engine_req_counts_list, current_wave, engines_running)
|
||||
publish_front.send(msgspec.msgpack.encode(to_publish))
|
||||
last_publish_time = int(time.time() * 1000)
|
||||
continue
|
||||
|
||||
events = dict(events)
|
||||
wave_state_changed = False
|
||||
|
||||
if publish_front in events:
|
||||
buffer = publish_front.recv()
|
||||
if buffer in (b"\x01", b"\x00"):
|
||||
# Ignore subscription messages.
|
||||
continue
|
||||
|
||||
decoded = msgspec.msgpack.decode(buffer)
|
||||
if (
|
||||
isinstance(decoded, (list, tuple))
|
||||
and len(decoded) == 2
|
||||
and decoded[0] == "SCALE_ELASTIC_EP"
|
||||
):
|
||||
# Handle scale up notification
|
||||
new_engine_count = decoded[1]
|
||||
current_count = len(self.engines)
|
||||
if new_engine_count > current_count:
|
||||
for _ in range(new_engine_count - current_count):
|
||||
self.engines.append(EngineState())
|
||||
# NOTE(yongji): handle the case
|
||||
# where newly started engines have current_wave = 0
|
||||
# if existing engines just finished a wave
|
||||
# and engine_running isn't updated yet at
|
||||
# CoordinatorProc requests routed to newly started
|
||||
# engines may not wake up existing engines, as long
|
||||
# as 0 < request.wave < existing engines'
|
||||
# current_wave
|
||||
# we note that 0 is the wave number for the new
|
||||
# engine
|
||||
engines_running = False
|
||||
logger.info(
|
||||
"DPCoordinator scaled up from %s to %s engines",
|
||||
current_count,
|
||||
new_engine_count,
|
||||
)
|
||||
else:
|
||||
self.engines = self.engines[:new_engine_count]
|
||||
logger.info(
|
||||
"DPCoordinator scaled down from %s to %s engines",
|
||||
current_count,
|
||||
new_engine_count,
|
||||
)
|
||||
continue # Skip normal engine notification processing
|
||||
|
||||
# Wave coordination: handle new-request messages from front-end.
|
||||
# Only process these when wave coordination is enabled
|
||||
if self.enable_wave_coordination:
|
||||
# We received a message on the front-end XPUB socket,
|
||||
# from an API server sending a new request while the
|
||||
# engines are paused, so that we can wake the other
|
||||
# engines.
|
||||
engine_to_exclude, wave = decoded
|
||||
if not engines_running:
|
||||
if wave < current_wave:
|
||||
# If the wave number is stale, ensure the message
|
||||
# is handled by all the engines.
|
||||
engine_to_exclude = None
|
||||
|
||||
engines_running = True
|
||||
wave_state_changed = True
|
||||
self._send_start_wave(
|
||||
publish_back, current_wave, engine_to_exclude
|
||||
)
|
||||
|
||||
if output_back in events:
|
||||
# We received a message from one of the engines.
|
||||
|
||||
buffer = output_back.recv()
|
||||
outputs: EngineCoreOutputs = decoder.decode(buffer)
|
||||
|
||||
assert not outputs.outputs
|
||||
assert outputs.utility_output is None
|
||||
|
||||
eng_index = outputs.engine_index
|
||||
scheduler_stats = outputs.scheduler_stats
|
||||
if scheduler_stats:
|
||||
# 1. Updated request load stats - update our local
|
||||
# state with these.
|
||||
stats = self.engines[eng_index].request_counts
|
||||
stats_step = scheduler_stats.step_counter
|
||||
stats_wave = scheduler_stats.current_wave
|
||||
if (
|
||||
stats_wave > last_stats_wave
|
||||
or stats_wave == last_stats_wave
|
||||
and stats_step > last_stats_step
|
||||
):
|
||||
if stats_changed:
|
||||
last_step_counts = self._get_engine_counts(do_copy=True)
|
||||
last_stats_step = stats_step
|
||||
last_stats_wave = stats_wave
|
||||
elif stats_wave != last_stats_wave or (
|
||||
stats_step != last_stats_step
|
||||
):
|
||||
logger.warning(
|
||||
"Received stats for out-of-order "
|
||||
"step (%d, %d) from engine %d (expected "
|
||||
"> (%d, %d))",
|
||||
stats_wave,
|
||||
stats_step,
|
||||
eng_index,
|
||||
last_stats_wave,
|
||||
last_stats_step,
|
||||
)
|
||||
stats[0] = scheduler_stats.num_waiting_reqs
|
||||
stats[1] = scheduler_stats.num_running_reqs
|
||||
stats_changed = True
|
||||
|
||||
# Wave coordination: handle wave completion and start notifications
|
||||
# Only process these when wave coordination is enabled
|
||||
if self.enable_wave_coordination:
|
||||
if (wave := outputs.wave_complete) is not None:
|
||||
# 2. Notification from rank 0 engine that we've
|
||||
# moved into the global paused state
|
||||
# (engines_running==False).
|
||||
if current_wave <= wave:
|
||||
new_wave = wave + 1
|
||||
logger.debug(
|
||||
"Moving DP wave from %d to %d.",
|
||||
current_wave,
|
||||
new_wave,
|
||||
)
|
||||
current_wave = new_wave
|
||||
engines_running = False
|
||||
wave_state_changed = True
|
||||
elif (wave := outputs.start_wave) is not None and (
|
||||
wave > current_wave
|
||||
or (wave == current_wave and not engines_running)
|
||||
):
|
||||
# 3. The engine received request for a non-current wave
|
||||
# so we must ensure that other engines progress to the
|
||||
# next wave (race condition handling).
|
||||
logger.debug(
|
||||
"Starting wave %d after notification of "
|
||||
"stale wave request from engine.",
|
||||
wave,
|
||||
)
|
||||
current_wave = wave
|
||||
engines_running = True
|
||||
wave_state_changed = True
|
||||
self._send_start_wave(publish_back, wave, eng_index)
|
||||
|
||||
if wave_state_changed:
|
||||
message = (None, current_wave, engines_running)
|
||||
publish_front.send(msgspec.msgpack.encode(message))
|
||||
|
||||
@staticmethod
|
||||
def _send_start_wave(
|
||||
socket: zmq.Socket, wave: int, exclude_engine_index: int | None
|
||||
):
|
||||
"""Broadcast the START_DP_WAVE message to all the engines.
|
||||
It includes the current wave number and index of engine which
|
||||
has already received a request with this wave number and so doesn't
|
||||
require additional notification.
|
||||
"""
|
||||
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
|
||||
socket.send_multipart((EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
|
||||
|
||||
def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
|
||||
"""Return list of [waiting, running] count lists for each engine."""
|
||||
if do_copy:
|
||||
return [copy.copy(e.request_counts) for e in self.engines]
|
||||
return [e.request_counts for e in self.engines]
|
||||
1842
vllm/v1/engine/core.py
Normal file
1842
vllm/v1/engine/core.py
Normal file
File diff suppressed because it is too large
Load Diff
1456
vllm/v1/engine/core_client.py
Normal file
1456
vllm/v1/engine/core_client.py
Normal file
File diff suppressed because it is too large
Load Diff
341
vllm/v1/engine/detokenizer.py
Normal file
341
vllm/v1/engine/detokenizer.py
Normal file
@@ -0,0 +1,341 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import tokenizers
|
||||
from packaging import version
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.decoders import DecodeStream
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.detokenizer_utils import (
|
||||
convert_prompt_ids_to_tokens,
|
||||
detokenize_incrementally,
|
||||
)
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Only tokenizers >= 0.22.0 supports DecodeStream with native prefill
|
||||
# (ids parameter) used for FastIncrementalDetokenizer.
|
||||
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.22.0")
|
||||
|
||||
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
|
||||
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
|
||||
|
||||
|
||||
class IncrementalDetokenizer:
|
||||
def __init__(self):
|
||||
self.token_ids: list[int] = []
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> list[int]:
|
||||
return self.token_ids
|
||||
|
||||
def num_output_tokens(self) -> int:
|
||||
return len(self.token_ids)
|
||||
|
||||
def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
|
||||
self.token_ids.extend(new_token_ids)
|
||||
return None
|
||||
|
||||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: TokenizerLike | None,
|
||||
request: EngineCoreRequest,
|
||||
) -> "IncrementalDetokenizer":
|
||||
assert request.sampling_params is not None
|
||||
|
||||
if tokenizer is None:
|
||||
# No tokenizer => skipping detokenization.
|
||||
return IncrementalDetokenizer()
|
||||
|
||||
if USE_FAST_DETOKENIZER and isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
# Fast tokenizer => use tokenizers library DecodeStream.
|
||||
return FastIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
# Fall back to slow python-based incremental detokenization.
|
||||
return SlowIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
|
||||
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
||||
def __init__(self, request: EngineCoreRequest):
|
||||
super().__init__()
|
||||
|
||||
# Stop strings
|
||||
params = request.sampling_params
|
||||
assert params is not None
|
||||
stop_list: list[str]
|
||||
if params.stop is None:
|
||||
stop_list = []
|
||||
elif isinstance(params.stop, str):
|
||||
stop_list = [params.stop]
|
||||
else:
|
||||
stop_list = params.stop
|
||||
self.stop = stop_list
|
||||
self.min_tokens = params.min_tokens
|
||||
self.include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
# Number of chars to hold back when stop strings are to be excluded
|
||||
# from streamed output.
|
||||
if self.stop and not self.include_stop_str_in_output:
|
||||
self.stop_buffer_length = max(len(s) for s in self.stop) - 1
|
||||
else:
|
||||
self.stop_buffer_length = 0
|
||||
self._last_output_text_offset: int = 0
|
||||
|
||||
# Generation data
|
||||
self.output_text = ""
|
||||
|
||||
def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
|
||||
"""
|
||||
Update RequestState for the request_id by:
|
||||
1) Detokenize the new token ids incrementally.
|
||||
2) Evaluate stop criteria.
|
||||
|
||||
Return matched stop string or None.
|
||||
"""
|
||||
if not new_token_ids:
|
||||
# Skip detokenization if no new token ids.
|
||||
return None
|
||||
|
||||
if stop_terminated and not self.include_stop_str_in_output:
|
||||
# If stop-terminated, exclude last token from detokenization
|
||||
# based on include_stop_str_in_output parameter.
|
||||
skipped_stop_token_id = new_token_ids[-1]
|
||||
new_token_ids = new_token_ids[:-1]
|
||||
else:
|
||||
skipped_stop_token_id = None
|
||||
|
||||
# 1) Detokenize the new token ids incrementally.
|
||||
stop_check_offset = len(self.output_text)
|
||||
for new_token_id in new_token_ids:
|
||||
self.token_ids.append(new_token_id)
|
||||
self.output_text += self.decode_next(new_token_id)
|
||||
# Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
|
||||
if self.min_tokens and self.num_output_tokens() <= self.min_tokens:
|
||||
stop_check_offset = len(self.output_text)
|
||||
|
||||
if skipped_stop_token_id is not None:
|
||||
# Cleanup after skipping detokenization.
|
||||
self.token_ids.append(skipped_stop_token_id)
|
||||
|
||||
# 2) Evaluate stop strings.
|
||||
stop_string = None
|
||||
if self.stop and self.num_output_tokens() > self.min_tokens:
|
||||
stop = check_stop_strings(
|
||||
output_text=self.output_text,
|
||||
new_char_count=len(self.output_text) - stop_check_offset,
|
||||
stop=self.stop,
|
||||
include_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
if stop is not None:
|
||||
stop_string, truncate_to = stop
|
||||
if truncate_to != -1:
|
||||
self.output_text = self.output_text[:truncate_to]
|
||||
|
||||
return stop_string
|
||||
|
||||
@abstractmethod
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
"""If delta is True, only new text since the last call to
|
||||
this method is returned"""
|
||||
|
||||
# We return the full output text if the sequence is finished.
|
||||
buffer_length = 0 if finished else self.stop_buffer_length
|
||||
if not delta:
|
||||
if not buffer_length:
|
||||
return self.output_text
|
||||
return self.output_text[:-buffer_length]
|
||||
|
||||
length = len(self.output_text) - buffer_length
|
||||
last_offset = self._last_output_text_offset
|
||||
if last_offset < length:
|
||||
self._last_output_text_offset = length
|
||||
return self.output_text[last_offset:length]
|
||||
return ""
|
||||
|
||||
|
||||
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
|
||||
self.request_id = request.request_id
|
||||
self.skip_special_tokens = sampling_params.skip_special_tokens
|
||||
|
||||
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
||||
|
||||
# Use native prefill to prime the decode stream with prompt tokens.
|
||||
self.stream = DecodeStream(
|
||||
ids=request.prompt_token_ids,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
)
|
||||
|
||||
self.spaces_between_special_tokens = (
|
||||
sampling_params.skip_special_tokens
|
||||
or sampling_params.spaces_between_special_tokens
|
||||
)
|
||||
|
||||
if not self.spaces_between_special_tokens:
|
||||
# Store dict of added token ids so that we can suppress
|
||||
# the spaces between them.
|
||||
added_token_ids = getattr(self.tokenizer, "added_token_ids", None)
|
||||
if added_token_ids is None:
|
||||
self.tokenizer.added_token_ids = added_token_ids = {
|
||||
tid: tok.content
|
||||
for tid, tok in self.tokenizer.get_added_tokens_decoder().items()
|
||||
}
|
||||
|
||||
if added_token_ids:
|
||||
self.last_special = False
|
||||
self.added_token_ids = added_token_ids
|
||||
else:
|
||||
# No added tokens.
|
||||
self.spaces_between_special_tokens = True
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
token = self._protected_step(next_token_id)
|
||||
|
||||
if not self.spaces_between_special_tokens:
|
||||
special_token = self.added_token_ids.get(next_token_id)
|
||||
is_special = special_token is not None
|
||||
if is_special and self.last_special:
|
||||
# Return raw token string without any prefixed spaces.
|
||||
token = special_token
|
||||
self.last_special = is_special
|
||||
|
||||
return token or ""
|
||||
|
||||
def _protected_step(self, next_token_id: int) -> str | None:
|
||||
try:
|
||||
token = self.stream.step(self.tokenizer, next_token_id)
|
||||
except (OverflowError, TypeError):
|
||||
# Handle rare observed overflow, still to be diagnosed.
|
||||
# See https://github.com/vllm-project/vllm/issues/21951.
|
||||
logger.exception("Encountered invalid token id: %r", next_token_id)
|
||||
token = None
|
||||
except Exception as e:
|
||||
if not str(e).startswith(INVALID_PREFIX_ERR_MSG):
|
||||
raise e
|
||||
# Recover from edge case where tokenizer can produce non-monotonic,
|
||||
# invalid UTF-8 output, which breaks the internal state of
|
||||
# tokenizers' DecodeStream.
|
||||
# See https://github.com/vllm-project/vllm/issues/17448.
|
||||
logger.warning(
|
||||
"Encountered invalid prefix detokenization error"
|
||||
" for request %s, resetting decode stream.",
|
||||
self.request_id,
|
||||
)
|
||||
self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
|
||||
token = self.stream.step(self.tokenizer, next_token_id)
|
||||
return token
|
||||
|
||||
|
||||
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike, request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
params = request.sampling_params
|
||||
assert params is not None
|
||||
|
||||
self.prompt_len = length_from_prompt_token_ids_or_embeds(
|
||||
request.prompt_token_ids, request.prompt_embeds
|
||||
)
|
||||
|
||||
# Metadata for incremental detokenization.
|
||||
if request.prompt_token_ids is not None:
|
||||
self.tokens, self.prefix_offset, self.read_offset = (
|
||||
convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=request.prompt_token_ids,
|
||||
skip_special_tokens=params.skip_special_tokens,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Prompt embedding requests cannot be detokenized, in general.
|
||||
self.tokens = [""] * self.prompt_len
|
||||
self.prefix_offset = 0
|
||||
self.read_offset = 0
|
||||
|
||||
self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len)
|
||||
|
||||
self.skip_special_tokens = params.skip_special_tokens
|
||||
self.spaces_between_special_tokens = params.spaces_between_special_tokens
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> list[int]:
|
||||
if self.prompt_len:
|
||||
return self.token_ids[self.prompt_len :]
|
||||
return self.token_ids
|
||||
|
||||
def num_output_tokens(self) -> int:
|
||||
return len(self.token_ids) - self.prompt_len
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally(
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=self.token_ids,
|
||||
prev_tokens=self.tokens,
|
||||
prefix_offset=self.prefix_offset,
|
||||
read_offset=self.read_offset,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
self.tokens.extend(new_tokens)
|
||||
self.prefix_offset = prefix_offset
|
||||
self.read_offset = read_offset
|
||||
|
||||
return decoded_text
|
||||
|
||||
|
||||
def check_stop_strings(
|
||||
output_text: str,
|
||||
new_char_count: int,
|
||||
stop: list[str],
|
||||
include_in_output: bool,
|
||||
) -> tuple[str, int] | None:
|
||||
"""Check if any stop strings are matched and truncate sequence
|
||||
output text accordingly.
|
||||
|
||||
Returns tuple (stop_string, offset) if matched or else None.
|
||||
|
||||
Where stop_string is the matched stop string and offset is the
|
||||
length to which output_text should be truncated, or -1 for no
|
||||
truncation.
|
||||
"""
|
||||
if not new_char_count or not stop:
|
||||
return None
|
||||
|
||||
for stop_str in stop:
|
||||
stop_string_len = len(stop_str)
|
||||
# Avoid searching already-searched text.
|
||||
stop_index = output_text.find(stop_str, 1 - new_char_count - stop_string_len)
|
||||
if stop_index == -1:
|
||||
continue
|
||||
|
||||
if include_in_output:
|
||||
# Truncate to end of stop string.
|
||||
stop_index += stop_string_len
|
||||
if stop_index >= len(output_text):
|
||||
# No truncation required.
|
||||
return stop_str, -1
|
||||
|
||||
# Truncate the output text to either the beginning
|
||||
# or end of the stop string.
|
||||
return stop_str, stop_index
|
||||
return None
|
||||
18
vllm/v1/engine/exceptions.py
Normal file
18
vllm/v1/engine/exceptions.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
class EngineGenerateError(Exception):
|
||||
"""Raised when a AsyncLLM.generate() fails. Recoverable."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EngineDeadError(Exception):
|
||||
"""Raised when the EngineCore dies. Unrecoverable."""
|
||||
|
||||
def __init__(self, *args, suppress_context: bool = False, **kwargs):
|
||||
ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501
|
||||
|
||||
super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs)
|
||||
# Make stack trace clearer when using with LLMEngine by
|
||||
# silencing irrelevant ZMQError.
|
||||
self.__suppress_context__ = suppress_context
|
||||
476
vllm/v1/engine/input_processor.py
Normal file
476
vllm/v1/engine/input_processor.py
Normal file
@@ -0,0 +1,476 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs.data import (
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
SingletonInputs,
|
||||
)
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.encoder_budget import MultiModalBudget
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalFeatureSpec,
|
||||
)
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer, renderer_from_config
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import GENERATION_TASKS, POOLING_TASKS, SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
from vllm.utils.jsontree import json_iter_leaves
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class InputProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
renderer: BaseRenderer | None = None,
|
||||
*,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.structured_outputs_config = vllm_config.structured_outputs_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.generation_config_fields = model_config.try_get_generation_config()
|
||||
|
||||
self.renderer = renderer or renderer_from_config(vllm_config)
|
||||
|
||||
self.supports_mm_inputs = mm_registry.supports_multimodal_inputs(model_config)
|
||||
self.mm_encoder_cache_size = 0
|
||||
self.skip_prompt_length_check = False
|
||||
if self.supports_mm_inputs:
|
||||
mm_budget = MultiModalBudget(vllm_config, mm_registry)
|
||||
self.mm_encoder_cache_size = mm_budget.encoder_cache_size
|
||||
self.skip_prompt_length_check = (
|
||||
mm_budget.processor.info.skip_prompt_length_check
|
||||
)
|
||||
mm_budget.reset_cache() # Not used anymore
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(
|
||||
vllm_config,
|
||||
renderer=renderer,
|
||||
mm_registry=mm_registry,
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
platform_validate_request = current_platform.validate_request
|
||||
if supports_kw(platform_validate_request, "prompt"):
|
||||
logger.warning_once(
|
||||
"The signature of Platform.validate_request has changed from "
|
||||
"`(cls, prompt, params, processed_inputs) -> None` to "
|
||||
"`(cls, processed_inputs, params) -> None`. The old signature "
|
||||
"will no longer be supported starting from v0.18."
|
||||
)
|
||||
|
||||
orig_validate_request = platform_validate_request
|
||||
|
||||
def compat_validate_request(
|
||||
processed_inputs: ProcessorInputs,
|
||||
params: SamplingParams | PoolingParams,
|
||||
):
|
||||
return orig_validate_request(
|
||||
processed_inputs,
|
||||
params,
|
||||
processed_inputs, # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
platform_validate_request = compat_validate_request
|
||||
|
||||
self._platform_validate_request = platform_validate_request
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
return self.renderer.tokenizer
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.renderer.get_tokenizer()
|
||||
|
||||
def _validate_params(
|
||||
self,
|
||||
params: SamplingParams | PoolingParams,
|
||||
supported_tasks: tuple[SupportedTask, ...],
|
||||
) -> None:
|
||||
"""Raise `ValueError` if SamplingParams or PoolingParams is not valid."""
|
||||
if params.truncate_prompt_tokens is not None:
|
||||
params_type = type(params).__name__
|
||||
warnings.warn(
|
||||
f"The `truncate_prompt_tokens` parameter in `{params_type}` "
|
||||
"is deprecated and will be removed in v0.17. "
|
||||
"Please pass it via `tokenization_kwargs` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if isinstance(params, SamplingParams):
|
||||
supported_generation_tasks = [
|
||||
task for task in supported_tasks if task in GENERATION_TASKS
|
||||
]
|
||||
if not supported_generation_tasks:
|
||||
raise ValueError("This model does not support generation")
|
||||
|
||||
params.verify(
|
||||
self.model_config,
|
||||
self.speculative_config,
|
||||
self.structured_outputs_config,
|
||||
self.tokenizer,
|
||||
)
|
||||
elif isinstance(params, PoolingParams):
|
||||
supported_pooling_tasks = [
|
||||
task for task in supported_tasks if task in POOLING_TASKS
|
||||
]
|
||||
if not supported_pooling_tasks:
|
||||
raise ValueError("This model does not support pooling")
|
||||
|
||||
if params.task is None:
|
||||
if "token_embed" in supported_pooling_tasks:
|
||||
params.task = "token_embed"
|
||||
elif "token_classify" in supported_pooling_tasks:
|
||||
params.task = "token_classify"
|
||||
elif "plugin" in supported_pooling_tasks:
|
||||
params.task = "plugin"
|
||||
|
||||
if params.task not in supported_pooling_tasks:
|
||||
raise ValueError(
|
||||
f"Unsupported task: {params.task!r} "
|
||||
f"Supported tasks: {supported_pooling_tasks}"
|
||||
)
|
||||
|
||||
params.verify(self.model_config)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"params must be either SamplingParams or PoolingParams, "
|
||||
f"but got {type(params).__name__}"
|
||||
)
|
||||
|
||||
def _validate_lora(self, lora_request: LoRARequest | None) -> None:
|
||||
if lora_request is None:
|
||||
return
|
||||
|
||||
# LoRA request passed in while LoRA is not enabled
|
||||
if not self.lora_config:
|
||||
raise ValueError(
|
||||
f"Got lora_request {lora_request} but LoRA is not enabled!"
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
logger.warning_once(
|
||||
"vLLM has deprecated support for supporting different "
|
||||
"tokenizers for different LoRAs. By default, vLLM uses base "
|
||||
"model's tokenizer. If you are using a LoRA "
|
||||
"with its own tokenizer, consider specifying `--tokenizer "
|
||||
"[lora_path]` to use the LoRA tokenizer."
|
||||
)
|
||||
|
||||
def _get_mm_identifier(
|
||||
self,
|
||||
mm_hash: str,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> str:
|
||||
"""
|
||||
When enable_tower_connector_lora is True, multi-modal embeddings
|
||||
vary depending on the LoRA request. Therefore, the mm_hash must be
|
||||
generated based on the LoRA request to prevent incorrect cache hits.
|
||||
"""
|
||||
if (
|
||||
lora_request is None
|
||||
or self.lora_config is None
|
||||
or not self.lora_config.enable_tower_connector_lora
|
||||
):
|
||||
return mm_hash
|
||||
return f"{lora_request.lora_name}:{mm_hash}"
|
||||
|
||||
@staticmethod
|
||||
def assign_request_id(request: EngineCoreRequest):
|
||||
"""Replace the externally supplied request ID with an internal request ID
|
||||
that adds 8 random characters in order to ensure uniquness.
|
||||
"""
|
||||
if request.external_req_id is not None:
|
||||
raise ValueError(
|
||||
"The external_req_id field should not be set on EngineCoreRequests"
|
||||
" passed to vLLM; use the request_id field."
|
||||
)
|
||||
request.external_req_id = request.request_id
|
||||
if envs.VLLM_DISABLE_REQUEST_ID_RANDOMIZATION:
|
||||
logger.warning_once(
|
||||
"VLLM_DISABLE_REQUEST_ID_RANDOMIZATION is set and will be "
|
||||
"removed in a future release. Duplicate externally-provided "
|
||||
"request IDs may cause failures and/or subtle correctness errors."
|
||||
)
|
||||
else:
|
||||
request.request_id = f"{request.external_req_id}-{random_uuid():.8}"
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType | ProcessorInputs,
|
||||
params: SamplingParams | PoolingParams,
|
||||
supported_tasks: tuple[SupportedTask, ...],
|
||||
arrival_time: float | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: int | None = None,
|
||||
resumable: bool = False,
|
||||
) -> EngineCoreRequest:
|
||||
self._validate_params(params, supported_tasks)
|
||||
self._validate_lora(lora_request)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_local_size = parallel_config.data_parallel_size_local
|
||||
num_ranks = dp_local_size if parallel_config.local_engines_only else dp_size
|
||||
if data_parallel_rank is not None and not (0 <= data_parallel_rank < num_ranks):
|
||||
raise ValueError(
|
||||
f"data_parallel_rank {data_parallel_rank} "
|
||||
f"is out of range [0, {num_ranks})."
|
||||
)
|
||||
|
||||
if isinstance(prompt, dict) and "type" in prompt:
|
||||
if tokenization_kwargs:
|
||||
logger.warning_once(
|
||||
"Passing tokenization_kwargs to InputProcessor is deprecated "
|
||||
"and will be removed in v0.18. You should instead pass "
|
||||
"them to Renderer.render_cmpl() or Renderer.render_chat()."
|
||||
)
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = prompt.get("arrival_time", time.time()) # type: ignore[assignment]
|
||||
|
||||
processed_inputs: ProcessorInputs = prompt # type: ignore[assignment]
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Passing raw prompts to InputProcessor is deprecated "
|
||||
"and will be removed in v0.18. You should instead pass "
|
||||
"the outputs of Renderer.render_cmpl() or Renderer.render_chat()."
|
||||
)
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
processed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
self._platform_validate_request(processed_inputs, params)
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
self._validate_model_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
# Mypy can be conservative for TypedDict unions; normalize access.
|
||||
if decoder_inputs["type"] == "embeds":
|
||||
prompt_token_ids = None
|
||||
prompt_embeds = decoder_inputs["prompt_embeds"]
|
||||
else:
|
||||
prompt_token_ids = decoder_inputs["prompt_token_ids"]
|
||||
prompt_embeds = None
|
||||
|
||||
sampling_params = None
|
||||
pooling_params = None
|
||||
if isinstance(params, SamplingParams):
|
||||
# TODO: can we avoid cloning here in multiproc case?
|
||||
sampling_params = params.clone()
|
||||
# If unset max tokens, then generate up to the max_model_len.
|
||||
if sampling_params.max_tokens is None:
|
||||
seq_len = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids, prompt_embeds
|
||||
)
|
||||
sampling_params.max_tokens = self.model_config.max_model_len - seq_len
|
||||
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields,
|
||||
self.renderer.get_eos_token_id(),
|
||||
)
|
||||
if self.tokenizer is not None:
|
||||
sampling_params.update_from_tokenizer(self.tokenizer)
|
||||
else:
|
||||
pooling_params = params.clone()
|
||||
|
||||
# Multimodal related.
|
||||
mm_features: list[MultiModalFeatureSpec] | None = None
|
||||
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
|
||||
decoder_mm_positions = decoder_inputs["mm_placeholders"]
|
||||
decoder_mm_hashes = decoder_inputs["mm_hashes"]
|
||||
|
||||
if not all(
|
||||
isinstance(leaf, str) for leaf in json_iter_leaves(decoder_mm_hashes)
|
||||
):
|
||||
raise ValueError(
|
||||
f"mm_hashes must contain only strings, got: {decoder_mm_hashes}. "
|
||||
"This is likely due to an incorrect custom implementation of "
|
||||
"MultiModalProcessor.apply method."
|
||||
)
|
||||
|
||||
# Merge and flatten multimodal placeholders, hashes and inputs
|
||||
# from dictionaries to lists, and sort them by each item's position
|
||||
# in the input sequence.
|
||||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||||
|
||||
mm_features = []
|
||||
for modality, idx in sorted_mm_idxs:
|
||||
base_mm_hash = decoder_mm_hashes[modality][idx]
|
||||
mm_features.append(
|
||||
MultiModalFeatureSpec(
|
||||
data=decoder_mm_inputs[modality][idx],
|
||||
modality=modality,
|
||||
identifier=self._get_mm_identifier(
|
||||
base_mm_hash,
|
||||
lora_request,
|
||||
),
|
||||
mm_position=decoder_mm_positions[modality][idx],
|
||||
mm_hash=base_mm_hash,
|
||||
)
|
||||
)
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_embeds=prompt_embeds,
|
||||
mm_features=mm_features,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
cache_salt=decoder_inputs.get("cache_salt"),
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
trace_headers=trace_headers,
|
||||
resumable=resumable,
|
||||
)
|
||||
|
||||
def _validate_prompt_len(
|
||||
self,
|
||||
prompt_len: int,
|
||||
prompt_type: Literal["encoder", "decoder"],
|
||||
):
|
||||
if self.skip_prompt_length_check:
|
||||
return
|
||||
|
||||
if prompt_len == 0 and prompt_type == "decoder":
|
||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||
|
||||
model_config = self.model_config
|
||||
max_prompt_len = (
|
||||
model_config.max_model_len
|
||||
if prompt_type == "decoder"
|
||||
else self.mm_encoder_cache_size
|
||||
)
|
||||
if prompt_len > max_prompt_len:
|
||||
if self.supports_mm_inputs:
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens plus multimodal tokens. For image "
|
||||
"inputs, the number of image tokens depends on the number "
|
||||
"of images, and possibly their aspect ratios as well."
|
||||
)
|
||||
else:
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"The {prompt_type} prompt (length {prompt_len}) is "
|
||||
f"longer than the maximum model length of {max_prompt_len}. "
|
||||
f"{suggestion}"
|
||||
)
|
||||
elif prompt_len == max_prompt_len and model_config.runner_type == "generate":
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens (prompt + requested output tokens)."
|
||||
)
|
||||
raise ValueError(
|
||||
f"The {prompt_type} prompt (length {prompt_len}) plus the number of "
|
||||
f"requested output tokens (at least 1) is longer than the maximum "
|
||||
f"model length of {max_prompt_len}. {suggestion}"
|
||||
)
|
||||
|
||||
def _validate_model_input(
|
||||
self,
|
||||
prompt_inputs: SingletonInputs,
|
||||
prompt_type: Literal["encoder", "decoder"],
|
||||
) -> None:
|
||||
model_config = self.model_config
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
prompt_ids = (
|
||||
None
|
||||
if prompt_inputs["type"] == "embeds"
|
||||
else prompt_inputs["prompt_token_ids"]
|
||||
)
|
||||
prompt_embeds = (
|
||||
prompt_inputs["prompt_embeds"]
|
||||
if prompt_inputs["type"] == "embeds"
|
||||
else None
|
||||
)
|
||||
|
||||
prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds)
|
||||
self._validate_prompt_len(prompt_len, prompt_type)
|
||||
|
||||
if prompt_inputs["type"] == "multimodal":
|
||||
decoder_mm_positions = prompt_inputs["mm_placeholders"]
|
||||
for modality, mm_positions in decoder_mm_positions.items():
|
||||
for mm_position in mm_positions:
|
||||
embed_length = mm_position.get_num_embeds()
|
||||
if embed_length > self.mm_encoder_cache_size:
|
||||
raise ValueError(
|
||||
f"The {prompt_type} prompt contains a(n) {modality} item "
|
||||
f"with length {embed_length}, which exceeds the "
|
||||
f"pre-allocated encoder cache size "
|
||||
f"{self.mm_encoder_cache_size}. Please reduce the input "
|
||||
f"size or increase the encoder cache size "
|
||||
f"by setting --limit-mm-per-prompt at startup."
|
||||
)
|
||||
|
||||
if prompt_ids and tokenizer is not None:
|
||||
max_input_id = max(prompt_ids, default=0)
|
||||
|
||||
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
|
||||
# self.model_config.get_vocab_size() is the model’s vocab size.
|
||||
# For Qwen3 models, the language model has extra tokens that do
|
||||
# not exist in the tokenizer, and vice versa for multimodal
|
||||
# placeholder tokens in some multimodal models.
|
||||
# See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
|
||||
# and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501
|
||||
|
||||
# Here we take the max of the two to determine if a token id is
|
||||
# truly out-of-vocabulary.
|
||||
model_vocab_size = model_config.get_vocab_size()
|
||||
if max_input_id > max(tokenizer.max_token_id, model_vocab_size - 1):
|
||||
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
|
||||
|
||||
def _validate_model_inputs(
|
||||
self,
|
||||
encoder_inputs: SingletonInputs | None,
|
||||
decoder_inputs: SingletonInputs,
|
||||
):
|
||||
if encoder_inputs is not None:
|
||||
self._validate_model_input(encoder_inputs, prompt_type="encoder")
|
||||
|
||||
self._validate_model_input(decoder_inputs, prompt_type="decoder")
|
||||
429
vllm/v1/engine/llm_engine.py
Normal file
429
vllm/v1/engine/llm_engine.py
Normal file
@@ -0,0 +1,429 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from collections.abc import Callable, Mapping
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
|
||||
import torch.nn as nn
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import renderer_from_config
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine import EngineCoreRequest, PauseMode
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
from vllm.v1.utils import record_function_or_nullcontext
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""Legacy LLMEngine for backwards compatibility."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
aggregate_engine_logging: bool = False,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: list[StatLoggerFactory] | None = None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
multiprocess_mode: bool = False,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
tracing_endpoint = self.observability_config.otlp_traces_endpoint
|
||||
if tracing_endpoint is not None:
|
||||
init_tracer("vllm.llm_engine", tracing_endpoint)
|
||||
|
||||
self.log_stats = log_stats
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
executor_backend = parallel_config.distributed_executor_backend
|
||||
|
||||
self.external_launcher_dp = (
|
||||
parallel_config.data_parallel_size > 1
|
||||
and executor_backend == "external_launcher"
|
||||
)
|
||||
# important: init dp group before init the engine_core
|
||||
# In the decoupled engine case this is handled in EngineCoreProc.
|
||||
if (
|
||||
not multiprocess_mode
|
||||
and parallel_config.data_parallel_size > 1
|
||||
and not self.external_launcher_dp
|
||||
):
|
||||
self.dp_group = parallel_config.stateless_init_dp_group()
|
||||
else:
|
||||
self.dp_group = None
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
self.renderer = renderer = renderer_from_config(self.vllm_config)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
# Convert TokPrompt --> EngineCoreRequest.
|
||||
self.input_processor = InputProcessor(self.vllm_config, renderer)
|
||||
|
||||
# Converts EngineCoreOutputs --> RequestOutput.
|
||||
self.output_processor = OutputProcessor(
|
||||
renderer.tokenizer,
|
||||
log_stats=self.log_stats,
|
||||
stream_interval=self.vllm_config.scheduler_config.stream_interval,
|
||||
tracing_enabled=tracing_endpoint is not None,
|
||||
)
|
||||
|
||||
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
||||
self.engine_core = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocess_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
self.logger_manager: StatLoggerManager | None = None
|
||||
if self.log_stats:
|
||||
self.logger_manager = StatLoggerManager(
|
||||
vllm_config=vllm_config,
|
||||
custom_stat_loggers=stat_loggers,
|
||||
enable_default_loggers=log_stats,
|
||||
aggregate_engine_logging=aggregate_engine_logging,
|
||||
)
|
||||
self.logger_manager.log_engine_initialized()
|
||||
|
||||
if not multiprocess_mode:
|
||||
# for v0 compatibility
|
||||
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
||||
|
||||
if self.external_launcher_dp:
|
||||
# If we use DP in external launcher mode, we reuse the
|
||||
# existing DP group used for data communication.
|
||||
self.dp_group = get_dp_group().cpu_group
|
||||
|
||||
# Don't keep the dummy data in memory
|
||||
self.reset_mm_cache()
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: list[StatLoggerFactory] | None = None,
|
||||
disable_log_stats: bool = False,
|
||||
) -> "LLMEngine":
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=(not disable_log_stats),
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: list[StatLoggerFactory] | None = None,
|
||||
enable_multiprocessing: bool = False,
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||
logger.debug("Enabling multiprocessing for LLMEngine.")
|
||||
enable_multiprocessing = True
|
||||
|
||||
# Create the LLMEngine.
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=enable_multiprocessing,
|
||||
)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.output_processor.get_num_unfinished_requests()
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
has_unfinished = self.output_processor.has_unfinished_requests()
|
||||
if self.dp_group is None:
|
||||
return has_unfinished or self.engine_core.dp_engines_running()
|
||||
return self.has_unfinished_requests_dp(has_unfinished)
|
||||
|
||||
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
|
||||
aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
|
||||
self.dp_group, has_unfinished
|
||||
)
|
||||
if not has_unfinished and aggregated_has_unfinished:
|
||||
self.should_execute_dummy_batch = True
|
||||
return aggregated_has_unfinished
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
if not hasattr(self, "_supported_tasks"):
|
||||
# Cache the result
|
||||
self._supported_tasks = self.engine_core.get_supported_tasks()
|
||||
|
||||
return self._supported_tasks
|
||||
|
||||
def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
|
||||
"""Remove request_ids from EngineCore and Detokenizer."""
|
||||
|
||||
request_ids = self.output_processor.abort_requests(request_ids, internal)
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: EngineCoreRequest | PromptType | ProcessorInputs,
|
||||
params: SamplingParams | PoolingParams,
|
||||
arrival_time: float | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
prompt_text: str | None = None,
|
||||
) -> str:
|
||||
# Validate the request_id type.
|
||||
if not isinstance(request_id, str):
|
||||
raise TypeError(f"request_id must be a string, got {type(request_id)}")
|
||||
|
||||
# Process raw inputs into the request.
|
||||
if isinstance(prompt, EngineCoreRequest):
|
||||
logger.warning_once(
|
||||
"Passing EngineCoreRequest to LLMEngine.generate() and .add_requests() "
|
||||
"is deprecated and will be removed in v0.18. You should instead pass "
|
||||
"the outputs of Renderer.render_cmpl() or Renderer.render_chat()."
|
||||
)
|
||||
|
||||
request = prompt
|
||||
if request_id != request.request_id:
|
||||
logger.warning_once(
|
||||
"LLMEngine.add_request() was passed a request_id parameter that "
|
||||
"does not match the EngineCoreRequest.request_id attribute. The "
|
||||
"latter will be used, and the former will be ignored."
|
||||
)
|
||||
else:
|
||||
request = self.input_processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
supported_tasks=self.get_supported_tasks(),
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
)
|
||||
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
|
||||
|
||||
self.input_processor.assign_request_id(request)
|
||||
|
||||
req_id = request.request_id
|
||||
|
||||
# Use cloned params that may have been updated in process_inputs()
|
||||
params = request.params
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
|
||||
if n == 1:
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request, prompt_text, None, 0)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(request)
|
||||
return req_id
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_req = ParentRequest(request)
|
||||
for idx in range(n):
|
||||
request_id, child_params = parent_req.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = child_params
|
||||
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(
|
||||
child_request, prompt_text, parent_req, idx
|
||||
)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(child_request)
|
||||
|
||||
return req_id
|
||||
|
||||
def step(self) -> list[RequestOutput | PoolingRequestOutput]:
|
||||
if self.should_execute_dummy_batch:
|
||||
self.should_execute_dummy_batch = False
|
||||
self.engine_core.execute_dummy_batch()
|
||||
return []
|
||||
|
||||
# 1) Get EngineCoreOutput from the EngineCore.
|
||||
with record_function_or_nullcontext("llm_engine step: get_output"):
|
||||
outputs = self.engine_core.get_output()
|
||||
|
||||
# 2) Process EngineCoreOutputs.
|
||||
with record_function_or_nullcontext("llm_engine step: process_outputs"):
|
||||
iteration_stats = IterationStats() if self.log_stats else None
|
||||
processed_outputs = self.output_processor.process_outputs(
|
||||
outputs.outputs,
|
||||
engine_core_timestamp=outputs.timestamp,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
with record_function_or_nullcontext("llm_engine step: abort_requests"):
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
|
||||
# 4) Record stats
|
||||
with record_function_or_nullcontext("llm_engine step: record_stats"):
|
||||
if (
|
||||
self.logger_manager is not None
|
||||
and outputs.scheduler_stats is not None
|
||||
and len(outputs.outputs) > 0
|
||||
):
|
||||
self.logger_manager.record(
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
mm_cache_stats=self.renderer.stat_mm_cache(),
|
||||
)
|
||||
self.do_log_stats_with_interval()
|
||||
|
||||
return processed_outputs.request_outputs
|
||||
|
||||
def start_profile(self, profile_prefix: str | None = None):
|
||||
self.engine_core.profile(True, profile_prefix)
|
||||
|
||||
def stop_profile(self):
|
||||
self.engine_core.profile(False)
|
||||
|
||||
def reset_mm_cache(self):
|
||||
self.renderer.clear_mm_cache()
|
||||
self.engine_core.reset_mm_cache()
|
||||
|
||||
def reset_prefix_cache(
|
||||
self, reset_running_requests: bool = False, reset_connector: bool = False
|
||||
) -> bool:
|
||||
return self.engine_core.reset_prefix_cache(
|
||||
reset_running_requests, reset_connector
|
||||
)
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Reset the encoder cache to invalidate all cached encoder outputs.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale vision embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.engine_core.reset_encoder_cache()
|
||||
|
||||
def sleep(self, level: int = 1, mode: PauseMode = "abort"):
|
||||
self.engine_core.sleep(level, mode)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
self.logger_manager.record_sleep_state(1, level)
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None):
|
||||
self.engine_core.wake_up(tags)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
self.logger_manager.record_sleep_state(0, 0)
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
return self.engine_core.is_sleeping()
|
||||
|
||||
def get_metrics(self) -> list[Metric]:
|
||||
assert self.log_stats, "Stat logging disabled"
|
||||
return get_metrics_snapshot()
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
return self.renderer.tokenizer
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.renderer.get_tokenizer()
|
||||
|
||||
def do_log_stats(self) -> None:
|
||||
"""Log stats if logging is enabled."""
|
||||
if self.logger_manager:
|
||||
self.logger_manager.log()
|
||||
|
||||
def do_log_stats_with_interval(self) -> None:
|
||||
"""Log stats when the time interval has passed."""
|
||||
now = time.time()
|
||||
if not hasattr(self, "_last_log_time"):
|
||||
self._last_log_time = now
|
||||
if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL:
|
||||
self.do_log_stats()
|
||||
self._last_log_time = now
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
return self.engine_core.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove an already loaded LoRA adapter."""
|
||||
return self.engine_core.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
"""List all registered adapters."""
|
||||
return self.engine_core.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
def collective_rpc(
|
||||
self,
|
||||
method: str | Callable[[WorkerBase], _R],
|
||||
timeout: float | None = None,
|
||||
args: tuple = (),
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> list[_R]:
|
||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||
return self.collective_rpc("apply_model", args=(func,))
|
||||
|
||||
def __del__(self):
|
||||
dp_group = getattr(self, "dp_group", None)
|
||||
if dp_group is not None and not self.external_launcher_dp:
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
245
vllm/v1/engine/logprobs.py
Normal file
245
vllm/v1/engine/logprobs.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import (
|
||||
PromptLogprobs,
|
||||
SampleLogprobs,
|
||||
append_logprobs_for_next_position,
|
||||
create_prompt_logprobs,
|
||||
create_sample_logprobs,
|
||||
)
|
||||
from vllm.tokenizers.detokenizer_utils import (
|
||||
TokenizerLike,
|
||||
convert_ids_list_to_tokens,
|
||||
)
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NONES = itertools.repeat(None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogprobsProcessor:
|
||||
# Tokenizer for this request,
|
||||
# None if detokenization is disabled.
|
||||
tokenizer: TokenizerLike | None
|
||||
|
||||
# Logprobs for this request
|
||||
logprobs: SampleLogprobs | None
|
||||
prompt_logprobs: PromptLogprobs | None
|
||||
cumulative_logprob: float | None
|
||||
num_logprobs: int | None
|
||||
num_prompt_logprobs: int | None
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: TokenizerLike | None,
|
||||
request: EngineCoreRequest,
|
||||
) -> "LogprobsProcessor":
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
num_logprobs = sampling_params.logprobs
|
||||
num_prompt_logprobs = sampling_params.prompt_logprobs
|
||||
return cls(
|
||||
tokenizer=tokenizer,
|
||||
cumulative_logprob=(None if num_logprobs is None else 0.0),
|
||||
logprobs=(
|
||||
None
|
||||
if num_logprobs is None
|
||||
else create_sample_logprobs(sampling_params.flat_logprobs)
|
||||
),
|
||||
prompt_logprobs=(
|
||||
None
|
||||
if num_prompt_logprobs is None
|
||||
else create_prompt_logprobs(sampling_params.flat_logprobs)
|
||||
),
|
||||
num_prompt_logprobs=num_prompt_logprobs,
|
||||
num_logprobs=num_logprobs,
|
||||
)
|
||||
|
||||
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
|
||||
"""Update with sample logprobs from EngineCore.
|
||||
|
||||
Outer lists are only of len > 1 if EngineCore made
|
||||
>1 tokens in prior step (e.g. in spec decoding).
|
||||
|
||||
Args:
|
||||
logprobs_lists: the lists of logprob tokens, logprobs, and ranks.
|
||||
|
||||
"""
|
||||
|
||||
assert self.num_logprobs is not None
|
||||
assert self.logprobs is not None
|
||||
assert self.cumulative_logprob is not None
|
||||
|
||||
token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists
|
||||
|
||||
for rank_np, logprobs_np, token_ids_np in zip(
|
||||
ranks_lst, logprobs_lst, token_ids_lst
|
||||
):
|
||||
rank = rank_np.tolist()
|
||||
logprobs = logprobs_np.tolist()
|
||||
token_ids = token_ids_np.tolist()
|
||||
# Detokenize (non-incrementally).
|
||||
decoded_tokens: list[str] | Iterable[None]
|
||||
if self.tokenizer is None:
|
||||
decoded_tokens = NONES
|
||||
else:
|
||||
decoded_tokens_list = convert_ids_list_to_tokens(
|
||||
self.tokenizer, token_ids
|
||||
)
|
||||
decoded_tokens = self._verify_tokens(
|
||||
decoded_tokens_list=decoded_tokens_list, tokens=token_ids
|
||||
)
|
||||
|
||||
# Sampler puts the sampled logprob in first.
|
||||
sampled_token_logprob = logprobs[0]
|
||||
self.cumulative_logprob += sampled_token_logprob
|
||||
|
||||
# Update with the Logprob container for this pos.
|
||||
append_logprobs_for_next_position(
|
||||
self.logprobs,
|
||||
token_ids,
|
||||
logprobs,
|
||||
decoded_tokens,
|
||||
rank,
|
||||
self.num_logprobs,
|
||||
)
|
||||
|
||||
def _update_prompt_logprobs(
|
||||
self,
|
||||
prompt_logprobs_tensors: LogprobsTensors,
|
||||
) -> None:
|
||||
"""Update with prompt logprobs from EngineCore.
|
||||
|
||||
Args:
|
||||
prompt_logprobs_tensors: tuple containing the prompt logprobs
|
||||
tensors.
|
||||
|
||||
"""
|
||||
|
||||
# Prompt logprobs are enabled.
|
||||
assert self.num_prompt_logprobs is not None
|
||||
assert self.prompt_logprobs is not None
|
||||
|
||||
token_ids, logprobs, ranks, _ = prompt_logprobs_tensors
|
||||
|
||||
# Recover shapes.
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
all_decoded_tokens: list[str] | None = (
|
||||
None
|
||||
if self.tokenizer is None
|
||||
else convert_ids_list_to_tokens(
|
||||
self.tokenizer, token_ids.flatten().tolist()
|
||||
)
|
||||
)
|
||||
|
||||
# Pythonize the torch tensors.
|
||||
prompt_token_ranks = ranks.tolist()
|
||||
prompt_logprobs = logprobs.tolist()
|
||||
token_ids_list = token_ids.tolist()
|
||||
|
||||
# Make Logprob for each position.
|
||||
for pos in range(num_prompt_tokens):
|
||||
# Handle flattening and UTF-8 correction per position
|
||||
offset = pos * num_logprobs
|
||||
offset_end = offset + num_logprobs
|
||||
|
||||
decoded_tokens_for_pos: list[str] | Iterable[None]
|
||||
if all_decoded_tokens is None:
|
||||
decoded_tokens_for_pos = NONES
|
||||
else:
|
||||
# Extract decoded tokens for this position
|
||||
decoded_tokens_slice = all_decoded_tokens[offset:offset_end]
|
||||
# Apply UTF-8 correction within this position's token boundaries
|
||||
decoded_tokens_for_pos = self._verify_tokens(
|
||||
decoded_tokens_list=decoded_tokens_slice, tokens=token_ids_list[pos]
|
||||
)
|
||||
|
||||
# Update with the Logprob container for this pos.
|
||||
append_logprobs_for_next_position(
|
||||
self.prompt_logprobs,
|
||||
token_ids_list[pos],
|
||||
prompt_logprobs[pos],
|
||||
decoded_tokens_for_pos,
|
||||
prompt_token_ranks[pos],
|
||||
self.num_prompt_logprobs,
|
||||
)
|
||||
|
||||
def pop_prompt_logprobs(self) -> PromptLogprobs | None:
|
||||
"""Pop and return all request prompt logprobs
|
||||
|
||||
The logprobs processor aggregates prompt chunk logprobs
|
||||
over one or more prefill chunks. This method returns
|
||||
all prompt logprobs at once and then forgets them.
|
||||
Ensures correct RequestOutputKind.DELTA semantics
|
||||
wherein all prompt logprobs are returned at once at
|
||||
the end of prefill.
|
||||
|
||||
Returns:
|
||||
None if prompt logprobs are disabled for this request.
|
||||
List of all prompt logprobs, otherwise.
|
||||
"""
|
||||
plp = self.prompt_logprobs
|
||||
if plp:
|
||||
self.prompt_logprobs = []
|
||||
return plp
|
||||
|
||||
def _correct_decoded_token(self, idx: int, tokens: list[int]) -> str:
|
||||
assert self.tokenizer is not None, "self.tokenizer should not be None"
|
||||
|
||||
# try with prev token id in same list
|
||||
if idx > 0:
|
||||
possible_decoded_token = self.tokenizer.decode(tokens[idx - 1 : idx + 1])
|
||||
if not possible_decoded_token.endswith("<EFBFBD>"):
|
||||
return possible_decoded_token
|
||||
# try with previous logprob token id
|
||||
if self.logprobs:
|
||||
latest_token_id = next(iter(self.logprobs[-1]))
|
||||
|
||||
decode_ids = [latest_token_id]
|
||||
if idx > 0:
|
||||
decode_ids.extend(tokens[idx - 1 : idx + 1])
|
||||
else:
|
||||
decode_ids.extend(tokens[idx : idx + 1])
|
||||
|
||||
possible_decoded_token = self.tokenizer.decode(decode_ids)
|
||||
if not possible_decoded_token.endswith("<EFBFBD>"):
|
||||
return possible_decoded_token
|
||||
|
||||
# by default return empty string
|
||||
return ""
|
||||
|
||||
def _verify_tokens(
|
||||
self, decoded_tokens_list: list[str], tokens: list[int]
|
||||
) -> list[str]:
|
||||
corrected_decoded_token_map = dict()
|
||||
for idx, text in enumerate(decoded_tokens_list):
|
||||
if text.endswith("<EFBFBD>"):
|
||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||
# from byte fallback tokenization.
|
||||
corrected_decoded_token_map[idx] = self._correct_decoded_token(
|
||||
idx, tokens
|
||||
)
|
||||
|
||||
for idx, text in corrected_decoded_token_map.items():
|
||||
decoded_tokens_list[idx] = text
|
||||
|
||||
return decoded_tokens_list
|
||||
|
||||
def update_from_output(self, output: EngineCoreOutput) -> None:
|
||||
if output.new_logprobs is not None:
|
||||
self._update_sample_logprobs(output.new_logprobs)
|
||||
if output.new_prompt_logprobs_tensors is not None:
|
||||
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)
|
||||
807
vllm/v1/engine/output_processor.py
Normal file
807
vllm/v1/engine/output_processor.py
Normal file
@@ -0,0 +1,807 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import (
|
||||
STREAM_FINISHED,
|
||||
CompletionOutput,
|
||||
PoolingOutput,
|
||||
PoolingRequestOutput,
|
||||
RequestOutput,
|
||||
)
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tracing import (
|
||||
SpanAttributes,
|
||||
SpanKind,
|
||||
extract_trace_context,
|
||||
instrument_manual,
|
||||
)
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.metrics.stats import (
|
||||
IterationStats,
|
||||
LoRARequestStates,
|
||||
RequestStateStats,
|
||||
SchedulerStats,
|
||||
)
|
||||
|
||||
# shared empty CPU tensor used as a placeholder pooling output
|
||||
EMPTY_CPU_TENSOR = torch.empty(0, device="cpu")
|
||||
|
||||
|
||||
class RequestOutputCollector:
|
||||
"""
|
||||
Collects streamed RequestOutputs per individual request,
|
||||
for hand-off to the consuming asyncio generate task.
|
||||
|
||||
When streaming deltas, RequestOutputs are merged if the
|
||||
producer gets ahead of the consumer.
|
||||
"""
|
||||
|
||||
def __init__(self, output_kind: RequestOutputKind, request_id: str):
|
||||
self.aggregate = output_kind == RequestOutputKind.DELTA
|
||||
self.request_id = request_id
|
||||
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
|
||||
self.ready = asyncio.Event()
|
||||
|
||||
self._input_stream_task: asyncio.Task | None = None
|
||||
|
||||
def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
|
||||
"""Non-blocking put operation."""
|
||||
if self.output is None or isinstance(output, Exception):
|
||||
self.output = output
|
||||
self.ready.set()
|
||||
elif isinstance(self.output, RequestOutput) and isinstance(
|
||||
output, RequestOutput
|
||||
):
|
||||
# This ensures that request outputs with different request indexes
|
||||
# (if n > 1) do not override each other.
|
||||
self.output.add(output, aggregate=self.aggregate)
|
||||
elif isinstance(self.output, PoolingRequestOutput) and isinstance(
|
||||
output, PoolingRequestOutput
|
||||
):
|
||||
self.output = output
|
||||
|
||||
async def get(self) -> RequestOutput | PoolingRequestOutput:
|
||||
"""Get operation blocks on put event."""
|
||||
while (output := self.output) is None:
|
||||
await self.ready.wait()
|
||||
self.output = None
|
||||
self.ready.clear()
|
||||
if isinstance(output, Exception):
|
||||
raise output
|
||||
return output
|
||||
|
||||
def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
|
||||
"""Non-blocking get operation."""
|
||||
output = self.output
|
||||
if output is not None:
|
||||
self.output = None
|
||||
self.ready.clear()
|
||||
if isinstance(output, Exception):
|
||||
raise output
|
||||
return output
|
||||
|
||||
def close(self):
|
||||
if self._input_stream_task is not None:
|
||||
self._input_stream_task.cancel()
|
||||
self._input_stream_task = None
|
||||
|
||||
def __del__(self):
|
||||
if (task := self._input_stream_task) is not None:
|
||||
task.get_loop().call_soon_threadsafe(task.cancel)
|
||||
self._input_stream_task = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputProcessorOutput:
|
||||
request_outputs: list[RequestOutput | PoolingRequestOutput]
|
||||
reqs_to_abort: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingUpdate:
|
||||
"""Streaming input update data for output processor.
|
||||
|
||||
Contains the incremental prompt data to be applied to a request state
|
||||
when the current sub-request completes.
|
||||
"""
|
||||
|
||||
prompt: str | None
|
||||
prompt_token_ids: list[int] | None
|
||||
arrival_time: float
|
||||
final: bool = False
|
||||
|
||||
|
||||
class RequestState:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
external_req_id: str,
|
||||
parent_req: ParentRequest | None,
|
||||
request_index: int,
|
||||
lora_request: LoRARequest | None,
|
||||
output_kind: RequestOutputKind,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None,
|
||||
prompt_embeds: torch.Tensor | None,
|
||||
logprobs_processor: LogprobsProcessor | None,
|
||||
detokenizer: IncrementalDetokenizer | None,
|
||||
max_tokens_param: int | None,
|
||||
arrival_time: float,
|
||||
queue: RequestOutputCollector | None,
|
||||
log_stats: bool,
|
||||
stream_interval: int,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
temperature: float | None = None,
|
||||
stream_input: bool = False,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.external_req_id = external_req_id
|
||||
self.parent_req = parent_req
|
||||
self.request_index = request_index
|
||||
self.lora_request = lora_request
|
||||
self.lora_name = lora_request.lora_name if lora_request is not None else None
|
||||
self.output_kind = output_kind
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_embeds = prompt_embeds
|
||||
self.prompt_len = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds
|
||||
)
|
||||
self.logprobs_processor = logprobs_processor
|
||||
self.detokenizer = detokenizer
|
||||
self.max_tokens_param = max_tokens_param
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.temperature = temperature
|
||||
self.is_prefilling = True
|
||||
self.queue = queue
|
||||
self.num_cached_tokens = 0
|
||||
|
||||
self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
|
||||
|
||||
# Stream Interval
|
||||
self.stream_interval = stream_interval
|
||||
self.sent_tokens_offset = 0 # Offset of sent tokens
|
||||
|
||||
# Streaming input queue
|
||||
self.streaming_input = stream_input
|
||||
self.input_chunk_queue: deque[StreamingUpdate] | None = (
|
||||
deque() if stream_input else None
|
||||
)
|
||||
|
||||
def apply_streaming_update(self, update: StreamingUpdate) -> None:
|
||||
# Apply the update to the request state.
|
||||
self.streaming_input = not update.final
|
||||
# TODO also include relevant output tokens in new prompt here
|
||||
# (match scheduler behavior).
|
||||
if update.prompt:
|
||||
self.prompt = (
|
||||
(self.prompt + update.prompt) if self.prompt else update.prompt
|
||||
)
|
||||
if self.prompt_token_ids:
|
||||
self.prompt_token_ids.extend(update.prompt_token_ids or ())
|
||||
else:
|
||||
self.prompt_token_ids = update.prompt_token_ids or []
|
||||
assert self.prompt_token_ids is not None
|
||||
self.prompt_len = len(self.prompt_token_ids)
|
||||
if self.stats is not None:
|
||||
self.stats.arrival_time = update.arrival_time
|
||||
self.is_prefilling = True
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: TokenizerLike | None,
|
||||
request: EngineCoreRequest,
|
||||
prompt: str | None,
|
||||
parent_req: ParentRequest | None,
|
||||
request_index: int,
|
||||
queue: RequestOutputCollector | None,
|
||||
log_stats: bool,
|
||||
stream_interval: int,
|
||||
) -> "RequestState":
|
||||
if sampling_params := request.sampling_params:
|
||||
if not sampling_params.detokenize:
|
||||
tokenizer = None
|
||||
output_kind = sampling_params.output_kind
|
||||
logprobs_processor = LogprobsProcessor.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
)
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
)
|
||||
max_tokens_param = sampling_params.max_tokens
|
||||
top_p = sampling_params.top_p
|
||||
n = sampling_params.n
|
||||
temperature = sampling_params.temperature
|
||||
else:
|
||||
logprobs_processor = None
|
||||
detokenizer = None
|
||||
max_tokens_param = None
|
||||
top_p = None
|
||||
n = None
|
||||
temperature = None
|
||||
assert request.pooling_params is not None
|
||||
output_kind = request.pooling_params.output_kind
|
||||
|
||||
assert request.external_req_id is not None
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
external_req_id=request.external_req_id,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
lora_request=request.lora_request,
|
||||
output_kind=output_kind,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
logprobs_processor=logprobs_processor,
|
||||
detokenizer=detokenizer,
|
||||
max_tokens_param=max_tokens_param,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
temperature=temperature,
|
||||
arrival_time=request.arrival_time,
|
||||
queue=queue,
|
||||
log_stats=log_stats,
|
||||
stream_interval=stream_interval,
|
||||
stream_input=request.resumable,
|
||||
)
|
||||
|
||||
def make_request_output(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
pooling_output: torch.Tensor | None,
|
||||
finish_reason: FinishReason | None,
|
||||
stop_reason: int | str | None,
|
||||
kv_transfer_params: dict[str, Any] | None = None,
|
||||
routed_experts: np.ndarray | None = None,
|
||||
) -> RequestOutput | PoolingRequestOutput | None:
|
||||
finished = finish_reason is not None
|
||||
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
|
||||
|
||||
if not finished and final_only:
|
||||
# Only the final output is required in FINAL_ONLY mode.
|
||||
return None
|
||||
|
||||
if self.stream_interval > 1:
|
||||
assert self.detokenizer is not None
|
||||
|
||||
# Send output request only when
|
||||
# 1. It has finished, or
|
||||
# 2. It is the first token, or
|
||||
# 3. It has reached the stream interval number of tokens
|
||||
if not (
|
||||
finished
|
||||
or self.sent_tokens_offset == 0
|
||||
or self.detokenizer.num_output_tokens() - self.sent_tokens_offset
|
||||
>= self.stream_interval
|
||||
):
|
||||
return None
|
||||
|
||||
if self.output_kind == RequestOutputKind.DELTA:
|
||||
# Send tokens from the offset in DELTA mode, otherwise all
|
||||
# tokens are sent.
|
||||
new_token_ids = self.detokenizer.output_token_ids[
|
||||
self.sent_tokens_offset :
|
||||
]
|
||||
self.sent_tokens_offset = self.detokenizer.num_output_tokens()
|
||||
|
||||
external_req_id = self.external_req_id
|
||||
|
||||
if pooling_output is not None:
|
||||
return self._new_request_output(
|
||||
external_req_id,
|
||||
[self._new_pooling_output(pooling_output)],
|
||||
finished,
|
||||
)
|
||||
|
||||
output = self._new_completion_output(
|
||||
new_token_ids, finish_reason, stop_reason, routed_experts
|
||||
)
|
||||
|
||||
if self.parent_req is None:
|
||||
outputs = [output]
|
||||
else:
|
||||
outputs, finished = self.parent_req.get_outputs(self.request_id, output)
|
||||
if not outputs:
|
||||
return None
|
||||
external_req_id = self.parent_req.external_req_id
|
||||
|
||||
return self._new_request_output(
|
||||
external_req_id, outputs, finished, kv_transfer_params
|
||||
)
|
||||
|
||||
def _new_request_output(
|
||||
self,
|
||||
external_req_id: str,
|
||||
outputs: list[CompletionOutput] | list[PoolingOutput],
|
||||
finished: bool,
|
||||
kv_transfer_params: dict[str, Any] | None = None,
|
||||
) -> RequestOutput | PoolingRequestOutput:
|
||||
# If prompt embeds were used, put placeholder prompt token ids
|
||||
prompt_token_ids = self.prompt_token_ids
|
||||
if prompt_token_ids is None and self.prompt_embeds is not None:
|
||||
prompt_token_ids = [0] * len(self.prompt_embeds)
|
||||
assert prompt_token_ids is not None
|
||||
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, PoolingOutput):
|
||||
assert len(outputs) == 1
|
||||
return PoolingRequestOutput(
|
||||
request_id=external_req_id,
|
||||
outputs=first_output,
|
||||
num_cached_tokens=self.num_cached_tokens,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
finished=finished,
|
||||
)
|
||||
assert self.logprobs_processor is not None
|
||||
if self.output_kind == RequestOutputKind.DELTA:
|
||||
# Side effect: logprobs processor forgets prompt logprobs
|
||||
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
|
||||
else:
|
||||
prompt_logprobs = self.logprobs_processor.prompt_logprobs
|
||||
|
||||
return RequestOutput(
|
||||
request_id=external_req_id, # request_id is what was provided externally
|
||||
lora_request=self.lora_request,
|
||||
prompt=self.prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
outputs=cast(list[CompletionOutput], outputs),
|
||||
finished=finished,
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
num_cached_tokens=self.num_cached_tokens,
|
||||
metrics=self.stats,
|
||||
)
|
||||
|
||||
def _new_completion_output(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
finish_reason: FinishReason | None,
|
||||
stop_reason: int | str | None,
|
||||
routed_experts: np.ndarray | None = None,
|
||||
) -> CompletionOutput:
|
||||
assert self.detokenizer is not None
|
||||
assert self.logprobs_processor is not None
|
||||
finished = finish_reason is not None
|
||||
delta = self.output_kind == RequestOutputKind.DELTA
|
||||
|
||||
# Prepare text and token_ids, based on delta mode
|
||||
text = self.detokenizer.get_next_output_text(finished, delta)
|
||||
if not delta:
|
||||
token_ids = self.detokenizer.output_token_ids
|
||||
|
||||
# Prepare logprobs, based on delta mode
|
||||
logprobs = self.logprobs_processor.logprobs
|
||||
if delta and logprobs:
|
||||
logprobs = logprobs[-len(token_ids) :]
|
||||
|
||||
return CompletionOutput(
|
||||
index=self.request_index,
|
||||
text=text,
|
||||
token_ids=token_ids,
|
||||
routed_experts=routed_experts,
|
||||
logprobs=logprobs,
|
||||
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
|
||||
finish_reason=str(finish_reason) if finished else None,
|
||||
stop_reason=stop_reason if finished else None,
|
||||
)
|
||||
|
||||
def _new_pooling_output(self, pooling_output: torch.Tensor) -> PoolingOutput:
|
||||
return PoolingOutput(data=pooling_output)
|
||||
|
||||
|
||||
class OutputProcessor:
|
||||
"""Process EngineCoreOutputs into RequestOutputs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
log_stats: bool,
|
||||
stream_interval: int = 1,
|
||||
tracing_enabled: bool = False,
|
||||
):
|
||||
self.log_stats = log_stats
|
||||
self.tokenizer = tokenizer
|
||||
self.stream_interval = stream_interval
|
||||
self.request_states: dict[str, RequestState] = {}
|
||||
self.parent_requests: dict[str, ParentRequest] = {}
|
||||
self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
|
||||
self.lora_states = LoRARequestStates(log_stats)
|
||||
self.tracing_enabled = tracing_enabled
|
||||
|
||||
def get_num_unfinished_requests(self):
|
||||
return len(self.request_states)
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return len(self.request_states) > 0
|
||||
|
||||
def propagate_error(self, e: Exception):
|
||||
"""Propagate error to all generate() tasks."""
|
||||
|
||||
for _, state in self.request_states.items():
|
||||
assert state.queue is not None
|
||||
state.queue.put(e)
|
||||
|
||||
def abort_requests(self, request_ids: Iterable[str], internal: bool) -> list[str]:
|
||||
"""Abort a list of requests.
|
||||
|
||||
The request_ids may be either external request IDs (those passed to
|
||||
InputProcessor.process_inputs()) or internal request IDs (those randomly
|
||||
generated when creating the EngineCoreRequest).
|
||||
|
||||
If an external request ID is provided, and that external request ID
|
||||
was used for multiple requests, all requests associated with that external
|
||||
request ID are aborted.
|
||||
|
||||
In the case of parallel sampling, a request ID may be used to identify
|
||||
a parent request, in which case the associated child requests are aborted
|
||||
also.
|
||||
"""
|
||||
internal_req_ids = []
|
||||
for request_id in request_ids:
|
||||
if internal:
|
||||
# Internal ID - this may be a parent request
|
||||
internal_req_ids.append(request_id)
|
||||
|
||||
# Remove internal ID from the external->internal mapping
|
||||
if req_state := self.request_states.get(request_id):
|
||||
external_req_id = req_state.external_req_id
|
||||
internal_ids = self.external_req_ids[external_req_id]
|
||||
internal_ids.remove(request_id)
|
||||
if not internal_ids:
|
||||
del self.external_req_ids[external_req_id]
|
||||
elif internal_ids := self.external_req_ids.pop(request_id, []):
|
||||
# External ID - abort all requests in the external->internal mapping
|
||||
internal_req_ids.extend(internal_ids)
|
||||
|
||||
request_ids_to_abort = []
|
||||
for request_id in internal_req_ids:
|
||||
req_state = self.request_states.pop(request_id, None)
|
||||
if req_state is not None:
|
||||
self.lora_states.request_finished(request_id, req_state.lora_name)
|
||||
request_ids_to_abort.append(request_id)
|
||||
# Produce final abort output.
|
||||
if req_state.queue is not None and (
|
||||
request_output := req_state.make_request_output(
|
||||
new_token_ids=[],
|
||||
# Set pooling_output is not None to
|
||||
# correctly enter the abort pooling branch
|
||||
pooling_output=EMPTY_CPU_TENSOR
|
||||
if req_state.detokenizer is None
|
||||
else None,
|
||||
finish_reason=FinishReason.ABORT,
|
||||
stop_reason=None,
|
||||
kv_transfer_params=None,
|
||||
)
|
||||
):
|
||||
req_state.queue.put(request_output)
|
||||
elif parent := self.parent_requests.get(request_id):
|
||||
# Abort children prior to removing the parent.
|
||||
if parent.child_requests:
|
||||
child_reqs = list(parent.child_requests)
|
||||
child_reqs = self.abort_requests(child_reqs, internal=True)
|
||||
request_ids_to_abort.extend(child_reqs)
|
||||
self.parent_requests.pop(request_id, None)
|
||||
return request_ids_to_abort
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: EngineCoreRequest,
|
||||
prompt: str | None,
|
||||
parent_req: ParentRequest | None = None,
|
||||
request_index: int = 0,
|
||||
queue: RequestOutputCollector | None = None,
|
||||
) -> None:
|
||||
request_id = request.request_id
|
||||
req_state = self.request_states.get(request_id)
|
||||
if req_state is not None:
|
||||
self._update_streaming_request_state(req_state, request, prompt)
|
||||
return
|
||||
|
||||
req_state = RequestState.from_new_request(
|
||||
tokenizer=self.tokenizer,
|
||||
request=request,
|
||||
prompt=prompt,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
queue=queue,
|
||||
log_stats=self.log_stats,
|
||||
stream_interval=self.stream_interval,
|
||||
)
|
||||
self.request_states[request_id] = req_state
|
||||
if parent_req:
|
||||
self.parent_requests[parent_req.request_id] = parent_req
|
||||
|
||||
# Track the external_req_id -> [internal_req_id, ...] mapping
|
||||
self.external_req_ids[req_state.external_req_id].append(request_id)
|
||||
|
||||
def _update_streaming_request_state(
|
||||
self, req_state: RequestState, request: EngineCoreRequest, prompt: str | None
|
||||
) -> None:
|
||||
"""Queue a streaming update instead of immediately applying it."""
|
||||
if not request.resumable:
|
||||
# Final request - just mark completion, don't add its dummy tokens.
|
||||
if req_state.input_chunk_queue is None:
|
||||
# Engine already finished - emit final output and clean up.
|
||||
self._finish_request(req_state)
|
||||
if req_state.queue is not None:
|
||||
# Emit a final output with finished=True
|
||||
# to unblock the generate() loop.
|
||||
req_state.queue.put(STREAM_FINISHED)
|
||||
elif req_state.input_chunk_queue:
|
||||
req_state.input_chunk_queue[-1].final = True
|
||||
else:
|
||||
req_state.streaming_input = False
|
||||
return
|
||||
|
||||
update = StreamingUpdate(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
arrival_time=request.arrival_time,
|
||||
)
|
||||
|
||||
# Apply request updates now if the last input already completed.
|
||||
if req_state.input_chunk_queue is None:
|
||||
req_state.apply_streaming_update(update)
|
||||
req_state.input_chunk_queue = deque()
|
||||
else:
|
||||
# Queue the streaming update otherwise.
|
||||
req_state.input_chunk_queue.append(update)
|
||||
|
||||
def process_outputs(
|
||||
self,
|
||||
engine_core_outputs: list[EngineCoreOutput],
|
||||
engine_core_timestamp: float | None = None,
|
||||
iteration_stats: IterationStats | None = None,
|
||||
) -> OutputProcessorOutput:
|
||||
"""
|
||||
Process the EngineCoreOutputs:
|
||||
1) Compute stats for logging
|
||||
2) Detokenize
|
||||
3) Create and handle RequestOutput objects:
|
||||
* If there is a queue (for usage with AsyncLLM),
|
||||
put the RequestOutput objects into the queue for
|
||||
handling by the per-request generate() tasks.
|
||||
|
||||
* If there is no queue (for usage with LLMEngine),
|
||||
return a list of RequestOutput objects.
|
||||
|
||||
NOTE FOR DEVELOPERS
|
||||
|
||||
vLLM V1 minimizes the number of python loops over the full
|
||||
batch to ensure system overheads are minimized. This is the
|
||||
only function that should loop over EngineCoreOutputs.
|
||||
|
||||
If you need to touch every element of the batch, do it from
|
||||
within the loop below.
|
||||
"""
|
||||
|
||||
request_outputs: list[RequestOutput | PoolingRequestOutput] = []
|
||||
reqs_to_abort: list[str] = []
|
||||
for engine_core_output in engine_core_outputs:
|
||||
req_id = engine_core_output.request_id
|
||||
req_state = self.request_states.get(req_id)
|
||||
if req_state is None:
|
||||
# Ignore output for already-aborted request.
|
||||
continue
|
||||
|
||||
# 1) Compute stats for this iteration.
|
||||
self._update_stats_from_output(
|
||||
req_state, engine_core_output, engine_core_timestamp, iteration_stats
|
||||
)
|
||||
|
||||
new_token_ids = engine_core_output.new_token_ids
|
||||
pooling_output = engine_core_output.pooling_output
|
||||
finish_reason = engine_core_output.finish_reason
|
||||
stop_reason = engine_core_output.stop_reason
|
||||
kv_transfer_params = engine_core_output.kv_transfer_params
|
||||
routed_experts = engine_core_output.routed_experts
|
||||
req_state.num_cached_tokens = engine_core_output.num_cached_tokens
|
||||
req_state.is_prefilling = False
|
||||
|
||||
if pooling_output is None:
|
||||
assert req_state.detokenizer is not None
|
||||
assert req_state.logprobs_processor is not None
|
||||
# 2) Detokenize the token ids into text and perform stop checks.
|
||||
stop_string = req_state.detokenizer.update(
|
||||
new_token_ids, finish_reason == FinishReason.STOP
|
||||
)
|
||||
if stop_string:
|
||||
finish_reason = FinishReason.STOP
|
||||
stop_reason = stop_string
|
||||
|
||||
# 3) Compute sample and prompt logprobs for request,
|
||||
# if required.
|
||||
req_state.logprobs_processor.update_from_output(engine_core_output)
|
||||
|
||||
# 4) Create and handle RequestOutput objects.
|
||||
if request_output := req_state.make_request_output(
|
||||
new_token_ids,
|
||||
pooling_output,
|
||||
finish_reason,
|
||||
stop_reason,
|
||||
kv_transfer_params,
|
||||
routed_experts,
|
||||
):
|
||||
if req_state.streaming_input:
|
||||
request_output.finished = False
|
||||
|
||||
if req_state.queue is not None:
|
||||
# AsyncLLM: put into queue for handling by generate().
|
||||
req_state.queue.put(request_output)
|
||||
else:
|
||||
# LLMEngine: return list of RequestOutputs.
|
||||
request_outputs.append(request_output)
|
||||
|
||||
# Free completed requests.
|
||||
if finish_reason is not None:
|
||||
if req_state.streaming_input:
|
||||
if req_state.input_chunk_queue:
|
||||
update = req_state.input_chunk_queue.popleft()
|
||||
req_state.apply_streaming_update(update)
|
||||
else:
|
||||
req_state.input_chunk_queue = None
|
||||
else:
|
||||
self._finish_request(req_state)
|
||||
if not engine_core_output.finished:
|
||||
# If req not finished in EngineCore, but Detokenizer
|
||||
# detected stop string, abort needed in EngineCore.
|
||||
reqs_to_abort.append(req_id)
|
||||
|
||||
# Track per-request stats
|
||||
self._update_stats_from_finished(
|
||||
req_state, finish_reason, iteration_stats
|
||||
)
|
||||
if self.tracing_enabled:
|
||||
self.do_tracing(engine_core_output, req_state, iteration_stats)
|
||||
|
||||
return OutputProcessorOutput(
|
||||
request_outputs=request_outputs,
|
||||
reqs_to_abort=reqs_to_abort,
|
||||
)
|
||||
|
||||
def _finish_request(self, req_state: RequestState) -> None:
|
||||
req_id = req_state.request_id
|
||||
self.request_states.pop(req_id)
|
||||
|
||||
internal_ids = self.external_req_ids[req_state.external_req_id]
|
||||
internal_ids.remove(req_id)
|
||||
if not internal_ids:
|
||||
del self.external_req_ids[req_state.external_req_id]
|
||||
|
||||
# Remove parent request if applicable.
|
||||
parent_req = req_state.parent_req
|
||||
if parent_req and not parent_req.child_requests:
|
||||
self.parent_requests.pop(parent_req.request_id, None)
|
||||
|
||||
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
|
||||
self.lora_states.update_scheduler_stats(scheduler_stats)
|
||||
|
||||
def do_tracing(
|
||||
self,
|
||||
engine_core_output: EngineCoreOutput,
|
||||
req_state: RequestState,
|
||||
iteration_stats: IterationStats | None,
|
||||
) -> None:
|
||||
assert req_state.stats is not None
|
||||
assert iteration_stats is not None
|
||||
|
||||
metrics = req_state.stats
|
||||
arrival_time_ns = int(metrics.arrival_time * 1e9)
|
||||
trace_context = extract_trace_context(engine_core_output.trace_headers)
|
||||
prompt_length = length_from_prompt_token_ids_or_embeds(
|
||||
req_state.prompt_token_ids, req_state.prompt_embeds
|
||||
)
|
||||
|
||||
# Calculate timing metrics
|
||||
e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
|
||||
queued_time = metrics.scheduled_ts - metrics.queued_ts
|
||||
prefill_time = metrics.first_token_ts - metrics.scheduled_ts
|
||||
decode_time = metrics.last_token_ts - metrics.first_token_ts
|
||||
inference_time = metrics.last_token_ts - metrics.scheduled_ts
|
||||
|
||||
# Build attributes dict
|
||||
attributes: dict[str, Any] = {
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN: (
|
||||
metrics.first_token_latency
|
||||
),
|
||||
SpanAttributes.GEN_AI_LATENCY_E2E: e2e_time,
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE: queued_time,
|
||||
SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS: prompt_length,
|
||||
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS: (
|
||||
metrics.num_generation_tokens
|
||||
),
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL: prefill_time,
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE: decode_time,
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE: inference_time,
|
||||
SpanAttributes.GEN_AI_REQUEST_ID: req_state.external_req_id,
|
||||
}
|
||||
|
||||
# Add optional request parameters
|
||||
if req_state.top_p:
|
||||
attributes[SpanAttributes.GEN_AI_REQUEST_TOP_P] = req_state.top_p
|
||||
if req_state.max_tokens_param:
|
||||
attributes[SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS] = (
|
||||
req_state.max_tokens_param
|
||||
)
|
||||
if req_state.temperature:
|
||||
attributes[SpanAttributes.GEN_AI_REQUEST_TEMPERATURE] = (
|
||||
req_state.temperature
|
||||
)
|
||||
if req_state.n:
|
||||
attributes[SpanAttributes.GEN_AI_REQUEST_N] = req_state.n
|
||||
|
||||
instrument_manual(
|
||||
span_name="llm_request",
|
||||
start_time=arrival_time_ns,
|
||||
attributes=attributes,
|
||||
context=trace_context,
|
||||
kind=SpanKind.SERVER,
|
||||
)
|
||||
|
||||
def _update_stats_from_output(
|
||||
self,
|
||||
req_state: RequestState,
|
||||
engine_core_output: EngineCoreOutput,
|
||||
engine_core_timestamp: float | None,
|
||||
iteration_stats: IterationStats | None,
|
||||
):
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
assert engine_core_timestamp is not None
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_output(
|
||||
engine_core_output,
|
||||
engine_core_timestamp,
|
||||
req_state.is_prefilling,
|
||||
req_state.prompt_len,
|
||||
req_state.stats,
|
||||
self.lora_states,
|
||||
req_state.lora_name,
|
||||
)
|
||||
|
||||
def _update_stats_from_finished(
|
||||
self,
|
||||
req_state: RequestState,
|
||||
finish_reason: FinishReason | None,
|
||||
iteration_stats: IterationStats | None,
|
||||
):
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
assert finish_reason is not None
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_finished_request(
|
||||
finish_reason=finish_reason,
|
||||
num_prompt_tokens=req_state.prompt_len,
|
||||
max_tokens_param=req_state.max_tokens_param,
|
||||
req_stats=req_state.stats,
|
||||
num_cached_tokens=req_state.num_cached_tokens,
|
||||
)
|
||||
self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
|
||||
|
||||
ParentRequest.observe_finished_request(
|
||||
req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
|
||||
)
|
||||
150
vllm/v1/engine/parallel_sampling.py
Normal file
150
vllm/v1/engine/parallel_sampling.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import copy
|
||||
from typing import cast
|
||||
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
|
||||
class ParentRequest:
|
||||
"""Info, state & processing for parallel sampling request.
|
||||
|
||||
Store parent request ID and sampling params.
|
||||
Facilitate generating child request sampling params.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
external_req_id: str
|
||||
sampling_params: SamplingParams
|
||||
|
||||
# To track the completion of child requests
|
||||
child_requests: set[str]
|
||||
|
||||
# To aggregate child completions when not streaming
|
||||
output_aggregator: list[CompletionOutput]
|
||||
|
||||
# To find the max number of generated tokens across all children
|
||||
max_num_generation_tokens: int
|
||||
|
||||
# To efficiently obtain child sampling params
|
||||
cached_child_sampling_params: SamplingParams | None
|
||||
|
||||
def __init__(self, request: EngineCoreRequest) -> None:
|
||||
assert request.external_req_id is not None
|
||||
sampling_params = request.params
|
||||
self.request_id = request.request_id
|
||||
self.external_req_id = request.external_req_id
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.child_requests = set()
|
||||
self.output_aggregator = (
|
||||
[cast(CompletionOutput, None)] * sampling_params.n
|
||||
if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY)
|
||||
else []
|
||||
)
|
||||
self.max_num_generation_tokens = 0
|
||||
self.cached_child_sampling_params = None
|
||||
|
||||
def _get_child_sampling_params(
|
||||
self,
|
||||
index: int,
|
||||
) -> SamplingParams:
|
||||
"""Efficiently obtain child `sampling_params`
|
||||
|
||||
If `sampling_params.seed` is not `None` then
|
||||
each child request requires a unique clone of
|
||||
parent `sampling_params` with a unique seed.
|
||||
|
||||
Args:
|
||||
index: index within `n` child requests
|
||||
|
||||
Returns:
|
||||
Child `sampling_params` instance.
|
||||
"""
|
||||
seed = self.sampling_params.seed
|
||||
if self.cached_child_sampling_params:
|
||||
# Reuse child sampling_params data structure
|
||||
return self.cached_child_sampling_params
|
||||
# Build child sampling_params
|
||||
child_sampling_params = copy(self.sampling_params)
|
||||
child_sampling_params.n = 1
|
||||
if seed is None:
|
||||
# Cache child sampling_params for later reuse
|
||||
self.cached_child_sampling_params = child_sampling_params
|
||||
else:
|
||||
# Each child gets a clone with a unique seed
|
||||
child_sampling_params.seed = seed + index
|
||||
return child_sampling_params
|
||||
|
||||
def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
|
||||
"""Get child request ID and sampling params.
|
||||
|
||||
Args:
|
||||
index: index within `n` child requests.
|
||||
|
||||
Returns:
|
||||
(request ID, sampling_params) tuple
|
||||
"""
|
||||
child_req_id = f"{index}_{self.request_id}"
|
||||
self.child_requests.add(child_req_id)
|
||||
return child_req_id, self._get_child_sampling_params(index)
|
||||
|
||||
@property
|
||||
def n(self) -> int:
|
||||
return self.sampling_params.n
|
||||
|
||||
def get_outputs(
|
||||
self,
|
||||
child_request_id: str,
|
||||
completion_output: CompletionOutput,
|
||||
) -> tuple[list[CompletionOutput], bool]:
|
||||
already_finished_and_returned: bool = False
|
||||
if completion_output.finished():
|
||||
if child_request_id in self.child_requests:
|
||||
self.child_requests.remove(child_request_id)
|
||||
else:
|
||||
# child request ID is not available in child_requests
|
||||
# which means the request had finished in previous
|
||||
# batch step and returned to the client earlier
|
||||
already_finished_and_returned = True
|
||||
|
||||
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
|
||||
# If streaming, just return the current output
|
||||
#
|
||||
# DO NOT output finished and already returned child request to client again
|
||||
outputs = [] if already_finished_and_returned else [completion_output]
|
||||
else:
|
||||
# If not streaming, aggregate the n final outputs.
|
||||
self.output_aggregator[completion_output.index] = completion_output
|
||||
outputs = [] if self.child_requests else self.output_aggregator
|
||||
|
||||
finished = not self.child_requests
|
||||
return outputs, finished
|
||||
|
||||
def observe_num_generation_tokens(self, num_generation_tokens: int):
|
||||
self.max_num_generation_tokens = max(
|
||||
num_generation_tokens, self.max_num_generation_tokens
|
||||
)
|
||||
return self.max_num_generation_tokens
|
||||
|
||||
@staticmethod
|
||||
def observe_finished_request(
|
||||
parent_req: "ParentRequest | None",
|
||||
iteration_stats: IterationStats,
|
||||
num_generation_tokens: int,
|
||||
):
|
||||
n_param = parent_req.n if parent_req is not None else 1
|
||||
|
||||
if parent_req is not None:
|
||||
num_generation_tokens = parent_req.observe_num_generation_tokens(
|
||||
num_generation_tokens
|
||||
)
|
||||
|
||||
# Child requests finished, we can now record to iteration stats
|
||||
if parent_req is None or not parent_req.child_requests:
|
||||
iteration_stats.max_num_generation_tokens_iter.append(num_generation_tokens)
|
||||
iteration_stats.n_params_iter.append(n_param)
|
||||
1090
vllm/v1/engine/utils.py
Normal file
1090
vllm/v1/engine/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user