### What this PR does / why we need it? This PR aims to fix weak memory ordering problem in share memory by patching message queue with an additional lock. The detailed issue can be found here https://github.com/vllm-project/vllm/issues/27858. The key point is to use the writer lock to enforce memory fence before the ready flag `metadata_buffer[0] = 1` is set. This is a temporary solution, and you can use it by setting env `SHM_BARRIER=true`. By default, we disable this modification. ### Does this PR introduce _any_ user-facing change? `SHM_BARRIER=true` enables this change while `SHM_BARRIER=false` disables this change. The latter is the default choice. ### How was this patch tested? by ci --------- Signed-off-by: Zetong Li <slippersss@126.com>
165 lines
6.5 KiB
Python
165 lines
6.5 KiB
Python
import time
|
|
from contextlib import contextmanager
|
|
from typing import Optional
|
|
|
|
import vllm.envs as envs
|
|
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
|
MessageQueue,
|
|
ShmRingBuffer,
|
|
SpinTimer)
|
|
from vllm.distributed.utils import sched_yield
|
|
from vllm.logger import logger
|
|
from vllm.utils import (get_ip, get_mp_context, get_open_port,
|
|
get_open_zmq_ipc_path, is_valid_ipv6_address)
|
|
from zmq import IPV6, XPUB, XPUB_VERBOSE, Context # type: ignore
|
|
|
|
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
|
|
|
|
|
def __init__(
|
|
self,
|
|
n_reader, # number of all readers
|
|
n_local_reader, # number of local readers through shared memory
|
|
local_reader_ranks: Optional[list[int]] = None,
|
|
max_chunk_bytes: int = 1024 * 1024 * 10,
|
|
max_chunks: int = 10,
|
|
connect_ip: Optional[str] = None,
|
|
):
|
|
if local_reader_ranks is None:
|
|
local_reader_ranks = list(range(n_local_reader))
|
|
else:
|
|
assert len(local_reader_ranks) == n_local_reader
|
|
self.n_local_reader = n_local_reader
|
|
n_remote_reader = n_reader - n_local_reader
|
|
self.n_remote_reader = n_remote_reader
|
|
|
|
context = Context()
|
|
|
|
if n_local_reader > 0:
|
|
# for local readers, we will:
|
|
# 1. create a shared memory ring buffer to communicate small data
|
|
# 2. create a publish-subscribe socket to communicate large data
|
|
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
|
max_chunks)
|
|
|
|
# XPUB is very similar to PUB,
|
|
# except that it can receive subscription messages
|
|
# to confirm the number of subscribers
|
|
self.local_socket = context.socket(XPUB)
|
|
# set the verbose option so that we can receive every subscription
|
|
# message. otherwise, we will only receive the first subscription
|
|
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
|
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
|
local_subscribe_addr = get_open_zmq_ipc_path()
|
|
logger.debug("Binding to %s", local_subscribe_addr)
|
|
self.local_socket.bind(local_subscribe_addr)
|
|
|
|
self.current_idx = 0
|
|
self.writer_lock = get_mp_context().Lock()
|
|
else:
|
|
self.buffer = None # type: ignore
|
|
local_subscribe_addr = None
|
|
self.local_socket = None
|
|
self.current_idx = -1
|
|
|
|
remote_addr_ipv6 = False
|
|
if n_remote_reader > 0:
|
|
# for remote readers, we will:
|
|
# create a publish-subscribe socket to communicate large data
|
|
if not connect_ip:
|
|
connect_ip = get_ip()
|
|
self.remote_socket = context.socket(XPUB)
|
|
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
|
remote_subscribe_port = get_open_port()
|
|
if is_valid_ipv6_address(connect_ip):
|
|
self.remote_socket.setsockopt(IPV6, 1)
|
|
remote_addr_ipv6 = True
|
|
connect_ip = f"[{connect_ip}]"
|
|
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
|
self.remote_socket.bind(socket_addr)
|
|
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
|
else:
|
|
remote_subscribe_addr = None
|
|
self.remote_socket = None
|
|
|
|
self._is_writer = True
|
|
self._is_local_reader = False
|
|
self.local_reader_rank = -1
|
|
# rank does not matter for remote readers
|
|
self._is_remote_reader = False
|
|
self._read_spin_timer = SpinTimer()
|
|
|
|
self.handle = Handle(
|
|
local_reader_ranks=local_reader_ranks,
|
|
buffer_handle=self.buffer.handle()
|
|
if self.buffer is not None else None,
|
|
local_subscribe_addr=local_subscribe_addr,
|
|
remote_subscribe_addr=remote_subscribe_addr,
|
|
remote_addr_ipv6=remote_addr_ipv6,
|
|
)
|
|
|
|
logger.info("vLLM message queue communication handle: %s", self.handle)
|
|
|
|
|
|
@contextmanager
|
|
def acquire_write(self, timeout: Optional[float] = None):
|
|
assert self._is_writer, "Only writers can acquire write"
|
|
start_time = time.monotonic()
|
|
n_warning = 1
|
|
while True:
|
|
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
|
read_count = sum(metadata_buffer[1:])
|
|
written_flag = metadata_buffer[0]
|
|
if written_flag and read_count != self.buffer.n_reader:
|
|
# this block is written and not read by all readers
|
|
# for writers, `self.current_idx` is the next block to write
|
|
# if this block is not ready to write,
|
|
# we need to wait until it is read by all readers
|
|
|
|
# Release the processor to other threads
|
|
sched_yield()
|
|
|
|
# if we time out, raise an exception
|
|
elapsed = time.monotonic() - start_time
|
|
if timeout is not None and elapsed > timeout:
|
|
raise TimeoutError
|
|
|
|
# if we wait for a long time, log a message
|
|
if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
|
|
logger.info(
|
|
"No available shared memory broadcast block found"
|
|
" in %s seconds. This typically happens when some"
|
|
" processes are hanging or doing some"
|
|
" time-consuming work (e.g. compilation)",
|
|
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
|
n_warning += 1
|
|
|
|
continue
|
|
# found a block that is either
|
|
# (1) not written
|
|
# (2) read by all readers
|
|
|
|
with self.writer_lock:
|
|
# mark the block as not written
|
|
metadata_buffer[0] = 0
|
|
# let caller write to the buffer
|
|
with self.buffer.get_data(self.current_idx) as buf:
|
|
yield buf
|
|
|
|
# caller has written to the buffer
|
|
# NOTE: order is important here
|
|
# first set the read flags to 0
|
|
# then set the written flag to 1
|
|
# otherwise, the readers may think they already read the block
|
|
for i in range(1, self.buffer.n_reader + 1):
|
|
# set read flag to 0, meaning it is not read yet
|
|
metadata_buffer[i] = 0
|
|
# mark the block as written
|
|
metadata_buffer[0] = 1
|
|
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
|
break
|
|
|
|
|
|
MessageQueue.__init__ = __init__
|
|
MessageQueue.acquire_write = acquire_write
|