# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import contextlib import queue import sys import uuid import weakref from abc import ABC, abstractmethod from collections import deque from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass from threading import Thread from typing import Any, Callable, Optional, TypeVar, Union import msgspec.msgpack import zmq import zmq.asyncio from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket, zmq_socket_ctx) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr from vllm.v1.utils import (CoreEngine, CoreEngineActorManager, CoreEngineProcManager, EngineZmqAddresses, get_engine_client_zmq_addr, wait_for_engine_startup) logger = init_logger(__name__) AnyFuture = Union[asyncio.Future[Any], Future[Any]] _R = TypeVar('_R') # Return type for collective_rpc class EngineCoreClient(ABC): """ EngineCoreClient: subclasses handle different methods for pushing and pulling from the EngineCore for asyncio / multiprocessing. Subclasses: * InprocClient: In process EngineCore (for V0-style LLMEngine use) * SyncMPClient: ZMQ + background proc EngineCore (for LLM) * AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM) """ @staticmethod def make_client( multiprocess_mode: bool, asyncio_mode: bool, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, ) -> "EngineCoreClient": # TODO: support this for debugging purposes. if asyncio_mode and not multiprocess_mode: raise NotImplementedError( "Running EngineCore in asyncio without multiprocessing " "is not currently supported.") if multiprocess_mode and asyncio_mode: return EngineCoreClient.make_async_mp_client( vllm_config, executor_class, log_stats) if multiprocess_mode and not asyncio_mode: return SyncMPClient(vllm_config, executor_class, log_stats) return InprocClient(vllm_config, executor_class, log_stats) @staticmethod def make_async_mp_client( vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, client_index: int = 0, ) -> "MPClient": if vllm_config.parallel_config.data_parallel_size > 1: if vllm_config.parallel_config.data_parallel_backend == "ray": return RayDPClient(vllm_config, executor_class, log_stats, client_addresses, client_index) return DPAsyncMPClient(vllm_config, executor_class, log_stats, client_addresses, client_index) return AsyncMPClient(vllm_config, executor_class, log_stats, client_addresses, client_index) @abstractmethod def shutdown(self): ... def get_output(self) -> EngineCoreOutputs: raise NotImplementedError def add_request(self, request: EngineCoreRequest) -> None: raise NotImplementedError def profile(self, is_start: bool = True) -> None: raise NotImplementedError def reset_mm_cache(self) -> None: raise NotImplementedError def reset_prefix_cache(self) -> None: raise NotImplementedError def sleep(self, level: int = 1) -> None: raise NotImplementedError def wake_up(self, tags: Optional[list[str]] = None) -> None: raise NotImplementedError def is_sleeping(self) -> bool: raise NotImplementedError def execute_dummy_batch(self) -> None: raise NotImplementedError async def execute_dummy_batch_async(self) -> None: raise NotImplementedError def abort_requests(self, request_ids: list[str]) -> None: raise NotImplementedError def add_lora(self, lora_request: LoRARequest) -> bool: raise NotImplementedError def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError def list_loras(self) -> set[int]: raise NotImplementedError def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError def save_sharded_state(self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None) -> None: raise NotImplementedError def collective_rpc(self, method: Union[str, Callable[..., _R]], timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict[str, Any]] = None) -> list[_R]: raise NotImplementedError async def get_output_async(self) -> EngineCoreOutputs: raise NotImplementedError async def add_request_async(self, request: EngineCoreRequest) -> None: raise NotImplementedError async def profile_async(self, is_start: bool = True) -> None: raise NotImplementedError async def reset_mm_cache_async(self) -> None: raise NotImplementedError async def reset_prefix_cache_async(self) -> None: raise NotImplementedError async def sleep_async(self, level: int = 1) -> None: raise NotImplementedError async def wake_up_async(self, tags: Optional[list[str]] = None) -> None: raise NotImplementedError async def is_sleeping_async(self) -> bool: raise NotImplementedError async def abort_requests_async(self, request_ids: list[str]) -> None: raise NotImplementedError async def add_lora_async(self, lora_request: LoRARequest) -> bool: raise NotImplementedError async def remove_lora_async(self, lora_id: int) -> bool: raise NotImplementedError async def list_loras_async(self) -> set[int]: raise NotImplementedError async def pin_lora_async(self, lora_id: int) -> bool: raise NotImplementedError async def save_sharded_state_async(self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None) -> None: raise NotImplementedError async def collective_rpc_async( self, method: Union[str, Callable[..., _R]], timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict[str, Any]] = None) -> list[_R]: raise NotImplementedError class InprocClient(EngineCoreClient): """ InprocClient: client for in-process EngineCore. Intended for use in LLMEngine for V0-style add_request() and step() EngineCore setup in this process (no busy loop). * pushes EngineCoreRequest directly into the EngineCore * pulls EngineCoreOutputs by stepping the EngineCore """ def __init__(self, *args, **kwargs): self.engine_core = EngineCore(*args, **kwargs) def get_output(self) -> EngineCoreOutputs: outputs, _ = self.engine_core.step() return outputs.get(0) or EngineCoreOutputs() def add_request(self, request: EngineCoreRequest) -> None: self.engine_core.add_request(request) def abort_requests(self, request_ids: list[str]) -> None: if len(request_ids) > 0: self.engine_core.abort_requests(request_ids) def shutdown(self) -> None: self.engine_core.shutdown() def profile(self, is_start: bool = True) -> None: self.engine_core.profile(is_start) def reset_mm_cache(self) -> None: self.engine_core.reset_mm_cache() def reset_prefix_cache(self) -> None: self.engine_core.reset_prefix_cache() def sleep(self, level: int = 1) -> None: self.engine_core.sleep(level) def wake_up(self, tags: Optional[list[str]] = None) -> None: self.engine_core.wake_up(tags) def is_sleeping(self) -> bool: return self.engine_core.is_sleeping() def execute_dummy_batch(self) -> None: self.engine_core.execute_dummy_batch() def add_lora(self, lora_request: LoRARequest) -> bool: return self.engine_core.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: return self.engine_core.remove_lora(lora_id) def list_loras(self) -> set[int]: return self.engine_core.list_loras() def pin_lora(self, lora_id: int) -> bool: return self.engine_core.pin_lora(lora_id) def save_sharded_state(self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None) -> None: self.engine_core.save_sharded_state(path, pattern, max_size) def collective_rpc(self, method: Union[str, Callable[..., _R]], timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict[str, Any]] = None) -> list[_R]: return self.engine_core.collective_rpc(method, timeout, args, kwargs) @dataclass class BackgroundResources: """Used as a finalizer for clean shutdown, avoiding circular reference back to the client object.""" ctx: Union[zmq.Context] # If CoreEngineProcManager, it manages local engines; # if CoreEngineActorManager, it manages all engines. engine_manager: Optional[Union[CoreEngineProcManager, CoreEngineActorManager]] = None coordinator: Optional[DPCoordinator] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None first_req_send_socket: Optional[zmq.asyncio.Socket] = None output_queue_task: Optional[asyncio.Task] = None stats_update_task: Optional[asyncio.Task] = None shutdown_path: Optional[str] = None # Set if any of the engines are dead. Here so that the output # processing threads can access it without holding a ref to the client. engine_dead: bool = False def __call__(self): """Clean up background resources.""" self.engine_dead = True if self.engine_manager is not None: self.engine_manager.close() if self.coordinator is not None: self.coordinator.close() if self.output_queue_task is not None: self.output_queue_task.cancel() if self.stats_update_task is not None: self.stats_update_task.cancel() # ZMQ context termination can hang if the sockets # aren't explicitly closed first. for socket in (self.output_socket, self.input_socket, self.first_req_send_socket): if socket is not None: socket.close(linger=0) if self.shutdown_path is not None: # We must ensure that the sync output socket is # closed cleanly in its own thread. with self.ctx.socket(zmq.PAIR) as shutdown_sender: shutdown_sender.connect(self.shutdown_path) # Send shutdown signal. shutdown_sender.send(b'') def validate_alive(self, frames: Sequence[zmq.Frame]): if len(frames) == 1 and (frames[0].buffer == EngineCoreProc.ENGINE_CORE_DEAD): self.engine_dead = True raise EngineDeadError() class MPClient(EngineCoreClient): """ MPClient: base client for multi-proc EngineCore. EngineCore runs in a background process busy loop, getting new EngineCoreRequests and returning EngineCoreOutputs * pushes EngineCoreRequests via input_socket * pulls EngineCoreOutputs via output_socket * AsyncMPClient subclass for AsyncLLM usage * SyncMPClient subclass for LLM usage """ def __init__( self, asyncio_mode: bool, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, ): self.vllm_config = vllm_config # Serialization setup. self.encoder = MsgpackEncoder() self.decoder = MsgpackDecoder(EngineCoreOutputs) # ZMQ setup. sync_ctx = zmq.Context(io_threads=2) self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx # This will ensure resources created so far are closed # when the client is garbage collected, even if an # exception is raised mid-construction. self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) success = False try: parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local local_start_index = parallel_config.data_parallel_rank_local dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank # SPMD mode is where there is an LLM instance per DP rank and # one core engine per LLM, see # examples/offline_inference/data_parallel.py. spmd_mode = local_start_index is not None if spmd_mode: assert local_engine_count == 1 self.core_engines = [CoreEngine(index=dp_rank, local=True)] else: assert dp_rank == 0 local_start_index = 0 self.core_engines = [ CoreEngine(index=i, local=(i < local_engine_count)) for i in range(dp_size) ] local_only = spmd_mode or local_engine_count == dp_size self.stats_update_address: Optional[str] = None if client_addresses is not None: input_address = client_addresses["input_address"] output_address = client_addresses["output_address"] self.stats_update_address = client_addresses.get( "stats_update_address") else: host = parallel_config.data_parallel_master_ip input_address = get_engine_client_zmq_addr(local_only, host) output_address = get_engine_client_zmq_addr(local_only, host) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( self.ctx, input_address, zmq.ROUTER, bind=True) self.resources.output_socket = make_zmq_socket( self.ctx, output_address, zmq.PULL) if client_addresses is None: self._init_engines_direct(vllm_config, local_only, local_start_index, input_address, output_address, executor_class, log_stats) coordinator = self.resources.coordinator if coordinator: self.stats_update_address = ( coordinator.get_stats_publish_address()) # Wait for ready messages from each engine on the input socket. identities = set(e.identity for e in self.core_engines) sync_input_socket = zmq.Socket.shadow(self.input_socket) while identities: if not sync_input_socket.poll(timeout=600_000): raise TimeoutError("Timed out waiting for engines to send" "initial message on input socket.") identity, _ = sync_input_socket.recv_multipart() identities.remove(identity) self.core_engine = self.core_engines[0] self.utility_results: dict[int, AnyFuture] = {} # Request objects which may contain pytorch-allocated tensors # that we need to keep references to until zmq is done with the # underlying data. self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]() success = True finally: if not success: self._finalizer() def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, local_start_index: int, input_address: str, output_address: str, executor_class: type[Executor], log_stats: bool): """Self-contained client mode, launch engine and coordinator process as needed.""" parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local start_index = parallel_config.data_parallel_rank host = parallel_config.data_parallel_master_ip if len(self.core_engines) > 1: self.resources.coordinator = DPCoordinator(parallel_config) handshake_address = get_engine_client_zmq_addr( local_only, host, parallel_config.data_parallel_rpc_port) with zmq_socket_ctx(handshake_address, zmq.ROUTER, bind=True) as handshake_socket: # Start local engines. if local_engine_count: # In server mode, start_index and local_start_index will # both be 0. self.resources.engine_manager = CoreEngineProcManager( EngineCoreProc.run_engine_core, vllm_config=vllm_config, executor_class=executor_class, log_stats=log_stats, handshake_address=handshake_address, on_head_node=True, local_engine_count=local_engine_count, start_index=start_index, local_start_index=local_start_index) # Wait for engine core process(es) to start. self._wait_for_engine_startup(handshake_socket, input_address, output_address) def _wait_for_engine_startup(self, handshake_socket: zmq.Socket, input_address: str, output_address: str): addresses = EngineZmqAddresses( inputs=[input_address], outputs=[output_address], ) coordinator = self.resources.coordinator if coordinator is not None: addresses.coordinator_input, addresses.coordinator_output = ( coordinator.get_engine_socket_addresses()) proc_manager = self.resources.engine_manager assert isinstance(proc_manager, (type(None), CoreEngineProcManager)), ( "_wait_for_engine_startup should only be " "called with CoreEngineProcManager") wait_for_engine_startup( handshake_socket, addresses, self.core_engines, self.vllm_config.parallel_config, self.vllm_config.cache_config, proc_manager, coordinator.proc if coordinator else None, ) def shutdown(self): # Terminate background resources. self._finalizer() def _format_exception(self, e: Exception) -> Exception: """If errored, use EngineDeadError so root cause is clear.""" return EngineDeadError( suppress_context=True) if self.resources.engine_dead else e def ensure_alive(self): if self.resources.engine_dead: raise EngineDeadError() def add_pending_message(self, tracker: zmq.MessageTracker, msg: Any): if not tracker.done: self.pending_messages.appendleft((tracker, msg)) def free_pending_messages(self): while self.pending_messages and self.pending_messages[-1][0].done: self.pending_messages.pop() def _process_utility_output(output: UtilityOutput, utility_results: dict[int, AnyFuture]): """Set the result from a utility method in the waiting future""" future = utility_results.pop(output.call_id) if output.failure_message is not None: future.set_exception(Exception(output.failure_message)) else: future.set_result(output.result) class SyncMPClient(MPClient): """Synchronous client for multi-proc EngineCore.""" def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool): super().__init__( asyncio_mode=False, vllm_config=vllm_config, executor_class=executor_class, log_stats=log_stats, ) self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]() # Ensure that the outputs socket processing thread does not have # a ref to the client which prevents gc. ctx = self.ctx out_socket = self.resources.output_socket assert out_socket is not None decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue shutdown_path = get_open_zmq_inproc_path() resources = self.resources resources.shutdown_path = shutdown_path def process_outputs_socket(): shutdown_socket = ctx.socket(zmq.PAIR) try: shutdown_socket.bind(shutdown_path) poller = zmq.Poller() poller.register(shutdown_socket, zmq.POLLIN) poller.register(out_socket, zmq.POLLIN) while True: socks = poller.poll() if not socks: continue if len(socks) == 2 or socks[0][0] == shutdown_socket: # shutdown signal, exit thread. break frames = out_socket.recv_multipart(copy=False) resources.validate_alive(frames) outputs = decoder.decode(frames) if outputs.utility_output: _process_utility_output(outputs.utility_output, utility_results) else: outputs_queue.put_nowait(outputs) except Exception as e: outputs_queue.put_nowait(e) finally: # Close sockets. shutdown_socket.close(linger=0) out_socket.close(linger=0) # Process outputs from engine in separate thread. self.output_queue_thread = Thread(target=process_outputs_socket, name="EngineCoreOutputQueueThread", daemon=True) self.output_queue_thread.start() # The thread takes on responsibility for closing the socket. self.resources.output_socket = None def get_output(self) -> EngineCoreOutputs: # If an exception arises in process_outputs_socket task, # it is forwarded to the outputs_queue so we can raise it # from this (run_output_handler) task to shut down the server. outputs = self.outputs_queue.get() if isinstance(outputs, Exception): raise self._format_exception(outputs) from None return outputs def _send_input(self, request_type: EngineCoreRequestType, request: Any): self.ensure_alive() self.free_pending_messages() # (Identity, RequestType, SerializedRequest) msg = (self.core_engine.identity, request_type.value, *self.encoder.encode(request)) if len(msg) <= 3: # No auxiliary buffers => no tensor backing buffers in request. self.input_socket.send_multipart(msg, copy=False) return tracker = self.input_socket.send_multipart(msg, copy=False, track=True) self.add_pending_message(tracker, request) def call_utility(self, method: str, *args) -> Any: call_id = uuid.uuid1().int >> 64 future: Future[Any] = Future() self.utility_results[call_id] = future self._send_input(EngineCoreRequestType.UTILITY, (0, call_id, method, args)) return future.result() def add_request(self, request: EngineCoreRequest) -> None: self._send_input(EngineCoreRequestType.ADD, request) def abort_requests(self, request_ids: list[str]) -> None: if request_ids and not self.resources.engine_dead: self._send_input(EngineCoreRequestType.ABORT, request_ids) def profile(self, is_start: bool = True) -> None: self.call_utility("profile", is_start) def reset_mm_cache(self) -> None: self.call_utility("reset_mm_cache") def reset_prefix_cache(self) -> None: self.call_utility("reset_prefix_cache") def add_lora(self, lora_request: LoRARequest) -> bool: return self.call_utility("add_lora", lora_request) def remove_lora(self, lora_id: int) -> bool: return self.call_utility("remove_lora", lora_id) def list_loras(self) -> set[int]: return self.call_utility("list_loras") def pin_lora(self, lora_id: int) -> bool: return self.call_utility("pin_lora", lora_id) def sleep(self, level: int = 1) -> None: self.call_utility("sleep", level) def wake_up(self, tags: Optional[list[str]] = None) -> None: self.call_utility("wake_up", tags) def is_sleeping(self) -> bool: return self.call_utility("is_sleeping") def execute_dummy_batch(self) -> None: self.call_utility("execute_dummy_batch") def collective_rpc(self, method: Union[str, Callable[..., _R]], timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict[str, Any]] = None) -> list[_R]: return self.call_utility("collective_rpc", method, timeout, args, kwargs) def save_sharded_state(self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None) -> None: self.call_utility("save_sharded_state", path, pattern, max_size) class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, client_index: int = 0): super().__init__( asyncio_mode=True, vllm_config=vllm_config, executor_class=executor_class, log_stats=log_stats, client_addresses=client_addresses, ) self.client_index = client_index self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, Exception]]() try: # If we are running in an asyncio event loop, start the queue task. # Otherwise, it will be started lazily. If it is not started here, # we could miss EXECUTOR_FAILED messages from engine core if they # occur prior to any requests being sent. asyncio.get_running_loop() self._ensure_output_queue_task() except RuntimeError: pass def _ensure_output_queue_task(self): resources = self.resources if resources.output_queue_task is not None: return # Perform IO in separate task to parallelize as much as possible. # Avoid task having direct reference back to the client. decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue output_handler: Optional[Callable[[AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = getattr( self.__class__, "process_engine_outputs", None) _self_ref = weakref.ref(self) if output_handler else None output_socket = resources.output_socket assert output_socket is not None async def process_outputs_socket(): try: while True: frames = await output_socket.recv_multipart(copy=False) resources.validate_alive(frames) outputs: EngineCoreOutputs = decoder.decode(frames) if outputs.utility_output: _process_utility_output(outputs.utility_output, utility_results) continue if output_handler is not None: assert _self_ref is not None _self = _self_ref() if not _self: # Client has been garbage collected, abort. return await output_handler(_self, outputs) if outputs.outputs or outputs.scheduler_stats: outputs_queue.put_nowait(outputs) except Exception as e: outputs_queue.put_nowait(e) resources.output_queue_task = asyncio.create_task( process_outputs_socket(), name="EngineCoreOutputQueueTask") async def get_output_async(self) -> EngineCoreOutputs: self._ensure_output_queue_task() # If an exception arises in process_outputs_socket task, # it is forwarded to the outputs_queue so we can raise it # from this (run_output_handler) task to shut down the server. assert self.outputs_queue is not None outputs = await self.outputs_queue.get() if isinstance(outputs, Exception): raise self._format_exception(outputs) from None return outputs def _send_input(self, request_type: EngineCoreRequestType, request: Any, engine: Optional[CoreEngine] = None) -> Awaitable[Any]: self.ensure_alive() if engine is None: engine = self.core_engine message = (request_type.value, *self.encoder.encode(request)) return self._send_input_message(message, engine, request) def _send_input_message(self, message: tuple[bytestr, ...], engine: CoreEngine, objects: Any) -> Awaitable[Any]: """ objects is a reference to retain until zmq is finished with the buffers, in case they were extracted from tensors in the request. """ self.ensure_alive() self.free_pending_messages() msg = (engine.identity, ) + message if not objects or len(msg) <= 3: # No auxiliary buffers => no tensor backing buffers in request. return self.input_socket.send_multipart(msg, copy=False) future: asyncio.Future[zmq.MessageTracker] future = self.input_socket.send_multipart(msg, copy=False, track=True) def add_pending(f: asyncio.Future[zmq.MessageTracker]): with contextlib.suppress(BaseException): self.add_pending_message(f.result(), objects) future.add_done_callback(add_pending) return future async def call_utility_async(self, method: str, *args) -> Any: return await self._call_utility_async(method, *args, engine=self.core_engine) async def _call_utility_async(self, method: str, *args, engine: CoreEngine) -> Any: call_id = uuid.uuid1().int >> 64 future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( (self.client_index, call_id, method, args))) await self._send_input_message(message, engine, args) self._ensure_output_queue_task() return await future async def add_request_async(self, request: EngineCoreRequest) -> None: request.client_index = self.client_index await self._send_input(EngineCoreRequestType.ADD, request) self._ensure_output_queue_task() async def abort_requests_async(self, request_ids: list[str]) -> None: if request_ids and not self.resources.engine_dead: await self._send_input(EngineCoreRequestType.ABORT, request_ids) async def profile_async(self, is_start: bool = True) -> None: await self.call_utility_async("profile", is_start) async def reset_mm_cache_async(self) -> None: await self.call_utility_async("reset_mm_cache") async def reset_prefix_cache_async(self) -> None: await self.call_utility_async("reset_prefix_cache") async def sleep_async(self, level: int = 1) -> None: await self.call_utility_async("sleep", level) async def wake_up_async(self, tags: Optional[list[str]] = None) -> None: await self.call_utility_async("wake_up", tags) async def is_sleeping_async(self) -> bool: return await self.call_utility_async("is_sleeping") async def execute_dummy_batch_async(self) -> None: await self.call_utility_async("execute_dummy_batch") async def add_lora_async(self, lora_request: LoRARequest) -> bool: return await self.call_utility_async("add_lora", lora_request) async def remove_lora_async(self, lora_id: int) -> bool: return await self.call_utility_async("remove_lora", lora_id) async def list_loras_async(self) -> set[int]: return await self.call_utility_async("list_loras") async def pin_lora_async(self, lora_id: int) -> bool: return await self.call_utility_async("pin_lora", lora_id) async def save_sharded_state_async(self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None) -> None: await self.call_utility_async("save_sharded_state", path, pattern, max_size) async def collective_rpc_async( self, method: Union[str, Callable[..., _R]], timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict[str, Any]] = None) -> list[_R]: return await self.call_utility_async("collective_rpc", method, timeout, args, kwargs) class DPAsyncMPClient(AsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) EngineCore.""" def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, client_index: int = 0): self.current_wave = 0 self.engines_running = False # To route aborts to the correct engine. self.reqs_in_flight: dict[str, CoreEngine] = {} super().__init__(vllm_config, executor_class, log_stats, client_addresses, client_index) assert len(self.core_engines) > 1 # List of [waiting, running] pair per engine. self.lb_engines: list[list[int]] = [] self.first_req_sock_addr = get_open_zmq_inproc_path() self.first_req_send_socket = self.resources.first_req_send_socket = ( make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True)) try: # If we are running in an asyncio event loop, start the stats task. # Otherwise, it will be started lazily. asyncio.get_running_loop() self._ensure_stats_update_task() except RuntimeError: pass def _ensure_stats_update_task(self): resources = self.resources if resources.stats_update_task is not None: return assert self.stats_update_address is not None async def run_engine_stats_update_task(): with make_zmq_socket(self.ctx, self.stats_update_address, zmq.XSUB) as socket, make_zmq_socket( self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False) as first_req_rcv_socket: # Send subscription message. await socket.send(b'\x01') poller = zmq.asyncio.Poller() poller.register(socket, zmq.POLLIN) poller.register(first_req_rcv_socket, zmq.POLLIN) while True: events = await poller.poll() if not self.engines_running and len(events) == 2 or ( events[0][0] == first_req_rcv_socket): # Send a message to notify the coordinator that # we're sending a request while the engines are # paused, so that it can wake the others up # (to run dummy EP loop). self.engines_running = True buf = first_req_rcv_socket.recv( flags=zmq.NOBLOCK).result() target_eng_index = int.from_bytes(buf, "little") msg = msgspec.msgpack.encode( (target_eng_index, self.current_wave)) await socket.send(msg) buf = None while True: # Drain all stats events (we only care about latest). future: asyncio.Future[bytes] = socket.recv( flags=zmq.NOBLOCK) if isinstance(future.exception(), zmq.Again): break buf = future.result() if buf is None: continue # Update local load-balancing state. counts, wave, running = msgspec.msgpack.decode(buf) self.current_wave = wave self.engines_running = running self.lb_engines = counts resources.stats_update_task = asyncio.create_task( run_engine_stats_update_task()) def get_core_engine_for_request(self, dp_rank: Optional[int] = None ) -> CoreEngine: if dp_rank is not None: # engines are already in rank order return self.core_engines[dp_rank] if not self.lb_engines: return self.core_engines[0] # TODO use P2C alg for larger DP sizes num_engines = len(self.lb_engines) min_counts = [sys.maxsize, sys.maxsize] eng_index = 0 for i in range(num_engines): # Start from client_index to help with balancing when engines # are empty. idx = (self.client_index + i) % num_engines counts = self.lb_engines[idx] if counts < min_counts: min_counts = counts eng_index = idx # Adjust local counts for better balancing between stats updates # from the coordinator (which happen every 100ms). if min_counts[0]: min_counts[0] += 1 else: min_counts[1] += 1 return self.core_engines[eng_index] async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. return (await asyncio.gather(*[ self._call_utility_async(method, *args, engine=engine) for engine in self.core_engines ]))[0] async def add_request_async(self, request: EngineCoreRequest) -> None: self._ensure_stats_update_task() request.current_wave = self.current_wave request.client_index = self.client_index chosen_engine = self.get_core_engine_for_request( request.data_parallel_rank) self.reqs_in_flight[request.request_id] = chosen_engine to_await = self._send_input(EngineCoreRequestType.ADD, request, chosen_engine) if not self.engines_running: # Notify coordinator that we're sending a request await self.first_req_send_socket.send(chosen_engine.identity) await to_await self._ensure_output_queue_task() @staticmethod async def process_engine_outputs(self: "DPAsyncMPClient", outputs: EngineCoreOutputs): if outputs.finished_requests and self.reqs_in_flight: for req_id in outputs.finished_requests: self.reqs_in_flight.pop(req_id, None) async def abort_requests_async(self, request_ids: list[str]) -> None: if not request_ids: return if len(request_ids) == 1: # Fast-path common case. if engine := self.reqs_in_flight.get(request_ids[0]): await self._abort_requests(request_ids, engine) return by_engine: dict[CoreEngine, list[str]] = {} for req_id in request_ids: if engine := self.reqs_in_flight.get(req_id): by_engine.setdefault(engine, []).append(req_id) for engine, req_ids in by_engine.items(): await self._abort_requests(req_ids, engine) async def _abort_requests(self, request_ids: list[str], engine: CoreEngine) -> None: if not self.resources.engine_dead: await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) class RayDPClient(DPAsyncMPClient): """ Ray-based client for multi-proc, multi-engine (data parallel) EngineCore. """ def __init__( self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, client_index: int = 0, ): super().__init__(vllm_config, executor_class, log_stats, client_addresses, client_index) def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, local_start_index: int, input_address: str, output_address: str, executor_class: type[Executor], log_stats: bool): """Self-contained client mode, launch engine and coordinator process as needed.""" parallel_config = vllm_config.parallel_config assert parallel_config.data_parallel_rank == 0 assert local_start_index == 0 addresses = EngineZmqAddresses( inputs=[input_address], outputs=[output_address], ) if len(self.core_engines) > 1: coordinator = DPCoordinator(parallel_config) self.resources.coordinator = coordinator addresses.coordinator_input, addresses.coordinator_output = ( coordinator.get_engine_socket_addresses()) # Start all engines. self.resources.engine_manager = CoreEngineActorManager( vllm_config=vllm_config, addresses=addresses, executor_class=executor_class, log_stats=log_stats)