[4/N]DP refactor: support watching mode get_load and shortest queue strategy (#10201)
This commit is contained in:
@@ -79,6 +79,8 @@ from sglang.srt.managers.io_struct import (
|
||||
FreezeGCReq,
|
||||
GetInternalStateReq,
|
||||
GetInternalStateReqOutput,
|
||||
GetLoadReqInput,
|
||||
GetLoadReqOutput,
|
||||
GetWeightsByNameReqInput,
|
||||
HealthCheckOutput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
@@ -577,6 +579,7 @@ class Scheduler(
|
||||
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
||||
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
||||
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
||||
(GetLoadReqInput, self.get_load),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -2279,39 +2282,50 @@ class Scheduler(
|
||||
if_success = False
|
||||
return if_success
|
||||
|
||||
def get_load(self):
|
||||
def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
|
||||
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
||||
|
||||
if self.is_hybrid:
|
||||
load_full = (
|
||||
num_tokens_full = (
|
||||
self.full_tokens_per_layer
|
||||
- self.token_to_kv_pool_allocator.full_available_size()
|
||||
- self.tree_cache.full_evictable_size()
|
||||
)
|
||||
load_swa = (
|
||||
num_tokens_swa = (
|
||||
self.swa_tokens_per_layer
|
||||
- self.token_to_kv_pool_allocator.swa_available_size()
|
||||
- self.tree_cache.swa_evictable_size()
|
||||
)
|
||||
load = max(load_full, load_swa)
|
||||
num_tokens = max(num_tokens_full, num_tokens_swa)
|
||||
else:
|
||||
load = (
|
||||
num_tokens = (
|
||||
self.max_total_num_tokens
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
- self.tree_cache.evictable_size()
|
||||
)
|
||||
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
||||
|
||||
# Tokens in waiting queue, bootstrap queue, prealloc queue
|
||||
num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
||||
num_waiting_reqs = len(self.waiting_queue)
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
load += sum(
|
||||
num_tokens += sum(
|
||||
len(req.origin_input_ids)
|
||||
for req in self.disagg_prefill_bootstrap_queue.queue
|
||||
)
|
||||
num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
load += sum(
|
||||
num_tokens += sum(
|
||||
len(req.req.origin_input_ids)
|
||||
for req in self.disagg_decode_prealloc_queue.queue
|
||||
)
|
||||
num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
|
||||
|
||||
return load
|
||||
return GetLoadReqOutput(
|
||||
dp_rank=self.dp_rank,
|
||||
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
|
||||
num_waiting_reqs=num_waiting_reqs,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
def get_internal_state(self, recv_req: GetInternalStateReq):
|
||||
ret = dict(global_server_args_dict)
|
||||
@@ -2337,8 +2351,6 @@ class Scheduler(
|
||||
if RECORD_STEP_TIME:
|
||||
ret["step_time_dict"] = self.step_time_dict
|
||||
|
||||
ret["load"] = self.get_load()
|
||||
|
||||
return GetInternalStateReqOutput(internal_state=ret)
|
||||
|
||||
def set_internal_state(self, recv_req: SetInternalStateReq):
|
||||
|
||||
Reference in New Issue
Block a user