PD Rust LB (PO2) (#6437)

Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
Liangsheng Yin
2025-05-29 20:50:10 +08:00
committed by GitHub
parent 1dc6864f17
commit 78689d3393
17 changed files with 1017 additions and 42 deletions

View File

@@ -0,0 +1,140 @@
import argparse
import dataclasses
@dataclasses.dataclass
class LBArgs:
rust_lb: bool = False
host: str = "0.0.0.0"
port: int = 8000
policy: str = "random"
prefill_infos: list = dataclasses.field(default_factory=list)
decode_infos: list = dataclasses.field(default_factory=list)
log_interval: int = 5
timeout: int = 600
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--rust-lb",
action="store_true",
help="Use Rust load balancer",
)
parser.add_argument(
"--host",
type=str,
default=LBArgs.host,
help=f"Host to bind the server (default: {LBArgs.host})",
)
parser.add_argument(
"--port",
type=int,
default=LBArgs.port,
help=f"Port to bind the server (default: {LBArgs.port})",
)
parser.add_argument(
"--policy",
type=str,
default=LBArgs.policy,
choices=["random", "po2"],
help=f"Policy to use for load balancing (default: {LBArgs.policy})",
)
parser.add_argument(
"--prefill",
type=str,
default=[],
nargs="+",
help="URLs for prefill servers",
)
parser.add_argument(
"--decode",
type=str,
default=[],
nargs="+",
help="URLs for decode servers",
)
parser.add_argument(
"--prefill-bootstrap-ports",
type=int,
nargs="+",
help="Bootstrap ports for prefill servers",
)
parser.add_argument(
"--log-interval",
type=int,
default=LBArgs.log_interval,
help=f"Log interval in seconds (default: {LBArgs.log_interval})",
)
parser.add_argument(
"--timeout",
type=int,
default=LBArgs.timeout,
help=f"Timeout in seconds (default: {LBArgs.timeout})",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
bootstrap_ports = args.prefill_bootstrap_ports
if bootstrap_ports is None:
bootstrap_ports = [None] * len(args.prefill)
elif len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(args.prefill)
else:
if len(bootstrap_ports) != len(args.prefill):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)
prefill_infos = [
(url, port) for url, port in zip(args.prefill, bootstrap_ports)
]
return cls(
rust_lb=args.rust_lb,
host=args.host,
port=args.port,
policy=args.policy,
prefill_infos=prefill_infos,
decode_infos=args.decode,
log_interval=args.log_interval,
timeout=args.timeout,
)
def __post_init__(self):
if not self.rust_lb:
assert (
self.policy == "random"
), "Only random policy is supported for Python load balancer"
def main():
parser = argparse.ArgumentParser(
description="PD Disaggregation Load Balancer Server"
)
LBArgs.add_cli_args(parser)
args = parser.parse_args()
lb_args = LBArgs.from_cli_args(args)
if lb_args.rust_lb:
from sgl_pdlb._rust import LoadBalancer as RustLB
RustLB(
host=lb_args.host,
port=lb_args.port,
policy=lb_args.policy,
prefill_infos=lb_args.prefill_infos,
decode_infos=lb_args.decode_infos,
log_interval=lb_args.log_interval,
timeout=lb_args.timeout,
).start()
else:
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
prefill_configs = [
PrefillConfig(url, port) for url, port in lb_args.prefill_infos
]
run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
if __name__ == "__main__":
main()

View File

@@ -377,42 +377,7 @@ def run(prefill_configs, decode_addrs, host, port):
if __name__ == "__main__":
import argparse
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
from sglang.srt.disaggregation.launch_lb import main
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
parser.add_argument(
"--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
)
parser.add_argument(
"--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
)
parser.add_argument(
"--prefill-bootstrap-ports",
type=int,
nargs="+",
help="Bootstrap ports for prefill servers",
)
parser.add_argument(
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
)
args = parser.parse_args()
bootstrap_ports = args.prefill_bootstrap_ports
if bootstrap_ports is None:
bootstrap_ports = [None] * len(args.prefill)
elif len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(args.prefill)
else:
if len(bootstrap_ports) != len(args.prefill):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)
prefill_configs = [
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
]
run(prefill_configs, args.decode, args.host, args.port)
main()

View File

@@ -229,6 +229,11 @@ async def get_server_info():
}
@app.get("/get_load")
async def get_load():
return await _global_state.tokenizer_manager.get_load()
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
async def set_internal_state(obj: SetInternalStateReq, request: Request):
res = await _global_state.tokenizer_manager.set_internal_state(obj)

View File

@@ -103,7 +103,7 @@ class GenerateReqInput:
# For disaggregated inference
bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[int], int]] = None
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None
def contains_mm_input(self) -> bool:

View File

@@ -1911,6 +1911,27 @@ class Scheduler(
if_success = False
return if_success
def get_load(self):
# TODO(lsyin): use dynamically maintained num_waiting_tokens
load = (
self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.evictable_size()
)
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
load += sum(
len(req.origin_input_ids)
for req in self.disagg_prefill_bootstrap_queue.queue
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
load += sum(
len(req.req.origin_input_ids)
for req in self.disagg_decode_prealloc_queue.queue
)
return load
def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict)
ret["last_gen_throughput"] = self.last_gen_throughput
@@ -1920,9 +1941,10 @@ class Scheduler(
)
if RECORD_STEP_TIME:
ret["step_time_dict"] = self.step_time_dict
return GetInternalStateReqOutput(
internal_state=ret,
)
ret["load"] = self.get_load()
return GetInternalStateReqOutput(internal_state=ret)
def set_internal_state(self, recv_req: SetInternalStateReq):
server_args_dict = recv_req.server_args

View File

@@ -395,6 +395,9 @@ class TokenizerManager:
self.server_args.disaggregation_bootstrap_port
)
self.current_load = 0
self.current_load_lock = asyncio.Lock()
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -983,6 +986,14 @@ class TokenizerManager:
# Many DP ranks
return [res.internal_state for res in responses]
async def get_load(self) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput: