[4/N]DP refactor: support watching mode get_load and shortest queue strategy (#10201)
This commit is contained in:
@@ -27,7 +27,7 @@ import tempfile
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
|
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import setproctitle
|
import setproctitle
|
||||||
|
|
||||||
@@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.multi_tokenizer_mixin import (
|
from sglang.srt.managers.multi_tokenizer_mixin import (
|
||||||
MultiTokenizerManager,
|
MultiTokenizerManager,
|
||||||
|
MultiTokenizerRouter,
|
||||||
get_main_process_id,
|
get_main_process_id,
|
||||||
monkey_patch_uvicorn_multiprocessing,
|
monkey_patch_uvicorn_multiprocessing,
|
||||||
read_from_shared_memory,
|
read_from_shared_memory,
|
||||||
@@ -127,7 +128,9 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
|||||||
# Store global states
|
# Store global states
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class _GlobalState:
|
class _GlobalState:
|
||||||
tokenizer_manager: TokenizerManager
|
tokenizer_manager: Union[
|
||||||
|
TokenizerManager, MultiTokenizerRouter, MultiTokenizerManager
|
||||||
|
]
|
||||||
template_manager: TemplateManager
|
template_manager: TemplateManager
|
||||||
scheduler_info: Dict
|
scheduler_info: Dict
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import struct
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from collections import deque
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
@@ -34,6 +35,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BlockReqInput,
|
BlockReqInput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
|
WatchLoadUpdateReq,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
@@ -46,7 +48,7 @@ from sglang.srt.utils import (
|
|||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_itself_when_parent_died,
|
kill_itself_when_parent_died,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -67,6 +69,42 @@ class LoadBalanceMethod(Enum):
|
|||||||
raise ValueError(f"Invalid load balance method: {method}") from exc
|
raise ValueError(f"Invalid load balance method: {method}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
class DPBudget:
|
||||||
|
def __init__(self):
|
||||||
|
# TODO: support minimum tokens method
|
||||||
|
self.budget_queue = deque()
|
||||||
|
|
||||||
|
def update_budget(self, load_update: WatchLoadUpdateReq):
|
||||||
|
"""Update the budget queue.
|
||||||
|
Use num_reqs instead of num_waiting_reqs to balance decode running batch.
|
||||||
|
"""
|
||||||
|
loads = load_update.loads
|
||||||
|
self.budget_queue.clear()
|
||||||
|
|
||||||
|
num_reqs = [load.num_reqs for load in loads]
|
||||||
|
if not num_reqs:
|
||||||
|
return
|
||||||
|
|
||||||
|
max_num_reqs = max(num_reqs)
|
||||||
|
if all(x == max_num_reqs for x in num_reqs):
|
||||||
|
return
|
||||||
|
|
||||||
|
while any(x != num_reqs[0] for x in num_reqs):
|
||||||
|
min_load = min(num_reqs)
|
||||||
|
min_indices = [i for i, x in enumerate(num_reqs) if x == min_load]
|
||||||
|
second_min_load = min(x for x in num_reqs if x > min_load)
|
||||||
|
self.budget_queue.extend(
|
||||||
|
[loads[i].dp_rank for i in min_indices] * (second_min_load - min_load)
|
||||||
|
)
|
||||||
|
for idx in min_indices:
|
||||||
|
num_reqs[idx] = second_min_load
|
||||||
|
|
||||||
|
def dispatch(self):
|
||||||
|
if self.budget_queue:
|
||||||
|
return self.budget_queue.popleft()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class DataParallelController:
|
class DataParallelController:
|
||||||
"""A controller that dispatches requests to multiple data parallel workers."""
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
||||||
|
|
||||||
@@ -104,6 +142,9 @@ class DataParallelController:
|
|||||||
}
|
}
|
||||||
self.dispatching = dispatch_lookup[self.load_balance_method]
|
self.dispatching = dispatch_lookup[self.load_balance_method]
|
||||||
|
|
||||||
|
# Load balance budget
|
||||||
|
self.dp_budget = DPBudget()
|
||||||
|
|
||||||
# Launch data parallel workers
|
# Launch data parallel workers
|
||||||
self.scheduler_procs = []
|
self.scheduler_procs = []
|
||||||
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
|
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
|
||||||
@@ -127,6 +168,31 @@ class DataParallelController:
|
|||||||
|
|
||||||
self.max_req_input_len = None
|
self.max_req_input_len = None
|
||||||
|
|
||||||
|
self.init_dispatcher()
|
||||||
|
|
||||||
|
def send_to_all_workers(self, obj):
|
||||||
|
for worker in self.workers:
|
||||||
|
worker.send_pyobj(obj)
|
||||||
|
|
||||||
|
def send_control_message(self, obj):
|
||||||
|
# Send control messages to first worker of tp group
|
||||||
|
for worker in self.workers[:: self.control_message_step]:
|
||||||
|
worker.send_pyobj(obj)
|
||||||
|
|
||||||
|
def handle_load_update_req(self, obj):
|
||||||
|
self.dp_budget.update_budget(obj)
|
||||||
|
|
||||||
|
def init_dispatcher(self):
|
||||||
|
self._request_dispatcher = TypeBasedDispatcher(
|
||||||
|
[
|
||||||
|
(TokenizedGenerateReqInput, self.dispatching),
|
||||||
|
(TokenizedEmbeddingReqInput, self.dispatching),
|
||||||
|
(BlockReqInput, self.send_to_all_workers),
|
||||||
|
(WatchLoadUpdateReq, self.handle_load_update_req),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self._request_dispatcher.add_fallback_fn(self.send_control_message)
|
||||||
|
|
||||||
def launch_dp_schedulers(self, server_args, port_args):
|
def launch_dp_schedulers(self, server_args, port_args):
|
||||||
base_gpu_id = 0
|
base_gpu_id = 0
|
||||||
|
|
||||||
@@ -291,10 +357,14 @@ class DataParallelController:
|
|||||||
else:
|
else:
|
||||||
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
||||||
|
|
||||||
def shortest_queue_scheduler(self, input_requests):
|
def shortest_queue_scheduler(self, req):
|
||||||
if self.maybe_external_dp_rank_routing(req):
|
if self.maybe_external_dp_rank_routing(req):
|
||||||
return
|
return
|
||||||
raise NotImplementedError()
|
target_worker = self.dp_budget.dispatch()
|
||||||
|
if target_worker is None:
|
||||||
|
self.round_robin_scheduler(req)
|
||||||
|
else:
|
||||||
|
self.workers[target_worker].send_pyobj(req)
|
||||||
|
|
||||||
def minimum_tokens_scheduler(self, req):
|
def minimum_tokens_scheduler(self, req):
|
||||||
if self.maybe_external_dp_rank_routing(req):
|
if self.maybe_external_dp_rank_routing(req):
|
||||||
@@ -333,22 +403,7 @@ class DataParallelController:
|
|||||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||||
except zmq.ZMQError:
|
except zmq.ZMQError:
|
||||||
break
|
break
|
||||||
|
self._request_dispatcher(recv_req)
|
||||||
if isinstance(
|
|
||||||
recv_req,
|
|
||||||
(
|
|
||||||
TokenizedGenerateReqInput,
|
|
||||||
TokenizedEmbeddingReqInput,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
self.dispatching(recv_req)
|
|
||||||
elif isinstance(recv_req, BlockReqInput):
|
|
||||||
for worker in self.workers:
|
|
||||||
worker.send_pyobj(recv_req)
|
|
||||||
else:
|
|
||||||
# Send other control messages to first worker of tp group
|
|
||||||
for worker in self.workers[:: self.control_message_step]:
|
|
||||||
worker.send_pyobj(recv_req)
|
|
||||||
|
|
||||||
|
|
||||||
def run_data_parallel_controller_process(
|
def run_data_parallel_controller_process(
|
||||||
|
|||||||
@@ -297,7 +297,7 @@ def run_detokenizer_process(
|
|||||||
else:
|
else:
|
||||||
manager.event_loop()
|
manager.event_loop()
|
||||||
except Exception:
|
except Exception:
|
||||||
manager.socket_mapping.clear_all_sockets()
|
manager.maybe_clear_socket_mapping()
|
||||||
traceback = get_exception_traceback()
|
traceback = get_exception_traceback()
|
||||||
logger.error(f"DetokenizerManager hit an exception: {traceback}")
|
logger.error(f"DetokenizerManager hit an exception: {traceback}")
|
||||||
parent_process.send_signal(signal.SIGQUIT)
|
parent_process.send_signal(signal.SIGQUIT)
|
||||||
|
|||||||
@@ -1374,3 +1374,21 @@ class BlockReqType(Enum):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BlockReqInput:
|
class BlockReqInput:
|
||||||
type: BlockReqType
|
type: BlockReqType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GetLoadReqInput:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GetLoadReqOutput:
|
||||||
|
dp_rank: int
|
||||||
|
num_reqs: int
|
||||||
|
num_waiting_reqs: int
|
||||||
|
num_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WatchLoadUpdateReq:
|
||||||
|
loads: List[GetLoadReqOutput]
|
||||||
|
|||||||
@@ -354,6 +354,10 @@ class MultiHttpWorkerDetokenizerMixin:
|
|||||||
worker_ids = []
|
worker_ids = []
|
||||||
return worker_ids
|
return worker_ids
|
||||||
|
|
||||||
|
def maybe_clear_socket_mapping(self):
|
||||||
|
if hasattr(self, "socket_mapping"):
|
||||||
|
self.socket_mapping.clear_all_sockets()
|
||||||
|
|
||||||
def multi_http_worker_event_loop(self):
|
def multi_http_worker_event_loop(self):
|
||||||
"""The event loop that handles requests, for multi multi-http-worker mode"""
|
"""The event loop that handles requests, for multi multi-http-worker mode"""
|
||||||
self.socket_mapping = SocketMapping()
|
self.socket_mapping = SocketMapping()
|
||||||
|
|||||||
@@ -79,6 +79,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
FreezeGCReq,
|
FreezeGCReq,
|
||||||
GetInternalStateReq,
|
GetInternalStateReq,
|
||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
|
GetLoadReqInput,
|
||||||
|
GetLoadReqOutput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
HealthCheckOutput,
|
HealthCheckOutput,
|
||||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||||
@@ -577,6 +579,7 @@ class Scheduler(
|
|||||||
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
||||||
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
||||||
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
||||||
|
(GetLoadReqInput, self.get_load),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2279,39 +2282,50 @@ class Scheduler(
|
|||||||
if_success = False
|
if_success = False
|
||||||
return if_success
|
return if_success
|
||||||
|
|
||||||
def get_load(self):
|
def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
|
||||||
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
||||||
|
|
||||||
if self.is_hybrid:
|
if self.is_hybrid:
|
||||||
load_full = (
|
num_tokens_full = (
|
||||||
self.full_tokens_per_layer
|
self.full_tokens_per_layer
|
||||||
- self.token_to_kv_pool_allocator.full_available_size()
|
- self.token_to_kv_pool_allocator.full_available_size()
|
||||||
- self.tree_cache.full_evictable_size()
|
- self.tree_cache.full_evictable_size()
|
||||||
)
|
)
|
||||||
load_swa = (
|
num_tokens_swa = (
|
||||||
self.swa_tokens_per_layer
|
self.swa_tokens_per_layer
|
||||||
- self.token_to_kv_pool_allocator.swa_available_size()
|
- self.token_to_kv_pool_allocator.swa_available_size()
|
||||||
- self.tree_cache.swa_evictable_size()
|
- self.tree_cache.swa_evictable_size()
|
||||||
)
|
)
|
||||||
load = max(load_full, load_swa)
|
num_tokens = max(num_tokens_full, num_tokens_swa)
|
||||||
else:
|
else:
|
||||||
load = (
|
num_tokens = (
|
||||||
self.max_total_num_tokens
|
self.max_total_num_tokens
|
||||||
- self.token_to_kv_pool_allocator.available_size()
|
- self.token_to_kv_pool_allocator.available_size()
|
||||||
- self.tree_cache.evictable_size()
|
- self.tree_cache.evictable_size()
|
||||||
)
|
)
|
||||||
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
|
||||||
|
# Tokens in waiting queue, bootstrap queue, prealloc queue
|
||||||
|
num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
||||||
|
num_waiting_reqs = len(self.waiting_queue)
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
load += sum(
|
num_tokens += sum(
|
||||||
len(req.origin_input_ids)
|
len(req.origin_input_ids)
|
||||||
for req in self.disagg_prefill_bootstrap_queue.queue
|
for req in self.disagg_prefill_bootstrap_queue.queue
|
||||||
)
|
)
|
||||||
|
num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
load += sum(
|
num_tokens += sum(
|
||||||
len(req.req.origin_input_ids)
|
len(req.req.origin_input_ids)
|
||||||
for req in self.disagg_decode_prealloc_queue.queue
|
for req in self.disagg_decode_prealloc_queue.queue
|
||||||
)
|
)
|
||||||
|
num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
|
||||||
|
|
||||||
return load
|
return GetLoadReqOutput(
|
||||||
|
dp_rank=self.dp_rank,
|
||||||
|
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
|
||||||
|
num_waiting_reqs=num_waiting_reqs,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
def get_internal_state(self, recv_req: GetInternalStateReq):
|
def get_internal_state(self, recv_req: GetInternalStateReq):
|
||||||
ret = dict(global_server_args_dict)
|
ret = dict(global_server_args_dict)
|
||||||
@@ -2337,8 +2351,6 @@ class Scheduler(
|
|||||||
if RECORD_STEP_TIME:
|
if RECORD_STEP_TIME:
|
||||||
ret["step_time_dict"] = self.step_time_dict
|
ret["step_time_dict"] = self.step_time_dict
|
||||||
|
|
||||||
ret["load"] = self.get_load()
|
|
||||||
|
|
||||||
return GetInternalStateReqOutput(internal_state=ret)
|
return GetInternalStateReqOutput(internal_state=ret)
|
||||||
|
|
||||||
def set_internal_state(self, recv_req: SetInternalStateReq):
|
def set_internal_state(self, recv_req: SetInternalStateReq):
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ class SchedulerMetricsMixin:
|
|||||||
self.server_args.load_balance_method == "minimum_tokens"
|
self.server_args.load_balance_method == "minimum_tokens"
|
||||||
and self.forward_ct % 40 == 0
|
and self.forward_ct % 40 == 0
|
||||||
):
|
):
|
||||||
holding_tokens = self.get_load()
|
holding_tokens = self.get_load().num_tokens
|
||||||
|
|
||||||
new_recv_dp_balance_id_list, holding_token_list = (
|
new_recv_dp_balance_id_list, holding_token_list = (
|
||||||
self.gather_dp_balance_info(holding_tokens)
|
self.gather_dp_balance_info(holding_tokens)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@@ -18,6 +19,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
|
import zmq
|
||||||
|
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
ClearHiCacheReqInput,
|
ClearHiCacheReqInput,
|
||||||
@@ -28,6 +30,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
FlushCacheReqOutput,
|
FlushCacheReqOutput,
|
||||||
GetInternalStateReq,
|
GetInternalStateReq,
|
||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
|
GetLoadReqInput,
|
||||||
|
GetLoadReqOutput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
GetWeightsByNameReqOutput,
|
GetWeightsByNameReqOutput,
|
||||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||||
@@ -75,14 +79,17 @@ class _Communicator(Generic[T]):
|
|||||||
|
|
||||||
enable_multi_tokenizer = False
|
enable_multi_tokenizer = False
|
||||||
|
|
||||||
def __init__(self, sender, fan_out: int):
|
def __init__(self, sender: zmq.Socket, fan_out: int, mode="queueing"):
|
||||||
self._sender = sender
|
self._sender = sender
|
||||||
self._fan_out = fan_out
|
self._fan_out = fan_out
|
||||||
|
self._mode = mode
|
||||||
self._result_event: Optional[asyncio.Event] = None
|
self._result_event: Optional[asyncio.Event] = None
|
||||||
self._result_values: Optional[List[T]] = None
|
self._result_values: Optional[List[T]] = None
|
||||||
self._ready_queue: Deque[asyncio.Future] = deque()
|
self._ready_queue: Deque[asyncio.Future] = deque()
|
||||||
|
|
||||||
async def __call__(self, obj):
|
assert mode in ["queueing", "watching"]
|
||||||
|
|
||||||
|
async def queueing_call(self, obj: T):
|
||||||
ready_event = asyncio.Event()
|
ready_event = asyncio.Event()
|
||||||
if self._result_event is not None or len(self._ready_queue) > 0:
|
if self._result_event is not None or len(self._ready_queue) > 0:
|
||||||
self._ready_queue.append(ready_event)
|
self._ready_queue.append(ready_event)
|
||||||
@@ -106,6 +113,28 @@ class _Communicator(Generic[T]):
|
|||||||
|
|
||||||
return result_values
|
return result_values
|
||||||
|
|
||||||
|
async def watching_call(self, obj):
|
||||||
|
if self._result_event is None:
|
||||||
|
assert self._result_values is None
|
||||||
|
self._result_values = []
|
||||||
|
self._result_event = asyncio.Event()
|
||||||
|
|
||||||
|
if obj:
|
||||||
|
if _Communicator.enable_multi_tokenizer:
|
||||||
|
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
|
||||||
|
self._sender.send_pyobj(obj)
|
||||||
|
|
||||||
|
await self._result_event.wait()
|
||||||
|
result_values = copy.deepcopy(self._result_values)
|
||||||
|
self._result_event = self._result_values = None
|
||||||
|
return result_values
|
||||||
|
|
||||||
|
async def __call__(self, obj):
|
||||||
|
if self._mode == "queueing":
|
||||||
|
return await self.queueing_call(obj)
|
||||||
|
else:
|
||||||
|
return await self.watching_call(obj)
|
||||||
|
|
||||||
def handle_recv(self, recv_obj: T):
|
def handle_recv(self, recv_obj: T):
|
||||||
self._result_values.append(recv_obj)
|
self._result_values.append(recv_obj)
|
||||||
if len(self._result_values) == self._fan_out:
|
if len(self._result_values) == self._fan_out:
|
||||||
@@ -165,6 +194,9 @@ class TokenizerCommunicatorMixin:
|
|||||||
self.update_lora_adapter_communicator = _Communicator(
|
self.update_lora_adapter_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
|
self.get_load_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size, mode="watching"
|
||||||
|
)
|
||||||
|
|
||||||
self._result_dispatcher += self._get_communicator_dispatcher()
|
self._result_dispatcher += self._get_communicator_dispatcher()
|
||||||
|
|
||||||
@@ -235,6 +267,10 @@ class TokenizerCommunicatorMixin:
|
|||||||
LoRAUpdateResult,
|
LoRAUpdateResult,
|
||||||
self.update_lora_adapter_communicator.handle_recv,
|
self.update_lora_adapter_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
GetLoadReqOutput,
|
||||||
|
self.get_load_communicator.handle_recv,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -528,10 +564,6 @@ class TokenizerCommunicatorMixin:
|
|||||||
)
|
)
|
||||||
return [res.updated for res in responses]
|
return [res.updated for res in responses]
|
||||||
|
|
||||||
async def get_load(self: TokenizerManager) -> dict:
|
async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]:
|
||||||
# TODO(lsyin): fake load report server
|
req = GetLoadReqInput()
|
||||||
if not self.current_load_lock.locked():
|
return await self.get_load_communicator(req)
|
||||||
async with self.current_load_lock:
|
|
||||||
internal_state = await self.get_internal_state()
|
|
||||||
self.current_load = internal_state[0]["load"]
|
|
||||||
return {"load": self.current_load}
|
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
FreezeGCReq,
|
FreezeGCReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
|
GetLoadReqInput,
|
||||||
HealthCheckOutput,
|
HealthCheckOutput,
|
||||||
MultiTokenizerWrapper,
|
MultiTokenizerWrapper,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
@@ -73,6 +74,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightFromDiskReqOutput,
|
UpdateWeightFromDiskReqOutput,
|
||||||
|
WatchLoadUpdateReq,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.mm_utils import TensorTransportMode
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
||||||
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
||||||
@@ -1240,6 +1242,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
self.asyncio_tasks.add(
|
self.asyncio_tasks.add(
|
||||||
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
||||||
)
|
)
|
||||||
|
self.asyncio_tasks.add(
|
||||||
|
loop.create_task(print_exception_wrapper(self.watch_load_thread))
|
||||||
|
)
|
||||||
|
|
||||||
def dump_requests_before_crash(self):
|
def dump_requests_before_crash(self):
|
||||||
if self.crash_dump_performed:
|
if self.crash_dump_performed:
|
||||||
@@ -1844,6 +1849,20 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
async def watch_load_thread(self):
|
||||||
|
# Only for dp_controller when dp_size > 1
|
||||||
|
if (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
or self.server_args.load_balance_method == "round_robin"
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(self.server_args.load_watch_interval)
|
||||||
|
loads = await self.get_load_communicator(GetLoadReqInput())
|
||||||
|
load_udpate_req = WatchLoadUpdateReq(loads=loads)
|
||||||
|
self.send_to_scheduler.send_pyobj(load_udpate_req)
|
||||||
|
|
||||||
|
|
||||||
class ServerStatus(Enum):
|
class ServerStatus(Enum):
|
||||||
Up = "Up"
|
Up = "Up"
|
||||||
|
|||||||
@@ -233,6 +233,7 @@ class ServerArgs:
|
|||||||
# Data parallelism
|
# Data parallelism
|
||||||
dp_size: int = 1
|
dp_size: int = 1
|
||||||
load_balance_method: str = "round_robin"
|
load_balance_method: str = "round_robin"
|
||||||
|
load_watch_interval: float = 0.1
|
||||||
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
|
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
|
||||||
prefill_round_robin_balance: bool = False
|
prefill_round_robin_balance: bool = False
|
||||||
|
|
||||||
@@ -663,6 +664,7 @@ class ServerArgs:
|
|||||||
|
|
||||||
if self.dp_size == 1:
|
if self.dp_size == 1:
|
||||||
self.enable_dp_attention = False
|
self.enable_dp_attention = False
|
||||||
|
self.enable_dp_lm_head = False
|
||||||
|
|
||||||
# Data parallelism attention
|
# Data parallelism attention
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
@@ -1488,6 +1490,12 @@ class ServerArgs:
|
|||||||
"minimum_tokens",
|
"minimum_tokens",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load-watch-interval",
|
||||||
|
type=float,
|
||||||
|
default=ServerArgs.load_watch_interval,
|
||||||
|
help="The interval of load watching in seconds.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prefill-round-robin-balance",
|
"--prefill-round-robin-balance",
|
||||||
default=ServerArgs.prefill_round_robin_balance,
|
default=ServerArgs.prefill_round_robin_balance,
|
||||||
|
|||||||
@@ -1160,7 +1160,7 @@ def pytorch_profile(name, func, *args, data_size=-1):
|
|||||||
|
|
||||||
def get_zmq_socket(
|
def get_zmq_socket(
|
||||||
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
|
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
|
||||||
):
|
) -> zmq.Socket:
|
||||||
mem = psutil.virtual_memory()
|
mem = psutil.virtual_memory()
|
||||||
total_mem = mem.total / 1024**3
|
total_mem = mem.total / 1024**3
|
||||||
available_mem = mem.available / 1024**3
|
available_mem = mem.available / 1024**3
|
||||||
|
|||||||
@@ -472,6 +472,10 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
|
|||||||
class TypeBasedDispatcher:
|
class TypeBasedDispatcher:
|
||||||
def __init__(self, mapping: List[Tuple[Type, Callable]]):
|
def __init__(self, mapping: List[Tuple[Type, Callable]]):
|
||||||
self._mapping = mapping
|
self._mapping = mapping
|
||||||
|
self._fallback_fn = None
|
||||||
|
|
||||||
|
def add_fallback_fn(self, fallback_fn: Callable):
|
||||||
|
self._fallback_fn = fallback_fn
|
||||||
|
|
||||||
def __iadd__(self, other: "TypeBasedDispatcher"):
|
def __iadd__(self, other: "TypeBasedDispatcher"):
|
||||||
self._mapping.extend(other._mapping)
|
self._mapping.extend(other._mapping)
|
||||||
@@ -481,6 +485,9 @@ class TypeBasedDispatcher:
|
|||||||
for ty, fn in self._mapping:
|
for ty, fn in self._mapping:
|
||||||
if isinstance(obj, ty):
|
if isinstance(obj, ty):
|
||||||
return fn(obj)
|
return fn(obj)
|
||||||
|
|
||||||
|
if self._fallback_fn is not None:
|
||||||
|
return self._fallback_fn(obj)
|
||||||
raise ValueError(f"Invalid object: {obj}")
|
raise ValueError(f"Invalid object: {obj}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user