diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 28aa897f7..08f964366 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -27,7 +27,7 @@ import tempfile import threading import time from http import HTTPStatus -from typing import Any, AsyncIterator, Callable, Dict, List, Optional +from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union import setproctitle @@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.multi_tokenizer_mixin import ( MultiTokenizerManager, + MultiTokenizerRouter, get_main_process_id, monkey_patch_uvicorn_multiprocessing, read_from_shared_memory, @@ -127,7 +128,9 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) # Store global states @dataclasses.dataclass class _GlobalState: - tokenizer_manager: TokenizerManager + tokenizer_manager: Union[ + TokenizerManager, MultiTokenizerRouter, MultiTokenizerManager + ] template_manager: TemplateManager scheduler_info: Dict diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index a7bb6d13a..781f95695 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -21,6 +21,7 @@ import struct import sys import threading import time +from collections import deque from enum import Enum, auto from multiprocessing import shared_memory from typing import Dict, List @@ -34,6 +35,7 @@ from sglang.srt.managers.io_struct import ( BlockReqInput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + WatchLoadUpdateReq, ) from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.scheduler import run_scheduler_process @@ -46,7 +48,7 @@ from sglang.srt.utils import ( get_zmq_socket, kill_itself_when_parent_died, ) -from sglang.utils import get_exception_traceback +from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) @@ -67,6 +69,42 @@ class LoadBalanceMethod(Enum): raise ValueError(f"Invalid load balance method: {method}") from exc +class DPBudget: + def __init__(self): + # TODO: support minimum tokens method + self.budget_queue = deque() + + def update_budget(self, load_update: WatchLoadUpdateReq): + """Update the budget queue. + Use num_reqs instead of num_waiting_reqs to balance decode running batch. + """ + loads = load_update.loads + self.budget_queue.clear() + + num_reqs = [load.num_reqs for load in loads] + if not num_reqs: + return + + max_num_reqs = max(num_reqs) + if all(x == max_num_reqs for x in num_reqs): + return + + while any(x != num_reqs[0] for x in num_reqs): + min_load = min(num_reqs) + min_indices = [i for i, x in enumerate(num_reqs) if x == min_load] + second_min_load = min(x for x in num_reqs if x > min_load) + self.budget_queue.extend( + [loads[i].dp_rank for i in min_indices] * (second_min_load - min_load) + ) + for idx in min_indices: + num_reqs[idx] = second_min_load + + def dispatch(self): + if self.budget_queue: + return self.budget_queue.popleft() + return None + + class DataParallelController: """A controller that dispatches requests to multiple data parallel workers.""" @@ -104,6 +142,9 @@ class DataParallelController: } self.dispatching = dispatch_lookup[self.load_balance_method] + # Load balance budget + self.dp_budget = DPBudget() + # Launch data parallel workers self.scheduler_procs = [] self.workers: List[zmq.Socket] = [None] * server_args.dp_size @@ -127,6 +168,31 @@ class DataParallelController: self.max_req_input_len = None + self.init_dispatcher() + + def send_to_all_workers(self, obj): + for worker in self.workers: + worker.send_pyobj(obj) + + def send_control_message(self, obj): + # Send control messages to first worker of tp group + for worker in self.workers[:: self.control_message_step]: + worker.send_pyobj(obj) + + def handle_load_update_req(self, obj): + self.dp_budget.update_budget(obj) + + def init_dispatcher(self): + self._request_dispatcher = TypeBasedDispatcher( + [ + (TokenizedGenerateReqInput, self.dispatching), + (TokenizedEmbeddingReqInput, self.dispatching), + (BlockReqInput, self.send_to_all_workers), + (WatchLoadUpdateReq, self.handle_load_update_req), + ] + ) + self._request_dispatcher.add_fallback_fn(self.send_control_message) + def launch_dp_schedulers(self, server_args, port_args): base_gpu_id = 0 @@ -291,10 +357,14 @@ class DataParallelController: else: self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) - def shortest_queue_scheduler(self, input_requests): + def shortest_queue_scheduler(self, req): if self.maybe_external_dp_rank_routing(req): return - raise NotImplementedError() + target_worker = self.dp_budget.dispatch() + if target_worker is None: + self.round_robin_scheduler(req) + else: + self.workers[target_worker].send_pyobj(req) def minimum_tokens_scheduler(self, req): if self.maybe_external_dp_rank_routing(req): @@ -333,22 +403,7 @@ class DataParallelController: recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) except zmq.ZMQError: break - - if isinstance( - recv_req, - ( - TokenizedGenerateReqInput, - TokenizedEmbeddingReqInput, - ), - ): - self.dispatching(recv_req) - elif isinstance(recv_req, BlockReqInput): - for worker in self.workers: - worker.send_pyobj(recv_req) - else: - # Send other control messages to first worker of tp group - for worker in self.workers[:: self.control_message_step]: - worker.send_pyobj(recv_req) + self._request_dispatcher(recv_req) def run_data_parallel_controller_process( diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index bc58f4ee5..3c5fd4420 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -297,7 +297,7 @@ def run_detokenizer_process( else: manager.event_loop() except Exception: - manager.socket_mapping.clear_all_sockets() + manager.maybe_clear_socket_mapping() traceback = get_exception_traceback() logger.error(f"DetokenizerManager hit an exception: {traceback}") parent_process.send_signal(signal.SIGQUIT) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 152c5a915..cf5406660 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1374,3 +1374,21 @@ class BlockReqType(Enum): @dataclass class BlockReqInput: type: BlockReqType + + +@dataclass +class GetLoadReqInput: + pass + + +@dataclass +class GetLoadReqOutput: + dp_rank: int + num_reqs: int + num_waiting_reqs: int + num_tokens: int + + +@dataclass +class WatchLoadUpdateReq: + loads: List[GetLoadReqOutput] diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 0aadfba2c..2d734ab2b 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -354,6 +354,10 @@ class MultiHttpWorkerDetokenizerMixin: worker_ids = [] return worker_ids + def maybe_clear_socket_mapping(self): + if hasattr(self, "socket_mapping"): + self.socket_mapping.clear_all_sockets() + def multi_http_worker_event_loop(self): """The event loop that handles requests, for multi multi-http-worker mode""" self.socket_mapping = SocketMapping() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index be7bc6a4a..3957bc679 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -79,6 +79,8 @@ from sglang.srt.managers.io_struct import ( FreezeGCReq, GetInternalStateReq, GetInternalStateReqOutput, + GetLoadReqInput, + GetLoadReqOutput, GetWeightsByNameReqInput, HealthCheckOutput, InitWeightsSendGroupForRemoteInstanceReqInput, @@ -577,6 +579,7 @@ class Scheduler( (LoadLoRAAdapterReqInput, self.load_lora_adapter), (UnloadLoRAAdapterReqInput, self.unload_lora_adapter), (MultiTokenizerRegisterReq, self.register_multi_tokenizer), + (GetLoadReqInput, self.get_load), ] ) @@ -2279,39 +2282,50 @@ class Scheduler( if_success = False return if_success - def get_load(self): + def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput: # TODO(lsyin): use dynamically maintained num_waiting_tokens + if self.is_hybrid: - load_full = ( + num_tokens_full = ( self.full_tokens_per_layer - self.token_to_kv_pool_allocator.full_available_size() - self.tree_cache.full_evictable_size() ) - load_swa = ( + num_tokens_swa = ( self.swa_tokens_per_layer - self.token_to_kv_pool_allocator.swa_available_size() - self.tree_cache.swa_evictable_size() ) - load = max(load_full, load_swa) + num_tokens = max(num_tokens_full, num_tokens_swa) else: - load = ( + num_tokens = ( self.max_total_num_tokens - self.token_to_kv_pool_allocator.available_size() - self.tree_cache.evictable_size() ) - load += sum(len(req.origin_input_ids) for req in self.waiting_queue) + + # Tokens in waiting queue, bootstrap queue, prealloc queue + num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue) + num_waiting_reqs = len(self.waiting_queue) if self.disaggregation_mode == DisaggregationMode.PREFILL: - load += sum( + num_tokens += sum( len(req.origin_input_ids) for req in self.disagg_prefill_bootstrap_queue.queue ) + num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue) elif self.disaggregation_mode == DisaggregationMode.DECODE: - load += sum( + num_tokens += sum( len(req.req.origin_input_ids) for req in self.disagg_decode_prealloc_queue.queue ) + num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue) - return load + return GetLoadReqOutput( + dp_rank=self.dp_rank, + num_reqs=len(self.running_batch.reqs) + num_waiting_reqs, + num_waiting_reqs=num_waiting_reqs, + num_tokens=num_tokens, + ) def get_internal_state(self, recv_req: GetInternalStateReq): ret = dict(global_server_args_dict) @@ -2337,8 +2351,6 @@ class Scheduler( if RECORD_STEP_TIME: ret["step_time_dict"] = self.step_time_dict - ret["load"] = self.get_load() - return GetInternalStateReqOutput(internal_state=ret) def set_internal_state(self, recv_req: SetInternalStateReq): diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 06c7fe7fb..66cdc95bb 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -279,7 +279,7 @@ class SchedulerMetricsMixin: self.server_args.load_balance_method == "minimum_tokens" and self.forward_ct % 40 == 0 ): - holding_tokens = self.get_load() + holding_tokens = self.get_load().num_tokens new_recv_dp_balance_id_list, holding_token_list = ( self.gather_dp_balance_info(holding_tokens) diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 33c222a94..8970d5ad5 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import copy import logging import os import time @@ -18,6 +19,7 @@ from typing import ( ) import fastapi +import zmq from sglang.srt.managers.io_struct import ( ClearHiCacheReqInput, @@ -28,6 +30,8 @@ from sglang.srt.managers.io_struct import ( FlushCacheReqOutput, GetInternalStateReq, GetInternalStateReqOutput, + GetLoadReqInput, + GetLoadReqOutput, GetWeightsByNameReqInput, GetWeightsByNameReqOutput, InitWeightsSendGroupForRemoteInstanceReqInput, @@ -75,14 +79,17 @@ class _Communicator(Generic[T]): enable_multi_tokenizer = False - def __init__(self, sender, fan_out: int): + def __init__(self, sender: zmq.Socket, fan_out: int, mode="queueing"): self._sender = sender self._fan_out = fan_out + self._mode = mode self._result_event: Optional[asyncio.Event] = None self._result_values: Optional[List[T]] = None self._ready_queue: Deque[asyncio.Future] = deque() - async def __call__(self, obj): + assert mode in ["queueing", "watching"] + + async def queueing_call(self, obj: T): ready_event = asyncio.Event() if self._result_event is not None or len(self._ready_queue) > 0: self._ready_queue.append(ready_event) @@ -106,6 +113,28 @@ class _Communicator(Generic[T]): return result_values + async def watching_call(self, obj): + if self._result_event is None: + assert self._result_values is None + self._result_values = [] + self._result_event = asyncio.Event() + + if obj: + if _Communicator.enable_multi_tokenizer: + obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj) + self._sender.send_pyobj(obj) + + await self._result_event.wait() + result_values = copy.deepcopy(self._result_values) + self._result_event = self._result_values = None + return result_values + + async def __call__(self, obj): + if self._mode == "queueing": + return await self.queueing_call(obj) + else: + return await self.watching_call(obj) + def handle_recv(self, recv_obj: T): self._result_values.append(recv_obj) if len(self._result_values) == self._fan_out: @@ -165,6 +194,9 @@ class TokenizerCommunicatorMixin: self.update_lora_adapter_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.get_load_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size, mode="watching" + ) self._result_dispatcher += self._get_communicator_dispatcher() @@ -235,6 +267,10 @@ class TokenizerCommunicatorMixin: LoRAUpdateResult, self.update_lora_adapter_communicator.handle_recv, ), + ( + GetLoadReqOutput, + self.get_load_communicator.handle_recv, + ), ] ) @@ -528,10 +564,6 @@ class TokenizerCommunicatorMixin: ) return [res.updated for res in responses] - async def get_load(self: TokenizerManager) -> dict: - # TODO(lsyin): fake load report server - if not self.current_load_lock.locked(): - async with self.current_load_lock: - internal_state = await self.get_internal_state() - self.current_load = internal_state[0]["load"] - return {"load": self.current_load} + async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]: + req = GetLoadReqInput() + return await self.get_load_communicator(req) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 25b5fd87c..40f21b17d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -64,6 +64,7 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, FreezeGCReq, GenerateReqInput, + GetLoadReqInput, HealthCheckOutput, MultiTokenizerWrapper, OpenSessionReqInput, @@ -73,6 +74,7 @@ from sglang.srt.managers.io_struct import ( TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, + WatchLoadUpdateReq, ) from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors @@ -1240,6 +1242,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): self.asyncio_tasks.add( loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) ) + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.watch_load_thread)) + ) def dump_requests_before_crash(self): if self.crash_dump_performed: @@ -1844,6 +1849,20 @@ class TokenizerManager(TokenizerCommunicatorMixin): return scores + async def watch_load_thread(self): + # Only for dp_controller when dp_size > 1 + if ( + self.server_args.dp_size == 1 + or self.server_args.load_balance_method == "round_robin" + ): + return + + while True: + await asyncio.sleep(self.server_args.load_watch_interval) + loads = await self.get_load_communicator(GetLoadReqInput()) + load_udpate_req = WatchLoadUpdateReq(loads=loads) + self.send_to_scheduler.send_pyobj(load_udpate_req) + class ServerStatus(Enum): Up = "Up" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9e1d9e0c2..a556febaa 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -233,6 +233,7 @@ class ServerArgs: # Data parallelism dp_size: int = 1 load_balance_method: str = "round_robin" + load_watch_interval: float = 0.1 # FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation prefill_round_robin_balance: bool = False @@ -663,6 +664,7 @@ class ServerArgs: if self.dp_size == 1: self.enable_dp_attention = False + self.enable_dp_lm_head = False # Data parallelism attention if self.enable_dp_attention: @@ -1488,6 +1490,12 @@ class ServerArgs: "minimum_tokens", ], ) + parser.add_argument( + "--load-watch-interval", + type=float, + default=ServerArgs.load_watch_interval, + help="The interval of load watching in seconds.", + ) parser.add_argument( "--prefill-round-robin-balance", default=ServerArgs.prefill_round_robin_balance, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 914d371b7..1e60208fc 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1160,7 +1160,7 @@ def pytorch_profile(name, func, *args, data_size=-1): def get_zmq_socket( context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool -): +) -> zmq.Socket: mem = psutil.virtual_memory() total_mem = mem.total / 1024**3 available_mem = mem.available / 1024**3 diff --git a/python/sglang/utils.py b/python/sglang/utils.py index f6bf20c42..23849af54 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -472,6 +472,10 @@ def wait_for_server(base_url: str, timeout: int = None) -> None: class TypeBasedDispatcher: def __init__(self, mapping: List[Tuple[Type, Callable]]): self._mapping = mapping + self._fallback_fn = None + + def add_fallback_fn(self, fallback_fn: Callable): + self._fallback_fn = fallback_fn def __iadd__(self, other: "TypeBasedDispatcher"): self._mapping.extend(other._mapping) @@ -481,6 +485,9 @@ class TypeBasedDispatcher: for ty, fn in self._mapping: if isinstance(obj, ty): return fn(obj) + + if self._fallback_fn is not None: + return self._fallback_fn(obj) raise ValueError(f"Invalid object: {obj}")