diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index b4ef633..42d86a3 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -24,3 +24,7 @@ import vllm_ascend.patch.platform.patch_sched_yield # noqa if os.getenv("DYNAMIC_EPLB", "false") == "true" or os.getenv( "EXPERT_MAP_RECORD", "false") == "true": import vllm_ascend.patch.platform.patch_multiproc_executor # noqa + +if os.getenv("SHM_BARRIER", "false") == "true": + import vllm_ascend.patch.platform.patch_core # noqa + import vllm_ascend.patch.platform.patch_message_queue # noqa diff --git a/vllm_ascend/patch/platform/patch_core.py b/vllm_ascend/patch/platform/patch_core.py new file mode 100644 index 0000000..56a519f --- /dev/null +++ b/vllm_ascend/patch/platform/patch_core.py @@ -0,0 +1,68 @@ +import signal +from typing import Optional + +from vllm.config import ParallelConfig +from vllm.logger import logger +from vllm.transformers_utils.config import \ + maybe_register_config_serialize_by_value +from vllm.utils import decorate_logs, set_process_title +from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc + + +def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): + """Launch EngineCore busy loop in background process.""" + + from vllm.distributed.device_communicators.shm_broadcast import \ + MessageQueue # noqa + + # Signal handler used for graceful termination. + # SystemExit exception is only raised once to allow this and worker + # processes to terminate without error + shutdown_requested = False + + # Ensure we can serialize transformer config after spawning + maybe_register_config_serialize_by_value() + + def signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the engine_core + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + engine_core: Optional[EngineCoreProc] = None + try: + parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config + if parallel_config.data_parallel_size > 1 or dp_rank > 0: + set_process_title("EngineCore", f"DP{dp_rank}") + decorate_logs() + # Set data parallel rank for this engine process. + parallel_config.data_parallel_rank = dp_rank + parallel_config.data_parallel_rank_local = local_dp_rank + engine_core = DPEngineCoreProc(*args, **kwargs) + else: + set_process_title("EngineCore") + decorate_logs() + engine_core = EngineCoreProc(*args, **kwargs) + + engine_core.run_busy_loop() + + except SystemExit: + logger.debug("EngineCore exiting.") + raise + except Exception as e: + if engine_core is None: + logger.exception("EngineCore failed to start.") + else: + logger.exception("EngineCore encountered a fatal error.") + engine_core._send_engine_dead() + raise e + finally: + if engine_core is not None: + engine_core.shutdown() + + +EngineCoreProc.run_engine_core = run_engine_core diff --git a/vllm_ascend/patch/platform/patch_message_queue.py b/vllm_ascend/patch/platform/patch_message_queue.py new file mode 100644 index 0000000..7bf183c --- /dev/null +++ b/vllm_ascend/patch/platform/patch_message_queue.py @@ -0,0 +1,164 @@ +import time +from contextlib import contextmanager +from typing import Optional + +import vllm.envs as envs +from vllm.distributed.device_communicators.shm_broadcast import (Handle, + MessageQueue, + ShmRingBuffer, + SpinTimer) +from vllm.distributed.utils import sched_yield +from vllm.logger import logger +from vllm.utils import (get_ip, get_mp_context, get_open_port, + get_open_zmq_ipc_path, is_valid_ipv6_address) +from zmq import IPV6, XPUB, XPUB_VERBOSE, Context # type: ignore + +VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL + + +def __init__( + self, + n_reader, # number of all readers + n_local_reader, # number of local readers through shared memory + local_reader_ranks: Optional[list[int]] = None, + max_chunk_bytes: int = 1024 * 1024 * 10, + max_chunks: int = 10, + connect_ip: Optional[str] = None, +): + if local_reader_ranks is None: + local_reader_ranks = list(range(n_local_reader)) + else: + assert len(local_reader_ranks) == n_local_reader + self.n_local_reader = n_local_reader + n_remote_reader = n_reader - n_local_reader + self.n_remote_reader = n_remote_reader + + context = Context() + + if n_local_reader > 0: + # for local readers, we will: + # 1. create a shared memory ring buffer to communicate small data + # 2. create a publish-subscribe socket to communicate large data + self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, + max_chunks) + + # XPUB is very similar to PUB, + # except that it can receive subscription messages + # to confirm the number of subscribers + self.local_socket = context.socket(XPUB) + # set the verbose option so that we can receive every subscription + # message. otherwise, we will only receive the first subscription + # see http://api.zeromq.org/3-3:zmq-setsockopt for more details + self.local_socket.setsockopt(XPUB_VERBOSE, True) + local_subscribe_addr = get_open_zmq_ipc_path() + logger.debug("Binding to %s", local_subscribe_addr) + self.local_socket.bind(local_subscribe_addr) + + self.current_idx = 0 + self.writer_lock = get_mp_context().Lock() + else: + self.buffer = None # type: ignore + local_subscribe_addr = None + self.local_socket = None + self.current_idx = -1 + + remote_addr_ipv6 = False + if n_remote_reader > 0: + # for remote readers, we will: + # create a publish-subscribe socket to communicate large data + if not connect_ip: + connect_ip = get_ip() + self.remote_socket = context.socket(XPUB) + self.remote_socket.setsockopt(XPUB_VERBOSE, True) + remote_subscribe_port = get_open_port() + if is_valid_ipv6_address(connect_ip): + self.remote_socket.setsockopt(IPV6, 1) + remote_addr_ipv6 = True + connect_ip = f"[{connect_ip}]" + socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" + self.remote_socket.bind(socket_addr) + remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" + else: + remote_subscribe_addr = None + self.remote_socket = None + + self._is_writer = True + self._is_local_reader = False + self.local_reader_rank = -1 + # rank does not matter for remote readers + self._is_remote_reader = False + self._read_spin_timer = SpinTimer() + + self.handle = Handle( + local_reader_ranks=local_reader_ranks, + buffer_handle=self.buffer.handle() + if self.buffer is not None else None, + local_subscribe_addr=local_subscribe_addr, + remote_subscribe_addr=remote_subscribe_addr, + remote_addr_ipv6=remote_addr_ipv6, + ) + + logger.info("vLLM message queue communication handle: %s", self.handle) + + +@contextmanager +def acquire_write(self, timeout: Optional[float] = None): + assert self._is_writer, "Only writers can acquire write" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_count = sum(metadata_buffer[1:]) + written_flag = metadata_buffer[0] + if written_flag and read_count != self.buffer.n_reader: + # this block is written and not read by all readers + # for writers, `self.current_idx` is the next block to write + # if this block is not ready to write, + # we need to wait until it is read by all readers + + # Release the processor to other threads + sched_yield() + + # if we time out, raise an exception + elapsed = time.monotonic() - start_time + if timeout is not None and elapsed > timeout: + raise TimeoutError + + # if we wait for a long time, log a message + if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: + logger.info( + "No available shared memory broadcast block found" + " in %s seconds. This typically happens when some" + " processes are hanging or doing some" + " time-consuming work (e.g. compilation)", + VLLM_RINGBUFFER_WARNING_INTERVAL) + n_warning += 1 + + continue + # found a block that is either + # (1) not written + # (2) read by all readers + + with self.writer_lock: + # mark the block as not written + metadata_buffer[0] = 0 + # let caller write to the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has written to the buffer + # NOTE: order is important here + # first set the read flags to 0 + # then set the written flag to 1 + # otherwise, the readers may think they already read the block + for i in range(1, self.buffer.n_reader + 1): + # set read flag to 0, meaning it is not read yet + metadata_buffer[i] = 0 + # mark the block as written + metadata_buffer[0] = 1 + self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks + break + + +MessageQueue.__init__ = __init__ +MessageQueue.acquire_write = acquire_write