PD Rust LB (PO2) (#6437)
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
@@ -103,7 +103,7 @@ class GenerateReqInput:
|
||||
|
||||
# For disaggregated inference
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
bootstrap_port: Optional[Union[List[int], int]] = None
|
||||
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
|
||||
@@ -1911,6 +1911,27 @@ class Scheduler(
|
||||
if_success = False
|
||||
return if_success
|
||||
|
||||
def get_load(self):
|
||||
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
||||
load = (
|
||||
self.max_total_num_tokens
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
- self.tree_cache.evictable_size()
|
||||
)
|
||||
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
load += sum(
|
||||
len(req.origin_input_ids)
|
||||
for req in self.disagg_prefill_bootstrap_queue.queue
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
load += sum(
|
||||
len(req.req.origin_input_ids)
|
||||
for req in self.disagg_decode_prealloc_queue.queue
|
||||
)
|
||||
|
||||
return load
|
||||
|
||||
def get_internal_state(self, recv_req: GetInternalStateReq):
|
||||
ret = dict(global_server_args_dict)
|
||||
ret["last_gen_throughput"] = self.last_gen_throughput
|
||||
@@ -1920,9 +1941,10 @@ class Scheduler(
|
||||
)
|
||||
if RECORD_STEP_TIME:
|
||||
ret["step_time_dict"] = self.step_time_dict
|
||||
return GetInternalStateReqOutput(
|
||||
internal_state=ret,
|
||||
)
|
||||
|
||||
ret["load"] = self.get_load()
|
||||
|
||||
return GetInternalStateReqOutput(internal_state=ret)
|
||||
|
||||
def set_internal_state(self, recv_req: SetInternalStateReq):
|
||||
server_args_dict = recv_req.server_args
|
||||
|
||||
@@ -395,6 +395,9 @@ class TokenizerManager:
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
|
||||
self.current_load = 0
|
||||
self.current_load_lock = asyncio.Lock()
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -983,6 +986,14 @@ class TokenizerManager:
|
||||
# Many DP ranks
|
||||
return [res.internal_state for res in responses]
|
||||
|
||||
async def get_load(self) -> dict:
|
||||
# TODO(lsyin): fake load report server
|
||||
if not self.current_load_lock.locked():
|
||||
async with self.current_load_lock:
|
||||
internal_state = await self.get_internal_state()
|
||||
self.current_load = internal_state[0]["load"]
|
||||
return {"load": self.current_load}
|
||||
|
||||
async def set_internal_state(
|
||||
self, obj: SetInternalStateReq
|
||||
) -> SetInternalStateReqOutput:
|
||||
|
||||
Reference in New Issue
Block a user