[Bugfix][SHM] Fix weak memory ordering problem in share memory (#3988)
### What this PR does / why we need it? This PR aims to fix weak memory ordering problem in share memory by patching message queue with an additional lock. The detailed issue can be found here https://github.com/vllm-project/vllm/issues/27858. The key point is to use the writer lock to enforce memory fence before the ready flag `metadata_buffer[0] = 1` is set. This is a temporary solution, and you can use it by setting env `SHM_BARRIER=true`. By default, we disable this modification. ### Does this PR introduce _any_ user-facing change? `SHM_BARRIER=true` enables this change while `SHM_BARRIER=false` disables this change. The latter is the default choice. ### How was this patch tested? by ci --------- Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -24,3 +24,7 @@ import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
||||
if os.getenv("DYNAMIC_EPLB", "false") == "true" or os.getenv(
|
||||
"EXPERT_MAP_RECORD", "false") == "true":
|
||||
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa
|
||||
|
||||
if os.getenv("SHM_BARRIER", "false") == "true":
|
||||
import vllm_ascend.patch.platform.patch_core # noqa
|
||||
import vllm_ascend.patch.platform.patch_message_queue # noqa
|
||||
|
||||
68
vllm_ascend/patch/platform/patch_core.py
Normal file
68
vllm_ascend/patch/platform/patch_core.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import signal
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.transformers_utils.config import \
|
||||
maybe_register_config_serialize_by_value
|
||||
from vllm.utils import decorate_logs, set_process_title
|
||||
from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc
|
||||
|
||||
|
||||
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import \
|
||||
MessageQueue # noqa
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
# processes to terminate without error
|
||||
shutdown_requested = False
|
||||
|
||||
# Ensure we can serialize transformer config after spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the engine_core
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
engine_core: Optional[EngineCoreProc] = None
|
||||
try:
|
||||
parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
|
||||
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
|
||||
set_process_title("EngineCore", f"DP{dp_rank}")
|
||||
decorate_logs()
|
||||
# Set data parallel rank for this engine process.
|
||||
parallel_config.data_parallel_rank = dp_rank
|
||||
parallel_config.data_parallel_rank_local = local_dp_rank
|
||||
engine_core = DPEngineCoreProc(*args, **kwargs)
|
||||
else:
|
||||
set_process_title("EngineCore")
|
||||
decorate_logs()
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
logger.debug("EngineCore exiting.")
|
||||
raise
|
||||
except Exception as e:
|
||||
if engine_core is None:
|
||||
logger.exception("EngineCore failed to start.")
|
||||
else:
|
||||
logger.exception("EngineCore encountered a fatal error.")
|
||||
engine_core._send_engine_dead()
|
||||
raise e
|
||||
finally:
|
||||
if engine_core is not None:
|
||||
engine_core.shutdown()
|
||||
|
||||
|
||||
EngineCoreProc.run_engine_core = run_engine_core
|
||||
164
vllm_ascend/patch/platform/patch_message_queue.py
Normal file
164
vllm_ascend/patch/platform/patch_message_queue.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||
MessageQueue,
|
||||
ShmRingBuffer,
|
||||
SpinTimer)
|
||||
from vllm.distributed.utils import sched_yield
|
||||
from vllm.logger import logger
|
||||
from vllm.utils import (get_ip, get_mp_context, get_open_port,
|
||||
get_open_zmq_ipc_path, is_valid_ipv6_address)
|
||||
from zmq import IPV6, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_reader, # number of all readers
|
||||
n_local_reader, # number of local readers through shared memory
|
||||
local_reader_ranks: Optional[list[int]] = None,
|
||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||
max_chunks: int = 10,
|
||||
connect_ip: Optional[str] = None,
|
||||
):
|
||||
if local_reader_ranks is None:
|
||||
local_reader_ranks = list(range(n_local_reader))
|
||||
else:
|
||||
assert len(local_reader_ranks) == n_local_reader
|
||||
self.n_local_reader = n_local_reader
|
||||
n_remote_reader = n_reader - n_local_reader
|
||||
self.n_remote_reader = n_remote_reader
|
||||
|
||||
context = Context()
|
||||
|
||||
if n_local_reader > 0:
|
||||
# for local readers, we will:
|
||||
# 1. create a shared memory ring buffer to communicate small data
|
||||
# 2. create a publish-subscribe socket to communicate large data
|
||||
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
||||
max_chunks)
|
||||
|
||||
# XPUB is very similar to PUB,
|
||||
# except that it can receive subscription messages
|
||||
# to confirm the number of subscribers
|
||||
self.local_socket = context.socket(XPUB)
|
||||
# set the verbose option so that we can receive every subscription
|
||||
# message. otherwise, we will only receive the first subscription
|
||||
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
||||
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
local_subscribe_addr = get_open_zmq_ipc_path()
|
||||
logger.debug("Binding to %s", local_subscribe_addr)
|
||||
self.local_socket.bind(local_subscribe_addr)
|
||||
|
||||
self.current_idx = 0
|
||||
self.writer_lock = get_mp_context().Lock()
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
local_subscribe_addr = None
|
||||
self.local_socket = None
|
||||
self.current_idx = -1
|
||||
|
||||
remote_addr_ipv6 = False
|
||||
if n_remote_reader > 0:
|
||||
# for remote readers, we will:
|
||||
# create a publish-subscribe socket to communicate large data
|
||||
if not connect_ip:
|
||||
connect_ip = get_ip()
|
||||
self.remote_socket = context.socket(XPUB)
|
||||
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
remote_subscribe_port = get_open_port()
|
||||
if is_valid_ipv6_address(connect_ip):
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
remote_addr_ipv6 = True
|
||||
connect_ip = f"[{connect_ip}]"
|
||||
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
self.remote_socket.bind(socket_addr)
|
||||
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
else:
|
||||
remote_subscribe_addr = None
|
||||
self.remote_socket = None
|
||||
|
||||
self._is_writer = True
|
||||
self._is_local_reader = False
|
||||
self.local_reader_rank = -1
|
||||
# rank does not matter for remote readers
|
||||
self._is_remote_reader = False
|
||||
self._read_spin_timer = SpinTimer()
|
||||
|
||||
self.handle = Handle(
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer_handle=self.buffer.handle()
|
||||
if self.buffer is not None else None,
|
||||
local_subscribe_addr=local_subscribe_addr,
|
||||
remote_subscribe_addr=remote_subscribe_addr,
|
||||
remote_addr_ipv6=remote_addr_ipv6,
|
||||
)
|
||||
|
||||
logger.info("vLLM message queue communication handle: %s", self.handle)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def acquire_write(self, timeout: Optional[float] = None):
|
||||
assert self._is_writer, "Only writers can acquire write"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
read_count = sum(metadata_buffer[1:])
|
||||
written_flag = metadata_buffer[0]
|
||||
if written_flag and read_count != self.buffer.n_reader:
|
||||
# this block is written and not read by all readers
|
||||
# for writers, `self.current_idx` is the next block to write
|
||||
# if this block is not ready to write,
|
||||
# we need to wait until it is read by all readers
|
||||
|
||||
# Release the processor to other threads
|
||||
sched_yield()
|
||||
|
||||
# if we time out, raise an exception
|
||||
elapsed = time.monotonic() - start_time
|
||||
if timeout is not None and elapsed > timeout:
|
||||
raise TimeoutError
|
||||
|
||||
# if we wait for a long time, log a message
|
||||
if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
|
||||
logger.info(
|
||||
"No available shared memory broadcast block found"
|
||||
" in %s seconds. This typically happens when some"
|
||||
" processes are hanging or doing some"
|
||||
" time-consuming work (e.g. compilation)",
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
n_warning += 1
|
||||
|
||||
continue
|
||||
# found a block that is either
|
||||
# (1) not written
|
||||
# (2) read by all readers
|
||||
|
||||
with self.writer_lock:
|
||||
# mark the block as not written
|
||||
metadata_buffer[0] = 0
|
||||
# let caller write to the buffer
|
||||
with self.buffer.get_data(self.current_idx) as buf:
|
||||
yield buf
|
||||
|
||||
# caller has written to the buffer
|
||||
# NOTE: order is important here
|
||||
# first set the read flags to 0
|
||||
# then set the written flag to 1
|
||||
# otherwise, the readers may think they already read the block
|
||||
for i in range(1, self.buffer.n_reader + 1):
|
||||
# set read flag to 0, meaning it is not read yet
|
||||
metadata_buffer[i] = 0
|
||||
# mark the block as written
|
||||
metadata_buffer[0] = 1
|
||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||
break
|
||||
|
||||
|
||||
MessageQueue.__init__ = __init__
|
||||
MessageQueue.acquire_write = acquire_write
|
||||
Reference in New Issue
Block a user