Better PD initialization (#5751)
This commit is contained in:
@@ -3,10 +3,12 @@ Minimal HTTP load balancer for prefill and decode servers for testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import random
|
||||
import urllib
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
@@ -14,11 +16,32 @@ import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
||||
|
||||
|
||||
def setup_logger():
|
||||
logger = logging.getLogger("pdlb")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PrefillConfig:
|
||||
def __init__(self, url: str, bootstrap_port: int):
|
||||
self.url = url
|
||||
self.bootstrap_port = bootstrap_port
|
||||
url: str
|
||||
bootstrap_port: Optional[int] = None
|
||||
|
||||
|
||||
class MiniLoadBalancer:
|
||||
@@ -28,6 +51,10 @@ class MiniLoadBalancer:
|
||||
self.decode_servers = decode_servers
|
||||
|
||||
def select_pair(self):
|
||||
# TODO: return some message instead of panic
|
||||
assert len(self.prefill_configs) > 0, "No prefill servers available"
|
||||
assert len(self.decode_servers) > 0, "No decode servers available"
|
||||
|
||||
prefill_config = random.choice(self.prefill_configs)
|
||||
decode_server = random.choice(self.decode_servers)
|
||||
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
||||
@@ -47,7 +74,7 @@ class MiniLoadBalancer:
|
||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||
]
|
||||
# Wait for both responses to complete. Prefill should end first.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
_, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
return ORJSONResponse(
|
||||
content=await decode_response.json(),
|
||||
@@ -268,6 +295,32 @@ async def get_models():
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/register")
|
||||
async def register(obj: PDRegistryRequest):
|
||||
if obj.mode == "prefill":
|
||||
load_balancer.prefill_configs.append(
|
||||
PrefillConfig(obj.registry_url, obj.bootstrap_port)
|
||||
)
|
||||
logger.info(
|
||||
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
|
||||
)
|
||||
elif obj.mode == "decode":
|
||||
load_balancer.decode_servers.append(obj.registry_url)
|
||||
logger.info(f"Registered decode server: {obj.registry_url}")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid mode. Must be either PREFILL or DECODE.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
|
||||
f"#Decode servers: {len(load_balancer.decode_servers)}"
|
||||
)
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def run(prefill_configs, decode_addrs, host, port):
|
||||
global load_balancer
|
||||
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
|
||||
@@ -279,15 +332,16 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
|
||||
parser.add_argument(
|
||||
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
|
||||
"--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-bootstrap-ports",
|
||||
help="Comma-separated bootstrap ports for prefill servers",
|
||||
default="8998",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode", required=True, help="Comma-separated URLs for decode servers"
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="Bootstrap ports for prefill servers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
|
||||
@@ -297,22 +351,19 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
prefill_urls = args.prefill.split(",")
|
||||
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
|
||||
|
||||
if len(bootstrap_ports) == 1:
|
||||
bootstrap_ports = bootstrap_ports * len(prefill_urls)
|
||||
bootstrap_ports = args.prefill_bootstrap_ports
|
||||
if bootstrap_ports is None:
|
||||
bootstrap_ports = [None] * len(args.prefill)
|
||||
elif len(bootstrap_ports) == 1:
|
||||
bootstrap_ports = bootstrap_ports * len(args.prefill)
|
||||
else:
|
||||
if len(bootstrap_ports) != len(prefill_urls):
|
||||
if len(bootstrap_ports) != len(args.prefill):
|
||||
raise ValueError(
|
||||
"Number of prefill URLs must match number of bootstrap ports"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
prefill_configs = []
|
||||
for url, port in zip(prefill_urls, bootstrap_ports):
|
||||
prefill_configs.append(PrefillConfig(url, port))
|
||||
prefill_configs = [
|
||||
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
|
||||
]
|
||||
|
||||
decode_addrs = args.decode.split(",")
|
||||
|
||||
run(prefill_configs, decode_addrs, args.host, args.port)
|
||||
run(prefill_configs, args.decode, args.host, args.port)
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import warnings
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.utils import get_ip
|
||||
|
||||
|
||||
class DisaggregationMode(Enum):
|
||||
NULL = "null"
|
||||
@@ -119,3 +124,41 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
||||
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
||||
# ceil(num_kv_indices / page_size)
|
||||
return (num_kv_indices + page_size - 1) // page_size
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PDRegistryRequest:
|
||||
"""A request to register a machine itself to the LB."""
|
||||
|
||||
mode: str
|
||||
registry_url: str
|
||||
bootstrap_port: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mode == "prefill" and self.bootstrap_port is None:
|
||||
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
||||
elif self.mode == "decode" and self.bootstrap_port is not None:
|
||||
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
||||
elif self.mode not in ["prefill", "decode"]:
|
||||
raise ValueError(
|
||||
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
||||
)
|
||||
|
||||
|
||||
def register_disaggregation_server(
|
||||
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
||||
):
|
||||
boostrap_port = bootstrap_port if mode == "prefill" else None
|
||||
registry_request = PDRegistryRequest(
|
||||
mode=mode,
|
||||
registry_url=f"http://{get_ip()}:{server_port}",
|
||||
bootstrap_port=boostrap_port,
|
||||
)
|
||||
res = requests.post(
|
||||
f"{pdlb_url}/register",
|
||||
json=dataclasses.asdict(registry_request),
|
||||
)
|
||||
if res.status_code != 200:
|
||||
warnings.warn(
|
||||
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
||||
)
|
||||
|
||||
@@ -42,7 +42,10 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.disaggregation.utils import FakeBootstrapHost
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FakeBootstrapHost,
|
||||
register_disaggregation_server,
|
||||
)
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
from sglang.srt.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import (
|
||||
@@ -871,5 +874,13 @@ def _wait_and_warmup(
|
||||
if server_args.debug_tensor_dump_input_file:
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
if server_args.pdlb_url is not None:
|
||||
register_disaggregation_server(
|
||||
server_args.disaggregation_mode,
|
||||
server_args.port,
|
||||
server_args.disaggregation_bootstrap_port,
|
||||
server_args.pdlb_url,
|
||||
)
|
||||
|
||||
if launch_callback is not None:
|
||||
launch_callback()
|
||||
|
||||
@@ -925,6 +925,10 @@ class Scheduler(
|
||||
)
|
||||
custom_logit_processor = None
|
||||
|
||||
if recv_req.bootstrap_port is None:
|
||||
# Use default bootstrap port
|
||||
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
|
||||
|
||||
req = Req(
|
||||
recv_req.rid,
|
||||
recv_req.input_text,
|
||||
|
||||
@@ -198,6 +198,7 @@ class ServerArgs:
|
||||
disaggregation_bootstrap_port: int = 8998
|
||||
disaggregation_transfer_backend: str = "mooncake"
|
||||
disaggregation_ib_device: Optional[str] = None
|
||||
pdlb_url: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Expert parallelism
|
||||
@@ -1254,6 +1255,12 @@ class ServerArgs:
|
||||
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
|
||||
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pdlb-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user