Files
2026-01-09 15:09:53 +08:00

479 lines
19 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pickle
import signal
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union
import cloudpickle
import vllm.envs as envs
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
import zmq
from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import VllmConfig
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCIsSleepingRequest,
RPCIsSleepingResponse,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest,
RPCSleepRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest, RPCWakeUpRequest)
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext
from vllm.worker.model_runner_base import InputProcessingError
import time
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
class MQLLMEngine:
"""A multiprocessing wrapper for
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
This class is used to wrap the
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use
in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode
process is kicked off when a new RPCProcessRequest is received by the
input_socket.
The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
[`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends
the RequestOutputs back over the output_socket.
If use_async_sockets is set, the logic associated with reading new
requests from the socket and sending data to the socket is passed
as a callback to the llm_engine, which calls the logic asynchronously
such that the IPC can be overlapped with the GPU.
Args:
ipc_path: Base path for zeromq interprocess messaging
use_async_sockets: Whether to make send/recv async with GPU
log_requests: Whether to log the requests.
*args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
**kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
"""
def __init__(self,
ipc_path: str,
use_async_sockets: bool,
*args,
log_requests: bool = True,
**kwargs) -> None:
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
kwargs['use_cached_outputs'] = True
if envs.VLLM_ZERO_OVERHEAD:
self.engine = ZeroOverheadEngine(*args, **kwargs)
else:
self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests
self.use_async_sockets = use_async_sockets
if self.use_async_sockets:
self.engine.process_request_outputs_callback = \
self._async_socket_engine_callback
self.ctx = zmq.Context() # type: ignore[attr-defined]
# Receive input from the client.
self.input_socket = self.ctx.socket(zmq.constants.PULL)
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
# Send output stream back to client.
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
# Send heartbeats back to client.
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Error state.
self._errored_with: Optional[BaseException] = None
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
return ENGINE_DEAD_ERROR(self._errored_with)
else:
return ENGINE_DEAD_ERROR()
@classmethod
def from_vllm_config(cls, vllm_config: VllmConfig,
usage_context: UsageContext,
disable_log_requests: bool, disable_log_stats: bool,
ipc_path: str) -> "MQLLMEngine":
# Setup plugins for each process
from vllm.plugins import load_general_plugins
load_general_plugins()
use_async_sockets = vllm_config.model_config.use_async_output_proc
return cls(
vllm_config=vllm_config,
executor_class=LLMEngine._get_executor_cls(vllm_config),
ipc_path=ipc_path,
usage_context=usage_context,
use_async_sockets=use_async_sockets,
log_requests=(not disable_log_requests),
log_stats=(not disable_log_stats),
)
@staticmethod
def from_engine_args(engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str):
"""Creates an MQLLMEngine from the engine arguments."""
vllm_config = engine_args.create_engine_config(usage_context)
return MQLLMEngine.from_vllm_config(
ipc_path=ipc_path,
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats,
)
def start(self):
try:
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
logger.exception(repr(e))
except KeyboardInterrupt:
logger.debug("Shutting down MQLLMEngine.")
finally:
logger.debug("MQLLMEngine is shut down.")
self.cleanup()
def cleanup(self):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self.ctx.destroy(linger=0)
del self.engine
@contextmanager
def make_data_socket(
self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
socket = self.ctx.socket(zmq.constants.ROUTER)
try:
socket.bind(self.data_ipc_path)
yield socket
finally:
socket.close(linger=0)
def run_startup_loop(self) -> None:
"""Startup loop for sending data from Engine -> Client."""
with self.make_data_socket() as socket:
response: Union[RPCStartupResponse, BaseException]
try:
identity, message = socket.recv_multipart(copy=False)
request: RPCStartupRequest = pickle.loads(message.buffer)
# Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
response = RPCStartupResponse(
tracing_enabled=tracing_enabled)
except Exception as e:
response = e
socket.send_multipart((identity, pickle.dumps(response)),
copy=False)
def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""
last_no_req_time_refreshed = True
last_no_req_time = time.perf_counter()
while True:
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send
# health status back to client
self._health_check()
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")
last_no_req_time = time.perf_counter()
last_no_req_time_refreshed = True
# Handle any input from the client.
self.handle_new_input()
if envs.VLLM_TBO_REQ_DELAY_MS > 0 and last_no_req_time_refreshed and envs.VLLM_ENABLE_TBO:
if self.engine.get_num_unfinished_requests() < 2:
time_diff_ms = int((time.perf_counter() - last_no_req_time) * 1000)
if time_diff_ms < envs.VLLM_TBO_REQ_DELAY_MS:
time.sleep(0.01) # sleep and waiting more request to merge in one batch
continue
last_no_req_time_refreshed = False
# Engine step.
request_outputs = self.engine_step()
# Send request outputs (if async, done in engine_step callback).
if not self.use_async_sockets:
self._send_outputs(request_outputs)
def engine_step(self) -> List[RequestOutput]:
"""Engine step wrapper with error handling."""
try:
return self.engine.step()
except SystemExit:
raise
except InputProcessingError as e:
# Special case where we handle an error preparing the inputs for
# a single request in the batch
rpc_err = RPCError(request_id=e.request_id,
is_engine_errored=False,
exception=e.__cause__)
self._send_outputs(rpc_err)
return []
except BaseException as e:
self._set_errored(e)
rpc_err = RPCError(request_id=None,
is_engine_errored=True,
exception=e)
self._send_outputs(rpc_err)
raise e
def handle_new_input(self):
"""Handle new input from the socket"""
try:
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
if isinstance(request, RPCProcessRequest):
if len(frames) > 1:
# Use cloudpickle for logits processors
assert isinstance(request.params, SamplingParams)
lprocs = cloudpickle.loads(frames[1].buffer)
request.params.logits_processors = lprocs
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCUProfileRequest):
if request == RPCUProfileRequest.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetMultiModalCacheRequest):
self.reset_mm_cache()
elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache()
elif isinstance(request, RPCSleepRequest):
self.sleep(request.value)
elif isinstance(request, RPCWakeUpRequest):
self.wake_up(request.tags)
elif isinstance(request, RPCIsSleepingRequest):
self._handle_is_sleeping_request(request)
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
raise e from None
def _handle_process_request(self, request: RPCProcessRequest):
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
request_id = request.request_id
if self._errored_with is not None:
rpc_err = RPCError(request_id=request_id,
is_engine_errored=True,
exception=ENGINE_DEAD_ERROR(self._errored_with))
self._send_outputs(rpc_err)
try:
self.engine.add_request(
request_id=request_id,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
except Exception as e:
# We do not set self._errored = True here, since the error
# is due to an issue adding this request to the engine,
# rather than an issue with the engine itself.
logger.debug("Failed to add request %s to engine. %s",
request.request_id, e)
is_errored = self._errored_with is not None
rpc_err = RPCError(request_id=request_id,
is_engine_errored=is_errored,
exception=e)
self._send_outputs(rpc_err)
# Remove request from the engine.
self.engine.abort_request(request_id)
def _handle_abort_request(self, request: RPCAbortRequest):
self.engine.abort_request(request.request_id)
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)
def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
try:
self.engine.add_lora(request.lora_request)
except BaseException as e:
# Send back an error if the adater fails to load
rpc_err = RPCError(request_id=request.request_id,
is_engine_errored=False,
exception=e)
self._send_outputs(rpc_err)
return
# Otherwise, send back the successful load message
self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id))
def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest):
is_sleeping = self.is_sleeping()
self._send_outputs(
RPCIsSleepingResponse(request_id=request.request_id,
is_sleeping=is_sleeping))
def _health_check(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send outputs back to the engine client. These can be:
- Exceptions
- A list of generation outputs
- A response from loading a lora adapter
"""
if outputs:
try:
from ray.exceptions import RayTaskError
# RayTaskError might not pickelable here. We need to unpack the
# underlying exception as the real exception in the output.
if (isinstance(outputs, RPCError)
and isinstance(outputs.exception, RayTaskError)):
outputs.exception = outputs.exception.cause
except ImportError:
pass
output_bytes = pickle.dumps(outputs)
self.output_socket.send_multipart((output_bytes, ), copy=False)
def _send_healthy(self):
"""Send HEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
error_bytes = pickle.dumps(error)
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
"""Callback used by engine to make socket handling async with GPU."""
self._send_outputs(request_outputs)
self.handle_new_input()
def _set_errored(self, e: BaseException):
"""Log and set errored status if this is the first issue."""
if self._errored_with is None:
self._errored_with = e
def start_profile(self) -> None:
self.engine.start_profile()
def stop_profile(self) -> None:
self.engine.stop_profile()
def reset_mm_cache(self) -> bool:
return self.engine.reset_mm_cache()
def reset_prefix_cache(self) -> bool:
return self.engine.reset_prefix_cache()
def sleep(self, level: int = 1) -> None:
self.engine.sleep(level)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine.is_sleeping()
def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated")
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
ipc_path: str, disable_log_stats: bool,
disable_log_requests: bool, engine_alive):
try:
# Ensure we can serialize transformer config before spawning
maybe_register_config_serialize_by_value()
engine = MQLLMEngine.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_stats=disable_log_stats,
disable_log_requests=disable_log_requests,
ipc_path=ipc_path)
signal.signal(signal.SIGTERM, signal_handler)
engine.start()
except BaseException as e:
logger.exception(e)
engine_alive.value = False
raise e from None