init
This commit is contained in:
0
v1/executor/__init__.py
Normal file
0
v1/executor/__init__.py
Normal file
113
v1/executor/abstract.py
Normal file
113
v1/executor/abstract.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
UniProcExecutor as UniProcExecutorV0)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
FailureCallback = Callable[[], None]
|
||||
|
||||
|
||||
class Executor(ExecutorBase):
|
||||
"""
|
||||
Abstract class for v1 executors, mainly define some methods for v1.
|
||||
For methods shared by v0 and v1, define them in ExecutorBase"""
|
||||
|
||||
@staticmethod
|
||||
def get_class(vllm_config: VllmConfig) -> type["Executor"]:
|
||||
executor_class: type[Executor]
|
||||
parallel_config = vllm_config.parallel_config
|
||||
distributed_executor_backend = (
|
||||
parallel_config.distributed_executor_backend)
|
||||
# distributed_executor_backend must be set in VllmConfig.__post_init__
|
||||
if isinstance(distributed_executor_backend, type):
|
||||
if not issubclass(distributed_executor_backend, ExecutorBase):
|
||||
raise TypeError(
|
||||
"distributed_executor_backend must be a subclass of "
|
||||
f"ExecutorBase. Got {distributed_executor_backend}.")
|
||||
executor_class = distributed_executor_backend
|
||||
elif distributed_executor_backend == "ray":
|
||||
from vllm.v1.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor)
|
||||
executor_class = RayDistributedExecutor
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
executor_class = MultiprocExecutor
|
||||
elif distributed_executor_backend == "uni":
|
||||
executor_class = UniProcExecutor
|
||||
elif distributed_executor_backend == "external_launcher":
|
||||
# TODO: make v1 scheduling deterministic
|
||||
# to support external launcher
|
||||
executor_class = ExecutorWithExternalLauncher
|
||||
else:
|
||||
raise ValueError("Unknown distributed executor backend: "
|
||||
f"{distributed_executor_backend}")
|
||||
return executor_class
|
||||
|
||||
def initialize_from_config(self,
|
||||
kv_cache_configs: list[KVCacheConfig]) -> None:
|
||||
"""
|
||||
Initialize the KV caches and begin the model execution loop of the
|
||||
underlying workers.
|
||||
"""
|
||||
self.collective_rpc("initialize_from_config",
|
||||
args=(kv_cache_configs, ))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
|
||||
def register_failure_callback(self, callback: FailureCallback):
|
||||
"""
|
||||
Register a function to be called if the executor enters a permanent
|
||||
failed state.
|
||||
"""
|
||||
pass
|
||||
|
||||
def determine_available_memory(self) -> list[int]: # in bytes
|
||||
output = self.collective_rpc("determine_available_memory")
|
||||
return output
|
||||
|
||||
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
|
||||
output = self.collective_rpc("get_kv_cache_spec")
|
||||
return output
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
output = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ))
|
||||
return output[0]
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return 1
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
self.collective_rpc("profile", args=(is_start, ))
|
||||
|
||||
|
||||
class UniProcExecutor(UniProcExecutorV0, Executor):
|
||||
pass
|
||||
|
||||
|
||||
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
|
||||
|
||||
def determine_available_memory(self) -> list[int]: # in bytes
|
||||
# same as determine_num_available_blocks in v0,
|
||||
# we need to get the min across all ranks.
|
||||
memory = super().determine_available_memory()
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
cpu_group = get_world_group().cpu_group
|
||||
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
|
||||
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
||||
return [memory_tensor.item()]
|
||||
537
v1/executor/multiproc_executor.py
Normal file
537
v1/executor/multiproc_executor.py
Normal file
@@ -0,0 +1,537 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from multiprocessing.connection import Connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
import cloudpickle
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||
MessageQueue)
|
||||
from vllm.executor.multiproc_worker_utils import (
|
||||
_add_prefix, set_multiprocessing_worker_envs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_distributed_init_method, get_mp_context,
|
||||
get_open_port)
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_MS = 5000
|
||||
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||
|
||||
EXECUTE_MODEL_TIMEOUT_S = 300
|
||||
|
||||
|
||||
class MultiprocExecutor(Executor):
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
# Call self.shutdown at exit to clean up
|
||||
# and ensure workers will be terminated.
|
||||
self._finalizer = weakref.finalize(self, self.shutdown)
|
||||
self.is_failed = False
|
||||
self.shutdown_event = threading.Event()
|
||||
self.failure_callback: Optional[FailureCallback] = None
|
||||
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
||||
f"world_size ({self.world_size}) must be equal to the "
|
||||
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
||||
f"_parallel_size ({pp_parallel_size}). ")
|
||||
|
||||
# Set multiprocessing envs that are common to V0 and V1
|
||||
set_multiprocessing_worker_envs(self.parallel_config)
|
||||
|
||||
# Multiprocessing-based executor does not support multi-node setting.
|
||||
# Since it only works for single node, we can use the loopback address
|
||||
# 127.0.0.1 for communication.
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
"127.0.0.1", get_open_port())
|
||||
|
||||
# Initialize worker and set up message queues for SchedulerOutputs
|
||||
# and ModelRunnerOutputs
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
self.rpc_broadcast_mq = MessageQueue(self.world_size,
|
||||
self.world_size,
|
||||
max_chunk_bytes=max_chunk_bytes)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
|
||||
# Create workers
|
||||
unready_workers: list[UnreadyWorkerProcHandle] = []
|
||||
success = False
|
||||
try:
|
||||
for rank in range(self.world_size):
|
||||
unready_workers.append(
|
||||
WorkerProc.make_worker_process(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
))
|
||||
|
||||
# Workers must be created before wait_for_ready to avoid
|
||||
# deadlock, since worker.init_device() does a device sync.
|
||||
self.workers = WorkerProc.wait_for_ready(unready_workers)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||
# Must be kept consistent with the WorkerProc.
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
|
||||
self.start_worker_monitor()
|
||||
success = True
|
||||
finally:
|
||||
if not success:
|
||||
# Clean up the worker procs if there was a failure.
|
||||
self._ensure_worker_termination(
|
||||
[w.proc for w in unready_workers])
|
||||
|
||||
# For pipeline parallel, we use a thread pool for asynchronous
|
||||
# execute_model.
|
||||
if self.max_concurrent_batches > 1:
|
||||
# Note: must use only 1 IO thread to keep dequeue sequence
|
||||
# from the response queue
|
||||
self.io_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="mp_exec_io")
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
|
||||
def start_worker_monitor(self):
|
||||
workers = self.workers
|
||||
self_ref = weakref.ref(self)
|
||||
|
||||
# Monitors worker process liveness. If any die unexpectedly,
|
||||
# logs an error, shuts down the executor and invokes the failure
|
||||
# callback to inform the engine.
|
||||
def monitor_workers():
|
||||
sentinels = [h.proc.sentinel for h in workers]
|
||||
died = multiprocessing.connection.wait(sentinels)
|
||||
_self = self_ref()
|
||||
if not _self or getattr(_self, 'shutting_down', False):
|
||||
return
|
||||
_self.is_failed = True
|
||||
proc_name = next(h.proc.name for h in workers
|
||||
if h.proc.sentinel == died[0])
|
||||
logger.error(
|
||||
"Worker proc %s died unexpectedly, "
|
||||
"shutting down executor.", proc_name)
|
||||
_self.shutdown()
|
||||
callback = _self.failure_callback
|
||||
if callback is not None:
|
||||
_self.failure_callback = None
|
||||
callback()
|
||||
|
||||
Thread(target=monitor_workers,
|
||||
daemon=True,
|
||||
name="MultiprocWorkerMonitor").start()
|
||||
|
||||
def register_failure_callback(self, callback: FailureCallback):
|
||||
if self.is_failed:
|
||||
callback()
|
||||
else:
|
||||
self.failure_callback = callback
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
(output, ) = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ),
|
||||
unique_reply_rank=self.output_rank,
|
||||
non_block=self.max_concurrent_batches
|
||||
> 1,
|
||||
timeout=EXECUTE_MODEL_TIMEOUT_S)
|
||||
return output
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None,
|
||||
non_block: bool = False,
|
||||
unique_reply_rank: Optional[int] = None) -> list[Any]:
|
||||
if self.is_failed:
|
||||
raise RuntimeError("Executor failed.")
|
||||
|
||||
deadline = None if timeout is None else time.monotonic() + timeout
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# NOTE: If the args are heterogeneous, then we pack them into a list,
|
||||
# and unpack them in the method of every worker, because every worker
|
||||
# knows their own rank.
|
||||
try:
|
||||
if isinstance(method, str):
|
||||
send_method = method
|
||||
else:
|
||||
send_method = cloudpickle.dumps(
|
||||
method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.rpc_broadcast_mq.enqueue(
|
||||
(send_method, args, kwargs, unique_reply_rank))
|
||||
|
||||
workers = (self.workers[unique_reply_rank],
|
||||
) if unique_reply_rank is not None else self.workers
|
||||
responses = []
|
||||
|
||||
def get_response(w: WorkerProcHandle,
|
||||
dequeue_timeout: Optional[float] = None,
|
||||
cancel_event: Optional[threading.Event] = None):
|
||||
status, result = w.worker_response_mq.dequeue(
|
||||
timeout=dequeue_timeout, cancel=cancel_event)
|
||||
|
||||
if status != WorkerProc.ResponseStatus.SUCCESS:
|
||||
raise RuntimeError(
|
||||
f"Worker failed with error '{result}', please check the"
|
||||
" stack trace above for the root cause")
|
||||
return result
|
||||
|
||||
for w in workers:
|
||||
dequeue_timeout = None if deadline is None else (
|
||||
deadline - time.monotonic())
|
||||
|
||||
if non_block:
|
||||
result = self.io_thread_pool.submit( # type: ignore
|
||||
get_response, w, dequeue_timeout, self.shutdown_event)
|
||||
else:
|
||||
result = get_response(w, dequeue_timeout)
|
||||
|
||||
responses.append(result)
|
||||
|
||||
return responses
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||
|
||||
@staticmethod
|
||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||
"""Ensure that all worker processes are terminated. Assumes workers have
|
||||
received termination requests. Waits for processing, then sends
|
||||
termination and kill signals if needed."""
|
||||
|
||||
def wait_for_termination(procs, timeout):
|
||||
if not time:
|
||||
# If we are in late stage shutdown, the interpreter may replace
|
||||
# `time` with `None`.
|
||||
return all(not proc.is_alive() for proc in procs)
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if all(not proc.is_alive() for proc in procs):
|
||||
return True
|
||||
time.sleep(0.1)
|
||||
return False
|
||||
|
||||
# Send SIGTERM if still running
|
||||
active_procs = [proc for proc in worker_procs if proc.is_alive()]
|
||||
for p in active_procs:
|
||||
p.terminate()
|
||||
if not wait_for_termination(active_procs, 4):
|
||||
# Send SIGKILL if still running
|
||||
active_procs = [p for p in active_procs if p.is_alive()]
|
||||
for p in active_procs:
|
||||
p.kill()
|
||||
|
||||
def shutdown(self):
|
||||
"""Properly shut down the executor and its workers"""
|
||||
if not getattr(self, 'shutting_down', False):
|
||||
self.shutting_down = True
|
||||
self.shutdown_event.set()
|
||||
|
||||
if self.io_thread_pool is not None:
|
||||
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
|
||||
self.io_thread_pool = None
|
||||
|
||||
if workers := getattr(self, 'workers', None):
|
||||
for w in workers:
|
||||
w.worker_response_mq = None
|
||||
self._ensure_worker_termination([w.proc for w in workers])
|
||||
|
||||
self.rpc_broadcast_mq = None
|
||||
|
||||
def check_health(self) -> None:
|
||||
self.collective_rpc("check_health", timeout=10)
|
||||
return
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return self.parallel_config.pipeline_parallel_size
|
||||
|
||||
def _get_output_rank(self) -> int:
|
||||
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
|
||||
# (the first TP worker of the last PP stage).
|
||||
# Example:
|
||||
# Assuming TP=8, PP=4, then the world_size=32
|
||||
# 0-7, PP rank 0
|
||||
# 8-15, PP rank 1
|
||||
# 16-23, PP rank 2
|
||||
# 24-31, PP rank 3
|
||||
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
|
||||
return self.world_size - self.parallel_config.tensor_parallel_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnreadyWorkerProcHandle:
|
||||
"""WorkerProcess handle before READY."""
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
ready_pipe: Connection
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerProcHandle:
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
worker_response_mq: MessageQueue # The worker process writes to this MQ
|
||||
|
||||
@classmethod
|
||||
def from_unready_handle(
|
||||
cls, unready_handle: UnreadyWorkerProcHandle,
|
||||
worker_response_mq: MessageQueue) -> "WorkerProcHandle":
|
||||
return cls(
|
||||
proc=unready_handle.proc,
|
||||
rank=unready_handle.rank,
|
||||
worker_response_mq=worker_response_mq,
|
||||
)
|
||||
|
||||
|
||||
class WorkerProc:
|
||||
"""Wrapper that runs one Worker in a separate process."""
|
||||
|
||||
READY_STR = "READY"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle: Handle,
|
||||
):
|
||||
self.rank = rank
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
|
||||
# TODO: move `init_worker` to executor level as a collective rpc call
|
||||
all_kwargs: list[dict] = [
|
||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||
]
|
||||
is_driver_worker = (
|
||||
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
|
||||
all_kwargs[rank] = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
"distributed_init_method": distributed_init_method,
|
||||
"is_driver_worker": is_driver_worker,
|
||||
}
|
||||
wrapper.init_worker(all_kwargs)
|
||||
self.worker = wrapper
|
||||
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
|
||||
_add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid)
|
||||
|
||||
# Initialize MessageQueue for receiving SchedulerOutput
|
||||
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||
input_shm_handle, self.worker.rank)
|
||||
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq = MessageQueue(1, 1)
|
||||
|
||||
# Initialize device and loads weights
|
||||
self.worker.init_device()
|
||||
self.worker.load_model()
|
||||
|
||||
@staticmethod
|
||||
def make_worker_process(
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle, # Receive SchedulerOutput
|
||||
) -> UnreadyWorkerProcHandle:
|
||||
context = get_mp_context()
|
||||
# (reader, writer)
|
||||
reader, writer = context.Pipe(duplex=False)
|
||||
|
||||
process_kwargs = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
"distributed_init_method": distributed_init_method,
|
||||
"input_shm_handle": input_shm_handle,
|
||||
"ready_pipe": (reader, writer),
|
||||
}
|
||||
# Run EngineCore busy loop in background process.
|
||||
proc = context.Process(target=WorkerProc.worker_main,
|
||||
kwargs=process_kwargs,
|
||||
name=f"VllmWorker-{rank}",
|
||||
daemon=True)
|
||||
|
||||
proc.start()
|
||||
writer.close()
|
||||
return UnreadyWorkerProcHandle(proc, rank, reader)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_ready(
|
||||
unready_proc_handles: list[UnreadyWorkerProcHandle]
|
||||
) -> list[WorkerProcHandle]:
|
||||
|
||||
e = Exception("WorkerProc initialization failed due to "
|
||||
"an exception in a background process. "
|
||||
"See stack trace for root cause.")
|
||||
|
||||
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
|
||||
ready_proc_handles: list[Optional[WorkerProcHandle]] = (
|
||||
[None] * len(unready_proc_handles))
|
||||
while pipes:
|
||||
ready = multiprocessing.connection.wait(pipes.keys())
|
||||
for pipe in ready:
|
||||
assert isinstance(pipe, Connection)
|
||||
try:
|
||||
# Wait until the WorkerProc is ready.
|
||||
unready_proc_handle = pipes.pop(pipe)
|
||||
response: dict[str, Any] = pipe.recv()
|
||||
if response["status"] != "READY":
|
||||
raise e
|
||||
|
||||
# Extract the message queue handle.
|
||||
worker_response_mq = MessageQueue.create_from_handle(
|
||||
response["handle"], 0)
|
||||
ready_proc_handles[unready_proc_handle.rank] = (
|
||||
WorkerProcHandle.from_unready_handle(
|
||||
unready_proc_handle, worker_response_mq))
|
||||
|
||||
except EOFError:
|
||||
e.__suppress_context__ = True
|
||||
raise e from None
|
||||
|
||||
finally:
|
||||
# Close connection.
|
||||
pipe.close()
|
||||
|
||||
return cast(list[WorkerProcHandle], ready_proc_handles)
|
||||
|
||||
def shutdown(self):
|
||||
self.rpc_broadcast_mq = None
|
||||
self.worker_response_mq = None
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
||||
@staticmethod
|
||||
def worker_main(*args, **kwargs):
|
||||
""" Worker initialization and execution loops.
|
||||
This runs a background process """
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
# processes to terminate without error
|
||||
shutdown_requested = False
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the worker
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
worker = None
|
||||
# tuple[Connection, Connection]
|
||||
reader, ready_writer = kwargs.pop("ready_pipe")
|
||||
try:
|
||||
reader.close()
|
||||
worker = WorkerProc(*args, **kwargs)
|
||||
|
||||
# Send READY once we know everything is loaded
|
||||
ready_writer.send({
|
||||
"status":
|
||||
WorkerProc.READY_STR,
|
||||
"handle":
|
||||
worker.worker_response_mq.export_handle(),
|
||||
})
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered.
|
||||
# Must be kept consistent with the Executor
|
||||
worker.rpc_broadcast_mq.wait_until_ready()
|
||||
worker.worker_response_mq.wait_until_ready()
|
||||
ready_writer.close()
|
||||
ready_writer = None
|
||||
|
||||
worker.worker_busy_loop()
|
||||
|
||||
except Exception:
|
||||
# NOTE: if an Exception arises in busy_loop, we send
|
||||
# a FAILURE message over the MQ RPC to notify the Executor,
|
||||
# which triggers system shutdown.
|
||||
# TODO(rob): handle case where the MQ itself breaks.
|
||||
|
||||
if ready_writer is not None:
|
||||
logger.exception("WorkerProc failed to start.")
|
||||
else:
|
||||
logger.exception("WorkerProc failed.")
|
||||
|
||||
# The parent sends a SIGTERM to all worker processes if
|
||||
# any worker dies. Set this value so we don't re-throw
|
||||
# SystemExit() to avoid zmq exceptions in __del__.
|
||||
shutdown_requested = True
|
||||
|
||||
finally:
|
||||
if ready_writer is not None:
|
||||
ready_writer.close()
|
||||
# Clean up once worker exits busy loop
|
||||
if worker is not None:
|
||||
worker.shutdown()
|
||||
|
||||
class ResponseStatus(Enum):
|
||||
SUCCESS = auto()
|
||||
FAILURE = auto()
|
||||
|
||||
def worker_busy_loop(self):
|
||||
"""Main busy loop for Multiprocessing Workers"""
|
||||
while True:
|
||||
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
|
||||
|
||||
try:
|
||||
if isinstance(method, str):
|
||||
func = getattr(self.worker, method)
|
||||
elif isinstance(method, bytes):
|
||||
func = partial(cloudpickle.loads(method), self.worker)
|
||||
output = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Notes have been introduced in python 3.11
|
||||
if hasattr(e, "add_note"):
|
||||
e.add_note(traceback.format_exc())
|
||||
logger.exception("WorkerProc hit an exception.")
|
||||
# exception might not be serializable, so we convert it to
|
||||
# string, only for logging purpose.
|
||||
if output_rank is None or self.rank == output_rank:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.FAILURE, str(e)))
|
||||
continue
|
||||
|
||||
if output_rank is None or self.rank == output_rank:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.SUCCESS, output))
|
||||
62
v1/executor/ray_distributed_executor.py
Normal file
62
v1/executor/ray_distributed_executor.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import Union
|
||||
|
||||
from vllm.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor as RayDistributedExecutorV0)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
|
||||
class FutureWrapper(Future):
|
||||
"""A wrapper around a Ray output reference to meet the interface
|
||||
of .execute_model().
|
||||
"""
|
||||
|
||||
def __init__(self, ref):
|
||||
super().__init__()
|
||||
self.ref = ref
|
||||
|
||||
def result(self, timeout=None):
|
||||
if timeout is not None:
|
||||
raise NotImplementedError("timeout is not supported")
|
||||
return self.ref.get()
|
||||
|
||||
|
||||
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||
"""Ray distributed executor using Ray Compiled Graphs."""
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
"""Ray distributed executor supports pipeline parallelism,
|
||||
meaning that it allows PP size batches to be executed concurrently.
|
||||
"""
|
||||
return self.parallel_config.pipeline_parallel_size
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
"""Execute the model on the Ray workers.
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output to execute.
|
||||
|
||||
Returns:
|
||||
The model runner output.
|
||||
"""
|
||||
# Build the compiled DAG for the first time.
|
||||
if self.forward_dag is None: # type: ignore
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
refs = self.forward_dag.execute(scheduler_output) # type: ignore
|
||||
|
||||
# When PP is not used, we block here until the result is available.
|
||||
if self.max_concurrent_batches == 1:
|
||||
return refs[0].get()
|
||||
|
||||
# When PP is used, we return a FutureWrapper immediately so that
|
||||
# the scheduler can yield to the next batch.
|
||||
return FutureWrapper(refs[0])
|
||||
Reference in New Issue
Block a user