[4/N]DP refactor: support watching mode get_load and shortest queue strategy (#10201)
This commit is contained in:
@@ -27,7 +27,7 @@ import tempfile
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
|
||||
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
|
||||
|
||||
import setproctitle
|
||||
|
||||
@@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import (
|
||||
)
|
||||
from sglang.srt.managers.multi_tokenizer_mixin import (
|
||||
MultiTokenizerManager,
|
||||
MultiTokenizerRouter,
|
||||
get_main_process_id,
|
||||
monkey_patch_uvicorn_multiprocessing,
|
||||
read_from_shared_memory,
|
||||
@@ -127,7 +128,9 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
# Store global states
|
||||
@dataclasses.dataclass
|
||||
class _GlobalState:
|
||||
tokenizer_manager: TokenizerManager
|
||||
tokenizer_manager: Union[
|
||||
TokenizerManager, MultiTokenizerRouter, MultiTokenizerManager
|
||||
]
|
||||
template_manager: TemplateManager
|
||||
scheduler_info: Dict
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import struct
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from enum import Enum, auto
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Dict, List
|
||||
@@ -34,6 +35,7 @@ from sglang.srt.managers.io_struct import (
|
||||
BlockReqInput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
WatchLoadUpdateReq,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
@@ -46,7 +48,7 @@ from sglang.srt.utils import (
|
||||
get_zmq_socket,
|
||||
kill_itself_when_parent_died,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -67,6 +69,42 @@ class LoadBalanceMethod(Enum):
|
||||
raise ValueError(f"Invalid load balance method: {method}") from exc
|
||||
|
||||
|
||||
class DPBudget:
|
||||
def __init__(self):
|
||||
# TODO: support minimum tokens method
|
||||
self.budget_queue = deque()
|
||||
|
||||
def update_budget(self, load_update: WatchLoadUpdateReq):
|
||||
"""Update the budget queue.
|
||||
Use num_reqs instead of num_waiting_reqs to balance decode running batch.
|
||||
"""
|
||||
loads = load_update.loads
|
||||
self.budget_queue.clear()
|
||||
|
||||
num_reqs = [load.num_reqs for load in loads]
|
||||
if not num_reqs:
|
||||
return
|
||||
|
||||
max_num_reqs = max(num_reqs)
|
||||
if all(x == max_num_reqs for x in num_reqs):
|
||||
return
|
||||
|
||||
while any(x != num_reqs[0] for x in num_reqs):
|
||||
min_load = min(num_reqs)
|
||||
min_indices = [i for i, x in enumerate(num_reqs) if x == min_load]
|
||||
second_min_load = min(x for x in num_reqs if x > min_load)
|
||||
self.budget_queue.extend(
|
||||
[loads[i].dp_rank for i in min_indices] * (second_min_load - min_load)
|
||||
)
|
||||
for idx in min_indices:
|
||||
num_reqs[idx] = second_min_load
|
||||
|
||||
def dispatch(self):
|
||||
if self.budget_queue:
|
||||
return self.budget_queue.popleft()
|
||||
return None
|
||||
|
||||
|
||||
class DataParallelController:
|
||||
"""A controller that dispatches requests to multiple data parallel workers."""
|
||||
|
||||
@@ -104,6 +142,9 @@ class DataParallelController:
|
||||
}
|
||||
self.dispatching = dispatch_lookup[self.load_balance_method]
|
||||
|
||||
# Load balance budget
|
||||
self.dp_budget = DPBudget()
|
||||
|
||||
# Launch data parallel workers
|
||||
self.scheduler_procs = []
|
||||
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
|
||||
@@ -127,6 +168,31 @@ class DataParallelController:
|
||||
|
||||
self.max_req_input_len = None
|
||||
|
||||
self.init_dispatcher()
|
||||
|
||||
def send_to_all_workers(self, obj):
|
||||
for worker in self.workers:
|
||||
worker.send_pyobj(obj)
|
||||
|
||||
def send_control_message(self, obj):
|
||||
# Send control messages to first worker of tp group
|
||||
for worker in self.workers[:: self.control_message_step]:
|
||||
worker.send_pyobj(obj)
|
||||
|
||||
def handle_load_update_req(self, obj):
|
||||
self.dp_budget.update_budget(obj)
|
||||
|
||||
def init_dispatcher(self):
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(TokenizedGenerateReqInput, self.dispatching),
|
||||
(TokenizedEmbeddingReqInput, self.dispatching),
|
||||
(BlockReqInput, self.send_to_all_workers),
|
||||
(WatchLoadUpdateReq, self.handle_load_update_req),
|
||||
]
|
||||
)
|
||||
self._request_dispatcher.add_fallback_fn(self.send_control_message)
|
||||
|
||||
def launch_dp_schedulers(self, server_args, port_args):
|
||||
base_gpu_id = 0
|
||||
|
||||
@@ -291,10 +357,14 @@ class DataParallelController:
|
||||
else:
|
||||
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
||||
|
||||
def shortest_queue_scheduler(self, input_requests):
|
||||
def shortest_queue_scheduler(self, req):
|
||||
if self.maybe_external_dp_rank_routing(req):
|
||||
return
|
||||
raise NotImplementedError()
|
||||
target_worker = self.dp_budget.dispatch()
|
||||
if target_worker is None:
|
||||
self.round_robin_scheduler(req)
|
||||
else:
|
||||
self.workers[target_worker].send_pyobj(req)
|
||||
|
||||
def minimum_tokens_scheduler(self, req):
|
||||
if self.maybe_external_dp_rank_routing(req):
|
||||
@@ -333,22 +403,7 @@ class DataParallelController:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
|
||||
if isinstance(
|
||||
recv_req,
|
||||
(
|
||||
TokenizedGenerateReqInput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
),
|
||||
):
|
||||
self.dispatching(recv_req)
|
||||
elif isinstance(recv_req, BlockReqInput):
|
||||
for worker in self.workers:
|
||||
worker.send_pyobj(recv_req)
|
||||
else:
|
||||
# Send other control messages to first worker of tp group
|
||||
for worker in self.workers[:: self.control_message_step]:
|
||||
worker.send_pyobj(recv_req)
|
||||
self._request_dispatcher(recv_req)
|
||||
|
||||
|
||||
def run_data_parallel_controller_process(
|
||||
|
||||
@@ -297,7 +297,7 @@ def run_detokenizer_process(
|
||||
else:
|
||||
manager.event_loop()
|
||||
except Exception:
|
||||
manager.socket_mapping.clear_all_sockets()
|
||||
manager.maybe_clear_socket_mapping()
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"DetokenizerManager hit an exception: {traceback}")
|
||||
parent_process.send_signal(signal.SIGQUIT)
|
||||
|
||||
@@ -1374,3 +1374,21 @@ class BlockReqType(Enum):
|
||||
@dataclass
|
||||
class BlockReqInput:
|
||||
type: BlockReqType
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetLoadReqInput:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetLoadReqOutput:
|
||||
dp_rank: int
|
||||
num_reqs: int
|
||||
num_waiting_reqs: int
|
||||
num_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatchLoadUpdateReq:
|
||||
loads: List[GetLoadReqOutput]
|
||||
|
||||
@@ -354,6 +354,10 @@ class MultiHttpWorkerDetokenizerMixin:
|
||||
worker_ids = []
|
||||
return worker_ids
|
||||
|
||||
def maybe_clear_socket_mapping(self):
|
||||
if hasattr(self, "socket_mapping"):
|
||||
self.socket_mapping.clear_all_sockets()
|
||||
|
||||
def multi_http_worker_event_loop(self):
|
||||
"""The event loop that handles requests, for multi multi-http-worker mode"""
|
||||
self.socket_mapping = SocketMapping()
|
||||
|
||||
@@ -79,6 +79,8 @@ from sglang.srt.managers.io_struct import (
|
||||
FreezeGCReq,
|
||||
GetInternalStateReq,
|
||||
GetInternalStateReqOutput,
|
||||
GetLoadReqInput,
|
||||
GetLoadReqOutput,
|
||||
GetWeightsByNameReqInput,
|
||||
HealthCheckOutput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
@@ -577,6 +579,7 @@ class Scheduler(
|
||||
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
||||
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
||||
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
||||
(GetLoadReqInput, self.get_load),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -2279,39 +2282,50 @@ class Scheduler(
|
||||
if_success = False
|
||||
return if_success
|
||||
|
||||
def get_load(self):
|
||||
def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
|
||||
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
||||
|
||||
if self.is_hybrid:
|
||||
load_full = (
|
||||
num_tokens_full = (
|
||||
self.full_tokens_per_layer
|
||||
- self.token_to_kv_pool_allocator.full_available_size()
|
||||
- self.tree_cache.full_evictable_size()
|
||||
)
|
||||
load_swa = (
|
||||
num_tokens_swa = (
|
||||
self.swa_tokens_per_layer
|
||||
- self.token_to_kv_pool_allocator.swa_available_size()
|
||||
- self.tree_cache.swa_evictable_size()
|
||||
)
|
||||
load = max(load_full, load_swa)
|
||||
num_tokens = max(num_tokens_full, num_tokens_swa)
|
||||
else:
|
||||
load = (
|
||||
num_tokens = (
|
||||
self.max_total_num_tokens
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
- self.tree_cache.evictable_size()
|
||||
)
|
||||
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
||||
|
||||
# Tokens in waiting queue, bootstrap queue, prealloc queue
|
||||
num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
||||
num_waiting_reqs = len(self.waiting_queue)
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
load += sum(
|
||||
num_tokens += sum(
|
||||
len(req.origin_input_ids)
|
||||
for req in self.disagg_prefill_bootstrap_queue.queue
|
||||
)
|
||||
num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
load += sum(
|
||||
num_tokens += sum(
|
||||
len(req.req.origin_input_ids)
|
||||
for req in self.disagg_decode_prealloc_queue.queue
|
||||
)
|
||||
num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
|
||||
|
||||
return load
|
||||
return GetLoadReqOutput(
|
||||
dp_rank=self.dp_rank,
|
||||
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
|
||||
num_waiting_reqs=num_waiting_reqs,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
def get_internal_state(self, recv_req: GetInternalStateReq):
|
||||
ret = dict(global_server_args_dict)
|
||||
@@ -2337,8 +2351,6 @@ class Scheduler(
|
||||
if RECORD_STEP_TIME:
|
||||
ret["step_time_dict"] = self.step_time_dict
|
||||
|
||||
ret["load"] = self.get_load()
|
||||
|
||||
return GetInternalStateReqOutput(internal_state=ret)
|
||||
|
||||
def set_internal_state(self, recv_req: SetInternalStateReq):
|
||||
|
||||
@@ -279,7 +279,7 @@ class SchedulerMetricsMixin:
|
||||
self.server_args.load_balance_method == "minimum_tokens"
|
||||
and self.forward_ct % 40 == 0
|
||||
):
|
||||
holding_tokens = self.get_load()
|
||||
holding_tokens = self.get_load().num_tokens
|
||||
|
||||
new_recv_dp_balance_id_list, holding_token_list = (
|
||||
self.gather_dp_balance_info(holding_tokens)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
@@ -18,6 +19,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import fastapi
|
||||
import zmq
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
ClearHiCacheReqInput,
|
||||
@@ -28,6 +30,8 @@ from sglang.srt.managers.io_struct import (
|
||||
FlushCacheReqOutput,
|
||||
GetInternalStateReq,
|
||||
GetInternalStateReqOutput,
|
||||
GetLoadReqInput,
|
||||
GetLoadReqOutput,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
@@ -75,14 +79,17 @@ class _Communicator(Generic[T]):
|
||||
|
||||
enable_multi_tokenizer = False
|
||||
|
||||
def __init__(self, sender, fan_out: int):
|
||||
def __init__(self, sender: zmq.Socket, fan_out: int, mode="queueing"):
|
||||
self._sender = sender
|
||||
self._fan_out = fan_out
|
||||
self._mode = mode
|
||||
self._result_event: Optional[asyncio.Event] = None
|
||||
self._result_values: Optional[List[T]] = None
|
||||
self._ready_queue: Deque[asyncio.Future] = deque()
|
||||
|
||||
async def __call__(self, obj):
|
||||
assert mode in ["queueing", "watching"]
|
||||
|
||||
async def queueing_call(self, obj: T):
|
||||
ready_event = asyncio.Event()
|
||||
if self._result_event is not None or len(self._ready_queue) > 0:
|
||||
self._ready_queue.append(ready_event)
|
||||
@@ -106,6 +113,28 @@ class _Communicator(Generic[T]):
|
||||
|
||||
return result_values
|
||||
|
||||
async def watching_call(self, obj):
|
||||
if self._result_event is None:
|
||||
assert self._result_values is None
|
||||
self._result_values = []
|
||||
self._result_event = asyncio.Event()
|
||||
|
||||
if obj:
|
||||
if _Communicator.enable_multi_tokenizer:
|
||||
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
|
||||
self._sender.send_pyobj(obj)
|
||||
|
||||
await self._result_event.wait()
|
||||
result_values = copy.deepcopy(self._result_values)
|
||||
self._result_event = self._result_values = None
|
||||
return result_values
|
||||
|
||||
async def __call__(self, obj):
|
||||
if self._mode == "queueing":
|
||||
return await self.queueing_call(obj)
|
||||
else:
|
||||
return await self.watching_call(obj)
|
||||
|
||||
def handle_recv(self, recv_obj: T):
|
||||
self._result_values.append(recv_obj)
|
||||
if len(self._result_values) == self._fan_out:
|
||||
@@ -165,6 +194,9 @@ class TokenizerCommunicatorMixin:
|
||||
self.update_lora_adapter_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.get_load_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size, mode="watching"
|
||||
)
|
||||
|
||||
self._result_dispatcher += self._get_communicator_dispatcher()
|
||||
|
||||
@@ -235,6 +267,10 @@ class TokenizerCommunicatorMixin:
|
||||
LoRAUpdateResult,
|
||||
self.update_lora_adapter_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
GetLoadReqOutput,
|
||||
self.get_load_communicator.handle_recv,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -528,10 +564,6 @@ class TokenizerCommunicatorMixin:
|
||||
)
|
||||
return [res.updated for res in responses]
|
||||
|
||||
async def get_load(self: TokenizerManager) -> dict:
|
||||
# TODO(lsyin): fake load report server
|
||||
if not self.current_load_lock.locked():
|
||||
async with self.current_load_lock:
|
||||
internal_state = await self.get_internal_state()
|
||||
self.current_load = internal_state[0]["load"]
|
||||
return {"load": self.current_load}
|
||||
async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]:
|
||||
req = GetLoadReqInput()
|
||||
return await self.get_load_communicator(req)
|
||||
|
||||
@@ -64,6 +64,7 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
FreezeGCReq,
|
||||
GenerateReqInput,
|
||||
GetLoadReqInput,
|
||||
HealthCheckOutput,
|
||||
MultiTokenizerWrapper,
|
||||
OpenSessionReqInput,
|
||||
@@ -73,6 +74,7 @@ from sglang.srt.managers.io_struct import (
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
WatchLoadUpdateReq,
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import TensorTransportMode
|
||||
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
||||
@@ -1240,6 +1242,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
self.asyncio_tasks.add(
|
||||
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
||||
)
|
||||
self.asyncio_tasks.add(
|
||||
loop.create_task(print_exception_wrapper(self.watch_load_thread))
|
||||
)
|
||||
|
||||
def dump_requests_before_crash(self):
|
||||
if self.crash_dump_performed:
|
||||
@@ -1844,6 +1849,20 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
|
||||
return scores
|
||||
|
||||
async def watch_load_thread(self):
|
||||
# Only for dp_controller when dp_size > 1
|
||||
if (
|
||||
self.server_args.dp_size == 1
|
||||
or self.server_args.load_balance_method == "round_robin"
|
||||
):
|
||||
return
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(self.server_args.load_watch_interval)
|
||||
loads = await self.get_load_communicator(GetLoadReqInput())
|
||||
load_udpate_req = WatchLoadUpdateReq(loads=loads)
|
||||
self.send_to_scheduler.send_pyobj(load_udpate_req)
|
||||
|
||||
|
||||
class ServerStatus(Enum):
|
||||
Up = "Up"
|
||||
|
||||
@@ -233,6 +233,7 @@ class ServerArgs:
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
load_balance_method: str = "round_robin"
|
||||
load_watch_interval: float = 0.1
|
||||
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
|
||||
prefill_round_robin_balance: bool = False
|
||||
|
||||
@@ -663,6 +664,7 @@ class ServerArgs:
|
||||
|
||||
if self.dp_size == 1:
|
||||
self.enable_dp_attention = False
|
||||
self.enable_dp_lm_head = False
|
||||
|
||||
# Data parallelism attention
|
||||
if self.enable_dp_attention:
|
||||
@@ -1488,6 +1490,12 @@ class ServerArgs:
|
||||
"minimum_tokens",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-watch-interval",
|
||||
type=float,
|
||||
default=ServerArgs.load_watch_interval,
|
||||
help="The interval of load watching in seconds.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-round-robin-balance",
|
||||
default=ServerArgs.prefill_round_robin_balance,
|
||||
|
||||
@@ -1160,7 +1160,7 @@ def pytorch_profile(name, func, *args, data_size=-1):
|
||||
|
||||
def get_zmq_socket(
|
||||
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
|
||||
):
|
||||
) -> zmq.Socket:
|
||||
mem = psutil.virtual_memory()
|
||||
total_mem = mem.total / 1024**3
|
||||
available_mem = mem.available / 1024**3
|
||||
|
||||
@@ -472,6 +472,10 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
|
||||
class TypeBasedDispatcher:
|
||||
def __init__(self, mapping: List[Tuple[Type, Callable]]):
|
||||
self._mapping = mapping
|
||||
self._fallback_fn = None
|
||||
|
||||
def add_fallback_fn(self, fallback_fn: Callable):
|
||||
self._fallback_fn = fallback_fn
|
||||
|
||||
def __iadd__(self, other: "TypeBasedDispatcher"):
|
||||
self._mapping.extend(other._mapping)
|
||||
@@ -481,6 +485,9 @@ class TypeBasedDispatcher:
|
||||
for ty, fn in self._mapping:
|
||||
if isinstance(obj, ty):
|
||||
return fn(obj)
|
||||
|
||||
if self._fallback_fn is not None:
|
||||
return self._fallback_fn(obj)
|
||||
raise ValueError(f"Invalid object: {obj}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user