[Bugfix][SHM] Fix weak memory ordering problem in share memory (#3988)

### What this PR does / why we need it?
This PR aims to fix weak memory ordering problem in share memory by
patching message queue with an additional lock. The detailed issue can
be found here https://github.com/vllm-project/vllm/issues/27858. The key
point is to use the writer lock to enforce memory fence before the ready
flag `metadata_buffer[0] = 1` is set.

This is a temporary solution, and you can use it by setting env
`SHM_BARRIER=true`. By default, we disable this modification.

### Does this PR introduce _any_ user-facing change?
`SHM_BARRIER=true` enables this change while `SHM_BARRIER=false`
disables this change. The latter is the default choice.

### How was this patch tested?
by ci

---------

Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
Zetong Li
2025-11-04 23:07:23 +08:00
committed by GitHub
parent 954dab64fb
commit 66b67f9cf2
3 changed files with 236 additions and 0 deletions

View File

@@ -24,3 +24,7 @@ import vllm_ascend.patch.platform.patch_sched_yield # noqa
if os.getenv("DYNAMIC_EPLB", "false") == "true" or os.getenv(
"EXPERT_MAP_RECORD", "false") == "true":
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa
if os.getenv("SHM_BARRIER", "false") == "true":
import vllm_ascend.patch.platform.patch_core # noqa
import vllm_ascend.patch.platform.patch_message_queue # noqa

View File

@@ -0,0 +1,68 @@
import signal
from typing import Optional
from vllm.config import ParallelConfig
from vllm.logger import logger
from vllm.transformers_utils.config import \
maybe_register_config_serialize_by_value
from vllm.utils import decorate_logs, set_process_title
from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
"""Launch EngineCore busy loop in background process."""
from vllm.distributed.device_communicators.shm_broadcast import \
MessageQueue # noqa
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
# Ensure we can serialize transformer config after spawning
maybe_register_config_serialize_by_value()
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the engine_core
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
engine_core: Optional[EngineCoreProc] = None
try:
parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
set_process_title("EngineCore", f"DP{dp_rank}")
decorate_logs()
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
set_process_title("EngineCore")
decorate_logs()
engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop()
except SystemExit:
logger.debug("EngineCore exiting.")
raise
except Exception as e:
if engine_core is None:
logger.exception("EngineCore failed to start.")
else:
logger.exception("EngineCore encountered a fatal error.")
engine_core._send_engine_dead()
raise e
finally:
if engine_core is not None:
engine_core.shutdown()
EngineCoreProc.run_engine_core = run_engine_core

View File

@@ -0,0 +1,164 @@
import time
from contextlib import contextmanager
from typing import Optional
import vllm.envs as envs
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue,
ShmRingBuffer,
SpinTimer)
from vllm.distributed.utils import sched_yield
from vllm.logger import logger
from vllm.utils import (get_ip, get_mp_context, get_open_port,
get_open_zmq_ipc_path, is_valid_ipv6_address)
from zmq import IPV6, XPUB, XPUB_VERBOSE, Context # type: ignore
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
def __init__(
self,
n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory
local_reader_ranks: Optional[list[int]] = None,
max_chunk_bytes: int = 1024 * 1024 * 10,
max_chunks: int = 10,
connect_ip: Optional[str] = None,
):
if local_reader_ranks is None:
local_reader_ranks = list(range(n_local_reader))
else:
assert len(local_reader_ranks) == n_local_reader
self.n_local_reader = n_local_reader
n_remote_reader = n_reader - n_local_reader
self.n_remote_reader = n_remote_reader
context = Context()
if n_local_reader > 0:
# for local readers, we will:
# 1. create a shared memory ring buffer to communicate small data
# 2. create a publish-subscribe socket to communicate large data
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
max_chunks)
# XPUB is very similar to PUB,
# except that it can receive subscription messages
# to confirm the number of subscribers
self.local_socket = context.socket(XPUB)
# set the verbose option so that we can receive every subscription
# message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self.local_socket.setsockopt(XPUB_VERBOSE, True)
local_subscribe_addr = get_open_zmq_ipc_path()
logger.debug("Binding to %s", local_subscribe_addr)
self.local_socket.bind(local_subscribe_addr)
self.current_idx = 0
self.writer_lock = get_mp_context().Lock()
else:
self.buffer = None # type: ignore
local_subscribe_addr = None
self.local_socket = None
self.current_idx = -1
remote_addr_ipv6 = False
if n_remote_reader > 0:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
if not connect_ip:
connect_ip = get_ip()
self.remote_socket = context.socket(XPUB)
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
remote_subscribe_port = get_open_port()
if is_valid_ipv6_address(connect_ip):
self.remote_socket.setsockopt(IPV6, 1)
remote_addr_ipv6 = True
connect_ip = f"[{connect_ip}]"
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
self.remote_socket.bind(socket_addr)
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
else:
remote_subscribe_addr = None
self.remote_socket = None
self._is_writer = True
self._is_local_reader = False
self.local_reader_rank = -1
# rank does not matter for remote readers
self._is_remote_reader = False
self._read_spin_timer = SpinTimer()
self.handle = Handle(
local_reader_ranks=local_reader_ranks,
buffer_handle=self.buffer.handle()
if self.buffer is not None else None,
local_subscribe_addr=local_subscribe_addr,
remote_subscribe_addr=remote_subscribe_addr,
remote_addr_ipv6=remote_addr_ipv6,
)
logger.info("vLLM message queue communication handle: %s", self.handle)
@contextmanager
def acquire_write(self, timeout: Optional[float] = None):
assert self._is_writer, "Only writers can acquire write"
start_time = time.monotonic()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_count = sum(metadata_buffer[1:])
written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader:
# this block is written and not read by all readers
# for writers, `self.current_idx` is the next block to write
# if this block is not ready to write,
# we need to wait until it is read by all readers
# Release the processor to other threads
sched_yield()
# if we time out, raise an exception
elapsed = time.monotonic() - start_time
if timeout is not None and elapsed > timeout:
raise TimeoutError
# if we wait for a long time, log a message
if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
logger.info(
"No available shared memory broadcast block found"
" in %s seconds. This typically happens when some"
" processes are hanging or doing some"
" time-consuming work (e.g. compilation)",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
continue
# found a block that is either
# (1) not written
# (2) read by all readers
with self.writer_lock:
# mark the block as not written
metadata_buffer[0] = 0
# let caller write to the buffer
with self.buffer.get_data(self.current_idx) as buf:
yield buf
# caller has written to the buffer
# NOTE: order is important here
# first set the read flags to 0
# then set the written flag to 1
# otherwise, the readers may think they already read the block
for i in range(1, self.buffer.n_reader + 1):
# set read flag to 0, meaning it is not read yet
metadata_buffer[i] = 0
# mark the block as written
metadata_buffer[0] = 1
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
break
MessageQueue.__init__ = __init__
MessageQueue.acquire_write = acquire_write