# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import multiprocessing import os import pickle import signal import sys import threading import time import traceback import weakref from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum, auto from functools import partial from multiprocessing.connection import Connection from multiprocessing.process import BaseProcess from threading import Thread from typing import Any, Callable, Optional, Union, cast import cloudpickle import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) from vllm.executor.multiproc_worker_utils import ( _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, get_mp_context, get_open_port) from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) class MultiprocExecutor(Executor): def _init_executor(self) -> None: # Call self.shutdown at exit to clean up # and ensure workers will be terminated. self._finalizer = weakref.finalize(self, self.shutdown) self.is_failed = False self.shutdown_event = threading.Event() self.failure_callback: Optional[FailureCallback] = None self.io_thread_pool: Optional[ThreadPoolExecutor] = None self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size pp_parallel_size = self.parallel_config.pipeline_parallel_size assert self.world_size == tensor_parallel_size * pp_parallel_size, ( f"world_size ({self.world_size}) must be equal to the " f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" f"_parallel_size ({pp_parallel_size}). ") # Set multiprocessing envs that are common to V0 and V1 set_multiprocessing_worker_envs(self.parallel_config) # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. distributed_init_method = get_distributed_init_method( "127.0.0.1", get_open_port()) # Initialize worker and set up message queues for SchedulerOutputs # and ModelRunnerOutputs max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 self.rpc_broadcast_mq = MessageQueue(self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes) scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers unready_workers: list[UnreadyWorkerProcHandle] = [] success = False try: for rank in range(self.world_size): unready_workers.append( WorkerProc.make_worker_process( vllm_config=self.vllm_config, local_rank=rank, rank=rank, distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, )) # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. self.workers = WorkerProc.wait_for_ready(unready_workers) # Ensure message queues are ready. Will deadlock if re-ordered # Must be kept consistent with the WorkerProc. self.rpc_broadcast_mq.wait_until_ready() for w in self.workers: w.worker_response_mq.wait_until_ready() self.start_worker_monitor() success = True finally: if not success: # Clean up the worker procs if there was a failure. self._ensure_worker_termination( [w.proc for w in unready_workers]) # For pipeline parallel, we use a thread pool for asynchronous # execute_model. if self.max_concurrent_batches > 1: # Note: must use only 1 IO thread to keep dequeue sequence # from the response queue self.io_thread_pool = ThreadPoolExecutor( max_workers=1, thread_name_prefix="mp_exec_io") self.output_rank = self._get_output_rank() def start_worker_monitor(self): workers = self.workers self_ref = weakref.ref(self) # Monitors worker process liveness. If any die unexpectedly, # logs an error, shuts down the executor and invokes the failure # callback to inform the engine. def monitor_workers(): sentinels = [h.proc.sentinel for h in workers] died = multiprocessing.connection.wait(sentinels) _self = self_ref() if not _self or getattr(_self, 'shutting_down', False): return _self.is_failed = True proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0]) logger.error( "Worker proc %s died unexpectedly, " "shutting down executor.", proc_name) _self.shutdown() callback = _self.failure_callback if callback is not None: _self.failure_callback = None callback() Thread(target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor").start() def register_failure_callback(self, callback: FailureCallback): if self.is_failed: callback() else: self.failure_callback = callback def execute_model( self, scheduler_output, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: (output, ) = self.collective_rpc( "execute_model", args=(scheduler_output, ), unique_reply_rank=self.output_rank, non_block=self.max_concurrent_batches > 1, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) return output def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict] = None, non_block: bool = False, unique_reply_rank: Optional[int] = None) -> list[Any]: if self.is_failed: raise RuntimeError("Executor failed.") deadline = None if timeout is None else time.monotonic() + timeout kwargs = kwargs or {} # NOTE: If the args are heterogeneous, then we pack them into a list, # and unpack them in the method of every worker, because every worker # knows their own rank. try: if isinstance(method, str): send_method = method else: send_method = cloudpickle.dumps( method, protocol=pickle.HIGHEST_PROTOCOL) self.rpc_broadcast_mq.enqueue( (send_method, args, kwargs, unique_reply_rank)) workers = (self.workers[unique_reply_rank], ) if unique_reply_rank is not None else self.workers responses = [] def get_response(w: WorkerProcHandle, dequeue_timeout: Optional[float] = None, cancel_event: Optional[threading.Event] = None): status, result = w.worker_response_mq.dequeue( timeout=dequeue_timeout, cancel=cancel_event) if status != WorkerProc.ResponseStatus.SUCCESS: raise RuntimeError( f"Worker failed with error '{result}', please check the" " stack trace above for the root cause") return result for w in workers: dequeue_timeout = None if deadline is None else ( deadline - time.monotonic()) if non_block: result = self.io_thread_pool.submit( # type: ignore get_response, w, dequeue_timeout, self.shutdown_event) else: result = get_response(w, dequeue_timeout) responses.append(result) return responses except TimeoutError as e: raise TimeoutError(f"RPC call to {method} timed out.") from e @staticmethod def _ensure_worker_termination(worker_procs: list[BaseProcess]): """Ensure that all worker processes are terminated. Assumes workers have received termination requests. Waits for processing, then sends termination and kill signals if needed.""" def wait_for_termination(procs, timeout): if not time: # If we are in late stage shutdown, the interpreter may replace # `time` with `None`. return all(not proc.is_alive() for proc in procs) start_time = time.time() while time.time() - start_time < timeout: if all(not proc.is_alive() for proc in procs): return True time.sleep(0.1) return False # Send SIGTERM if still running active_procs = [proc for proc in worker_procs if proc.is_alive()] for p in active_procs: p.terminate() if not wait_for_termination(active_procs, 4): # Send SIGKILL if still running active_procs = [p for p in active_procs if p.is_alive()] for p in active_procs: p.kill() def shutdown(self): """Properly shut down the executor and its workers""" if not getattr(self, 'shutting_down', False): self.shutting_down = True self.shutdown_event.set() if self.io_thread_pool is not None: self.io_thread_pool.shutdown(wait=False, cancel_futures=True) self.io_thread_pool = None if workers := getattr(self, 'workers', None): for w in workers: w.worker_response_mq = None self._ensure_worker_termination([w.proc for w in workers]) self.rpc_broadcast_mq = None def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) return @property def max_concurrent_batches(self) -> int: return self.parallel_config.pipeline_parallel_size def _get_output_rank(self) -> int: # Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1 # (the first TP worker of the last PP stage). # Example: # Assuming TP=8, PP=4, then the world_size=32 # 0-7, PP rank 0 # 8-15, PP rank 1 # 16-23, PP rank 2 # 24-31, PP rank 3 # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3) return self.world_size - self.parallel_config.tensor_parallel_size @dataclass class UnreadyWorkerProcHandle: """WorkerProcess handle before READY.""" proc: BaseProcess rank: int ready_pipe: Connection @dataclass class WorkerProcHandle: proc: BaseProcess rank: int worker_response_mq: MessageQueue # The worker process writes to this MQ @classmethod def from_unready_handle( cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue) -> "WorkerProcHandle": return cls( proc=unready_handle.proc, rank=unready_handle.rank, worker_response_mq=worker_response_mq, ) class WorkerProc: """Wrapper that runs one Worker in a separate process.""" READY_STR = "READY" def __init__( self, vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, input_shm_handle: Handle, ): self.rank = rank wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) # TODO: move `init_worker` to executor level as a collective rpc call all_kwargs: list[dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) ] is_driver_worker = ( rank % vllm_config.parallel_config.tensor_parallel_size == 0) all_kwargs[rank] = { "vllm_config": vllm_config, "local_rank": local_rank, "rank": rank, "distributed_init_method": distributed_init_method, "is_driver_worker": is_driver_worker, } wrapper.init_worker(all_kwargs) self.worker = wrapper pid = os.getpid() _add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid) _add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid) # Initialize MessageQueue for receiving SchedulerOutput self.rpc_broadcast_mq = MessageQueue.create_from_handle( input_shm_handle, self.worker.rank) # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) # Initialize device and loads weights self.worker.init_device() self.worker.load_model() @staticmethod def make_worker_process( vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, input_shm_handle, # Receive SchedulerOutput ) -> UnreadyWorkerProcHandle: context = get_mp_context() # (reader, writer) reader, writer = context.Pipe(duplex=False) process_kwargs = { "vllm_config": vllm_config, "local_rank": local_rank, "rank": rank, "distributed_init_method": distributed_init_method, "input_shm_handle": input_shm_handle, "ready_pipe": (reader, writer), } # Run EngineCore busy loop in background process. proc = context.Process(target=WorkerProc.worker_main, kwargs=process_kwargs, name=f"VllmWorker-{rank}", daemon=True) proc.start() writer.close() return UnreadyWorkerProcHandle(proc, rank, reader) @staticmethod def wait_for_ready( unready_proc_handles: list[UnreadyWorkerProcHandle] ) -> list[WorkerProcHandle]: e = Exception("WorkerProc initialization failed due to " "an exception in a background process. " "See stack trace for root cause.") pipes = {handle.ready_pipe: handle for handle in unready_proc_handles} ready_proc_handles: list[Optional[WorkerProcHandle]] = ( [None] * len(unready_proc_handles)) while pipes: ready = multiprocessing.connection.wait(pipes.keys()) for pipe in ready: assert isinstance(pipe, Connection) try: # Wait until the WorkerProc is ready. unready_proc_handle = pipes.pop(pipe) response: dict[str, Any] = pipe.recv() if response["status"] != "READY": raise e # Extract the message queue handle. worker_response_mq = MessageQueue.create_from_handle( response["handle"], 0) ready_proc_handles[unready_proc_handle.rank] = ( WorkerProcHandle.from_unready_handle( unready_proc_handle, worker_response_mq)) except EOFError: e.__suppress_context__ = True raise e from None finally: # Close connection. pipe.close() return cast(list[WorkerProcHandle], ready_proc_handles) def shutdown(self): self.rpc_broadcast_mq = None self.worker_response_mq = None destroy_model_parallel() destroy_distributed_environment() @staticmethod def worker_main(*args, **kwargs): """ Worker initialization and execution loops. This runs a background process """ # Signal handler used for graceful termination. # SystemExit exception is only raised once to allow this and worker # processes to terminate without error shutdown_requested = False def signal_handler(signum, frame): nonlocal shutdown_requested if not shutdown_requested: shutdown_requested = True raise SystemExit() # Either SIGTERM or SIGINT will terminate the worker signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) worker = None # tuple[Connection, Connection] reader, ready_writer = kwargs.pop("ready_pipe") try: reader.close() worker = WorkerProc(*args, **kwargs) # Send READY once we know everything is loaded ready_writer.send({ "status": WorkerProc.READY_STR, "handle": worker.worker_response_mq.export_handle(), }) # Ensure message queues are ready. Will deadlock if re-ordered. # Must be kept consistent with the Executor worker.rpc_broadcast_mq.wait_until_ready() worker.worker_response_mq.wait_until_ready() ready_writer.close() ready_writer = None worker.worker_busy_loop() except Exception: # NOTE: if an Exception arises in busy_loop, we send # a FAILURE message over the MQ RPC to notify the Executor, # which triggers system shutdown. # TODO(rob): handle case where the MQ itself breaks. if ready_writer is not None: logger.exception("WorkerProc failed to start.") else: logger.exception("WorkerProc failed.") # The parent sends a SIGTERM to all worker processes if # any worker dies. Set this value so we don't re-throw # SystemExit() to avoid zmq exceptions in __del__. shutdown_requested = True finally: if ready_writer is not None: ready_writer.close() # Clean up once worker exits busy loop if worker is not None: worker.shutdown() class ResponseStatus(Enum): SUCCESS = auto() FAILURE = auto() def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" while True: method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue() try: if isinstance(method, str): func = getattr(self.worker, method) elif isinstance(method, bytes): func = partial(cloudpickle.loads(method), self.worker) output = func(*args, **kwargs) except Exception as e: # Notes have been introduced in python 3.11 if hasattr(e, "add_note"): e.add_note(traceback.format_exc()) logger.exception("WorkerProc hit an exception.") # exception might not be serializable, so we convert it to # string, only for logging purpose. if output_rank is None or self.rank == output_rank: self.worker_response_mq.enqueue( (WorkerProc.ResponseStatus.FAILURE, str(e))) continue if output_rank is None or self.rank == output_rank: self.worker_response_mq.enqueue( (WorkerProc.ResponseStatus.SUCCESS, output))