Better PD initialization (#5751)
This commit is contained in:
@@ -3,10 +3,12 @@ Minimal HTTP load balancer for prefill and decode servers for testing.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
import urllib
|
import urllib
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import orjson
|
import orjson
|
||||||
@@ -14,11 +16,32 @@ import uvicorn
|
|||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger():
|
||||||
|
logger = logging.getLogger("pdlb")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
class PrefillConfig:
|
class PrefillConfig:
|
||||||
def __init__(self, url: str, bootstrap_port: int):
|
url: str
|
||||||
self.url = url
|
bootstrap_port: Optional[int] = None
|
||||||
self.bootstrap_port = bootstrap_port
|
|
||||||
|
|
||||||
|
|
||||||
class MiniLoadBalancer:
|
class MiniLoadBalancer:
|
||||||
@@ -28,6 +51,10 @@ class MiniLoadBalancer:
|
|||||||
self.decode_servers = decode_servers
|
self.decode_servers = decode_servers
|
||||||
|
|
||||||
def select_pair(self):
|
def select_pair(self):
|
||||||
|
# TODO: return some message instead of panic
|
||||||
|
assert len(self.prefill_configs) > 0, "No prefill servers available"
|
||||||
|
assert len(self.decode_servers) > 0, "No decode servers available"
|
||||||
|
|
||||||
prefill_config = random.choice(self.prefill_configs)
|
prefill_config = random.choice(self.prefill_configs)
|
||||||
decode_server = random.choice(self.decode_servers)
|
decode_server = random.choice(self.decode_servers)
|
||||||
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
||||||
@@ -47,7 +74,7 @@ class MiniLoadBalancer:
|
|||||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||||
]
|
]
|
||||||
# Wait for both responses to complete. Prefill should end first.
|
# Wait for both responses to complete. Prefill should end first.
|
||||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
_, decode_response = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
return ORJSONResponse(
|
return ORJSONResponse(
|
||||||
content=await decode_response.json(),
|
content=await decode_response.json(),
|
||||||
@@ -268,6 +295,32 @@ async def get_models():
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/register")
|
||||||
|
async def register(obj: PDRegistryRequest):
|
||||||
|
if obj.mode == "prefill":
|
||||||
|
load_balancer.prefill_configs.append(
|
||||||
|
PrefillConfig(obj.registry_url, obj.bootstrap_port)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
|
||||||
|
)
|
||||||
|
elif obj.mode == "decode":
|
||||||
|
load_balancer.decode_servers.append(obj.registry_url)
|
||||||
|
logger.info(f"Registered decode server: {obj.registry_url}")
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Invalid mode. Must be either PREFILL or DECODE.",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
|
||||||
|
f"#Decode servers: {len(load_balancer.decode_servers)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def run(prefill_configs, decode_addrs, host, port):
|
def run(prefill_configs, decode_addrs, host, port):
|
||||||
global load_balancer
|
global load_balancer
|
||||||
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
|
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
|
||||||
@@ -279,15 +332,16 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
|
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
|
"--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prefill-bootstrap-ports",
|
"--prefill-bootstrap-ports",
|
||||||
help="Comma-separated bootstrap ports for prefill servers",
|
type=int,
|
||||||
default="8998",
|
nargs="+",
|
||||||
)
|
help="Bootstrap ports for prefill servers",
|
||||||
parser.add_argument(
|
|
||||||
"--decode", required=True, help="Comma-separated URLs for decode servers"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
|
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
|
||||||
@@ -297,22 +351,19 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
prefill_urls = args.prefill.split(",")
|
bootstrap_ports = args.prefill_bootstrap_ports
|
||||||
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
|
if bootstrap_ports is None:
|
||||||
|
bootstrap_ports = [None] * len(args.prefill)
|
||||||
if len(bootstrap_ports) == 1:
|
elif len(bootstrap_ports) == 1:
|
||||||
bootstrap_ports = bootstrap_ports * len(prefill_urls)
|
bootstrap_ports = bootstrap_ports * len(args.prefill)
|
||||||
else:
|
else:
|
||||||
if len(bootstrap_ports) != len(prefill_urls):
|
if len(bootstrap_ports) != len(args.prefill):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Number of prefill URLs must match number of bootstrap ports"
|
"Number of prefill URLs must match number of bootstrap ports"
|
||||||
)
|
)
|
||||||
exit(1)
|
|
||||||
|
|
||||||
prefill_configs = []
|
prefill_configs = [
|
||||||
for url, port in zip(prefill_urls, bootstrap_ports):
|
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
|
||||||
prefill_configs.append(PrefillConfig(url, port))
|
]
|
||||||
|
|
||||||
decode_addrs = args.decode.split(",")
|
run(prefill_configs, args.decode, args.host, args.port)
|
||||||
|
|
||||||
run(prefill_configs, decode_addrs, args.host, args.port)
|
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from sglang.srt.utils import get_ip
|
||||||
|
|
||||||
|
|
||||||
class DisaggregationMode(Enum):
|
class DisaggregationMode(Enum):
|
||||||
NULL = "null"
|
NULL = "null"
|
||||||
@@ -119,3 +124,41 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
|||||||
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
||||||
# ceil(num_kv_indices / page_size)
|
# ceil(num_kv_indices / page_size)
|
||||||
return (num_kv_indices + page_size - 1) // page_size
|
return (num_kv_indices + page_size - 1) // page_size
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class PDRegistryRequest:
|
||||||
|
"""A request to register a machine itself to the LB."""
|
||||||
|
|
||||||
|
mode: str
|
||||||
|
registry_url: str
|
||||||
|
bootstrap_port: Optional[int] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.mode == "prefill" and self.bootstrap_port is None:
|
||||||
|
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
||||||
|
elif self.mode == "decode" and self.bootstrap_port is not None:
|
||||||
|
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
||||||
|
elif self.mode not in ["prefill", "decode"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_disaggregation_server(
|
||||||
|
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
||||||
|
):
|
||||||
|
boostrap_port = bootstrap_port if mode == "prefill" else None
|
||||||
|
registry_request = PDRegistryRequest(
|
||||||
|
mode=mode,
|
||||||
|
registry_url=f"http://{get_ip()}:{server_port}",
|
||||||
|
bootstrap_port=boostrap_port,
|
||||||
|
)
|
||||||
|
res = requests.post(
|
||||||
|
f"{pdlb_url}/register",
|
||||||
|
json=dataclasses.asdict(registry_request),
|
||||||
|
)
|
||||||
|
if res.status_code != 200:
|
||||||
|
warnings.warn(
|
||||||
|
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -42,7 +42,10 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
from sglang.srt.disaggregation.utils import FakeBootstrapHost
|
from sglang.srt.disaggregation.utils import (
|
||||||
|
FakeBootstrapHost,
|
||||||
|
register_disaggregation_server,
|
||||||
|
)
|
||||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||||
from sglang.srt.function_call_parser import FunctionCallParser
|
from sglang.srt.function_call_parser import FunctionCallParser
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
@@ -871,5 +874,13 @@ def _wait_and_warmup(
|
|||||||
if server_args.debug_tensor_dump_input_file:
|
if server_args.debug_tensor_dump_input_file:
|
||||||
kill_process_tree(os.getpid())
|
kill_process_tree(os.getpid())
|
||||||
|
|
||||||
|
if server_args.pdlb_url is not None:
|
||||||
|
register_disaggregation_server(
|
||||||
|
server_args.disaggregation_mode,
|
||||||
|
server_args.port,
|
||||||
|
server_args.disaggregation_bootstrap_port,
|
||||||
|
server_args.pdlb_url,
|
||||||
|
)
|
||||||
|
|
||||||
if launch_callback is not None:
|
if launch_callback is not None:
|
||||||
launch_callback()
|
launch_callback()
|
||||||
|
|||||||
@@ -925,6 +925,10 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
custom_logit_processor = None
|
custom_logit_processor = None
|
||||||
|
|
||||||
|
if recv_req.bootstrap_port is None:
|
||||||
|
# Use default bootstrap port
|
||||||
|
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
|
||||||
|
|
||||||
req = Req(
|
req = Req(
|
||||||
recv_req.rid,
|
recv_req.rid,
|
||||||
recv_req.input_text,
|
recv_req.input_text,
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ class ServerArgs:
|
|||||||
disaggregation_bootstrap_port: int = 8998
|
disaggregation_bootstrap_port: int = 8998
|
||||||
disaggregation_transfer_backend: str = "mooncake"
|
disaggregation_transfer_backend: str = "mooncake"
|
||||||
disaggregation_ib_device: Optional[str] = None
|
disaggregation_ib_device: Optional[str] = None
|
||||||
|
pdlb_url: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
@@ -1254,6 +1255,12 @@ class ServerArgs:
|
|||||||
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
|
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
|
||||||
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
|
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pdlb-url",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
Reference in New Issue
Block a user