378 lines
16 KiB
Python
378 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import copy
|
|
import multiprocessing
|
|
import time
|
|
import weakref
|
|
|
|
import msgspec.msgpack
|
|
import zmq
|
|
|
|
from vllm.config import ParallelConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.utils.network_utils import make_zmq_socket
|
|
from vllm.utils.system_utils import get_mp_context, set_process_title
|
|
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
|
|
from vllm.v1.serial_utils import MsgpackDecoder
|
|
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class DPCoordinator:
|
|
"""Coordinator process used for data-parallel deployments (DP>1).
|
|
|
|
Intermediates between multiple DP engine rank processes and one or more
|
|
front-end API server processes.
|
|
|
|
* Collects stats from each DP engine (currently just waiting and running
|
|
queue lengths), and publishes these to all front-ends for use in
|
|
load-balancing decisions.
|
|
|
|
* Keeps track of the current DP "request wave" number and running state
|
|
of the engines. This is received from the DP rank 0 engine and published
|
|
to the front-end processes along with the current load stats.
|
|
|
|
The engines alternate between a global running/paused state. The global
|
|
"request wave" number is a count of the number of times that the workers
|
|
collectively move from a running state to a paused state. This transition
|
|
is synchronized via the all-reduce operation performed in the
|
|
DPEngineCoreProc._has_global_unfinished_reqs method.
|
|
|
|
* Broadcasts the START_DP_WAVE message to engines to move them from paused
|
|
to running state when one engine receives a new request. This can happen
|
|
in two cases:
|
|
1) A front-end sending a new request while the engines are paused will
|
|
concurrently notify the coordinator.
|
|
2) An engine receiving a request for a stale request wave while in paused
|
|
state will notify the coordinator.
|
|
|
|
Engines will move into running state when receiving a new request or
|
|
START_DP_WAVE message.
|
|
|
|
Note that when deployed in External LB mode, no stats will be published by
|
|
the engines and thus updates will only be sent to front-ends when the
|
|
request wave / running state changes.
|
|
"""
|
|
|
|
def __init__(self, parallel_config: ParallelConfig):
|
|
dp_size = parallel_config.data_parallel_size
|
|
assert dp_size > 1, "Coordinator only used for data parallel"
|
|
|
|
host = parallel_config.data_parallel_master_ip
|
|
external_lb = parallel_config.data_parallel_external_lb
|
|
hybrid_lb = parallel_config.data_parallel_hybrid_lb
|
|
|
|
# Assume coordinator is colocated with front-end procs when not in
|
|
# either external or hybrid DP LB mode.
|
|
local_only = not (external_lb or hybrid_lb)
|
|
front_publish_address = get_engine_client_zmq_addr(
|
|
local_only=local_only, host=host
|
|
)
|
|
|
|
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
|
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
|
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
|
|
|
|
context = get_mp_context()
|
|
self.proc: multiprocessing.Process = context.Process(
|
|
target=DPCoordinatorProc.run_coordinator,
|
|
name="VLLM_DP_Coordinator",
|
|
kwargs={
|
|
"engine_count": parallel_config.data_parallel_size,
|
|
"front_publish_address": front_publish_address,
|
|
"back_output_address": back_output_address,
|
|
"back_publish_address": back_publish_address,
|
|
},
|
|
daemon=True,
|
|
)
|
|
self.proc.start()
|
|
|
|
self.stats_publish_address = front_publish_address
|
|
self.coord_in_address = back_publish_address
|
|
self.coord_out_address = back_output_address
|
|
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
|
|
|
|
def get_stats_publish_address(self) -> str:
|
|
return self.stats_publish_address
|
|
|
|
def get_engine_socket_addresses(self) -> tuple[str, str]:
|
|
"""Returns tuple of ZMQ input address, output address."""
|
|
return self.coord_in_address, self.coord_out_address
|
|
|
|
def close(self):
|
|
self._finalizer()
|
|
|
|
|
|
class EngineState:
|
|
def __init__(self):
|
|
self.request_counts = [0, 0] # [waiting, running]
|
|
|
|
|
|
class DPCoordinatorProc:
|
|
def __init__(self, engine_count: int, min_stats_update_interval_ms: int = 100):
|
|
set_process_title("DPCoordinator")
|
|
self.ctx = zmq.Context()
|
|
|
|
self.engines = [EngineState() for _ in range(engine_count)]
|
|
|
|
self.stats_update_interval_ms = min_stats_update_interval_ms
|
|
|
|
@staticmethod
|
|
def run_coordinator(
|
|
engine_count: int,
|
|
front_publish_address: str,
|
|
back_output_address: str,
|
|
back_publish_address: str,
|
|
min_stats_update_interval_ms: int = 100,
|
|
):
|
|
coordinator = DPCoordinatorProc(
|
|
engine_count=engine_count,
|
|
min_stats_update_interval_ms=min_stats_update_interval_ms,
|
|
)
|
|
try:
|
|
coordinator.process_input_socket(
|
|
front_publish_address,
|
|
back_output_address,
|
|
back_publish_address,
|
|
)
|
|
except KeyboardInterrupt:
|
|
logger.info("DP Coordinator process exiting")
|
|
|
|
def process_input_socket(
|
|
self,
|
|
front_publish_address: str,
|
|
back_output_address: str,
|
|
back_publish_address: str,
|
|
):
|
|
decoder = MsgpackDecoder(EngineCoreOutputs)
|
|
|
|
# For tracking request wave progression.
|
|
current_wave = 0
|
|
engines_running = False
|
|
|
|
# For tracking request counts for internal load-balancing.
|
|
stats_changed = False
|
|
last_stats_step = -1
|
|
last_stats_wave = -1
|
|
last_step_counts: list[list[int]] | None = None
|
|
|
|
with (
|
|
make_zmq_socket(
|
|
path=front_publish_address, # IPC
|
|
ctx=self.ctx,
|
|
socket_type=zmq.XPUB,
|
|
bind=True,
|
|
) as publish_front,
|
|
make_zmq_socket(
|
|
path=back_output_address, # IPC or TCP
|
|
ctx=self.ctx,
|
|
socket_type=zmq.PULL,
|
|
bind=True,
|
|
) as output_back,
|
|
make_zmq_socket(
|
|
path=back_publish_address, # IPC or TCP
|
|
ctx=self.ctx,
|
|
socket_type=zmq.XPUB,
|
|
bind=True,
|
|
) as publish_back,
|
|
):
|
|
# Wait until all engines subscribe.
|
|
for _ in self.engines:
|
|
if publish_back.recv() != b"\x01":
|
|
logger.error(
|
|
"DP Coordinator received unexpected message while "
|
|
"waiting for engines to subscribe"
|
|
)
|
|
return
|
|
# Send ready message to engines.
|
|
publish_back.send(b"READY")
|
|
|
|
logger.info("All engine subscriptions received by DP coordinator")
|
|
|
|
poller = zmq.Poller()
|
|
poller.register(publish_front, zmq.POLLIN)
|
|
poller.register(output_back, zmq.POLLIN)
|
|
last_publish_time = 0
|
|
while True:
|
|
elapsed = int(time.time() * 1000) - last_publish_time
|
|
# Send at stats_update_interval_ms interval if the stats have
|
|
# changed, or otherwise every 5 seconds.
|
|
wait_for = self.stats_update_interval_ms if stats_changed else 5000
|
|
|
|
# Wait at least 50ms to ensure we've received all stats for
|
|
# the current step.
|
|
min_timeout = 50 if last_step_counts is None else 0
|
|
|
|
events = poller.poll(timeout=max(min_timeout, wait_for - elapsed))
|
|
if not events:
|
|
# Poller timeout - publish current stats to front-ends.
|
|
if last_step_counts is not None:
|
|
engine_req_counts_list = last_step_counts
|
|
last_step_counts = None
|
|
else:
|
|
engine_req_counts_list = self._get_engine_counts()
|
|
stats_changed = False
|
|
|
|
to_publish = (engine_req_counts_list, current_wave, engines_running)
|
|
publish_front.send(msgspec.msgpack.encode(to_publish))
|
|
last_publish_time = int(time.time() * 1000)
|
|
continue
|
|
|
|
events = dict(events)
|
|
wave_state_changed = False
|
|
|
|
if publish_front in events:
|
|
buffer = publish_front.recv()
|
|
if buffer in (b"\x01", b"\x00"):
|
|
# Ignore subscription messages.
|
|
continue
|
|
|
|
decoded = msgspec.msgpack.decode(buffer)
|
|
if (
|
|
isinstance(decoded, (list, tuple))
|
|
and len(decoded) == 2
|
|
and decoded[0] == "SCALE_ELASTIC_EP"
|
|
):
|
|
# Handle scale up notification
|
|
new_engine_count = decoded[1]
|
|
current_count = len(self.engines)
|
|
if new_engine_count > current_count:
|
|
for _ in range(new_engine_count - current_count):
|
|
self.engines.append(EngineState())
|
|
# NOTE(yongji): handle the case
|
|
# where newly started engines have current_wave = 0
|
|
# if existing engines just finished a wave
|
|
# and engine_running isn't updated yet at
|
|
# CoordinatorProc requests routed to newly started
|
|
# engines may not wake up existing engines, as long
|
|
# as 0 < request.wave < existing engines'
|
|
# current_wave
|
|
# we note that 0 is the wave number for the new
|
|
# engine
|
|
engines_running = False
|
|
logger.info(
|
|
"DPCoordinator scaled up from %s to %s engines",
|
|
current_count,
|
|
new_engine_count,
|
|
)
|
|
else:
|
|
self.engines = self.engines[:new_engine_count]
|
|
logger.info(
|
|
"DPCoordinator scaled down from %s to %s engines",
|
|
current_count,
|
|
new_engine_count,
|
|
)
|
|
continue # Skip normal engine notification processing
|
|
|
|
# We received a message on the front-end XPUB socket,
|
|
# from an API server sending a new request while the
|
|
# engines are paused, so that we can wake the other
|
|
# engines.
|
|
engine_to_exclude, wave = decoded
|
|
if not engines_running:
|
|
if wave < current_wave:
|
|
# If the wave number is stale, ensure the message
|
|
# is handled by all the engines.
|
|
engine_to_exclude = None
|
|
|
|
engines_running = True
|
|
wave_state_changed = True
|
|
self._send_start_wave(
|
|
publish_back, current_wave, engine_to_exclude
|
|
)
|
|
|
|
if output_back in events:
|
|
# We received a message from one of the engines.
|
|
|
|
buffer = output_back.recv()
|
|
outputs: EngineCoreOutputs = decoder.decode(buffer)
|
|
|
|
assert not outputs.outputs
|
|
assert outputs.utility_output is None
|
|
|
|
eng_index = outputs.engine_index
|
|
scheduler_stats = outputs.scheduler_stats
|
|
if scheduler_stats:
|
|
# 1. Updated request load stats - update our local
|
|
# state with these.
|
|
stats = self.engines[eng_index].request_counts
|
|
stats_step = scheduler_stats.step_counter
|
|
stats_wave = scheduler_stats.current_wave
|
|
if (
|
|
stats_wave > last_stats_wave
|
|
or stats_wave == last_stats_wave
|
|
and stats_step > last_stats_step
|
|
):
|
|
if stats_changed:
|
|
last_step_counts = self._get_engine_counts(do_copy=True)
|
|
last_stats_step = stats_step
|
|
last_stats_wave = stats_wave
|
|
elif stats_wave != last_stats_wave or (
|
|
stats_step != last_stats_step
|
|
):
|
|
logger.warning(
|
|
"Received stats for out-of-order "
|
|
"step (%d, %d) from engine %d (expected "
|
|
"> (%d, %d))",
|
|
stats_wave,
|
|
stats_step,
|
|
eng_index,
|
|
last_stats_wave,
|
|
last_stats_step,
|
|
)
|
|
stats[0] = scheduler_stats.num_waiting_reqs
|
|
stats[1] = scheduler_stats.num_running_reqs
|
|
stats_changed = True
|
|
|
|
if (wave := outputs.wave_complete) is not None:
|
|
# 2. Notification from rank 0 engine that we've
|
|
# moved into the global paused state
|
|
# (engines_running==False).
|
|
if current_wave <= wave:
|
|
new_wave = wave + 1
|
|
logger.debug(
|
|
"Moving DP wave from %d to %d.", current_wave, new_wave
|
|
)
|
|
current_wave = new_wave
|
|
engines_running = False
|
|
wave_state_changed = True
|
|
elif (wave := outputs.start_wave) is not None and (
|
|
wave > current_wave
|
|
or (wave == current_wave and not engines_running)
|
|
):
|
|
# 3. The engine received request for a non-current wave
|
|
# so we must ensure that other engines progress to the
|
|
# next wave (race condition handling).
|
|
logger.debug(
|
|
"Starting wave %d after notification of "
|
|
"stale wave request from engine.",
|
|
wave,
|
|
)
|
|
current_wave = wave
|
|
engines_running = True
|
|
wave_state_changed = True
|
|
self._send_start_wave(publish_back, wave, eng_index)
|
|
|
|
if wave_state_changed:
|
|
message = (None, current_wave, engines_running)
|
|
publish_front.send(msgspec.msgpack.encode(message))
|
|
|
|
@staticmethod
|
|
def _send_start_wave(
|
|
socket: zmq.Socket, wave: int, exclude_engine_index: int | None
|
|
):
|
|
"""Broadcast the START_DP_WAVE message to all the engines.
|
|
It includes the current wave number and index of engine which
|
|
has already received a request with this wave number and so doesn't
|
|
require additional notification.
|
|
"""
|
|
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
|
|
socket.send_multipart((EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
|
|
|
|
def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
|
|
"""Return list of [waiting, running] count lists for each engine."""
|
|
if do_copy:
|
|
return [copy.copy(e.request_counts) for e in self.engines]
|
|
return [e.request_counts for e in self.engines]
|