Simplify Router arguments passing and build it in docker image (#9964)
This commit is contained in:
@@ -1,118 +0,0 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
|
||||
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LBArgs:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
policy: str = "random"
|
||||
prefill_infos: list = dataclasses.field(default_factory=list)
|
||||
decode_infos: list = dataclasses.field(default_factory=list)
|
||||
log_interval: int = 5
|
||||
timeout: int = 600
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=LBArgs.host,
|
||||
help=f"Host to bind the server (default: {LBArgs.host})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=LBArgs.port,
|
||||
help=f"Port to bind the server (default: {LBArgs.port})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy",
|
||||
type=str,
|
||||
default=LBArgs.policy,
|
||||
choices=["random", "po2"],
|
||||
help=f"Policy to use for load balancing (default: {LBArgs.policy})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill",
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="+",
|
||||
help="URLs for prefill servers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode",
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="+",
|
||||
help="URLs for decode servers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-bootstrap-ports",
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="Bootstrap ports for prefill servers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=LBArgs.log_interval,
|
||||
help=f"Log interval in seconds (default: {LBArgs.log_interval})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=LBArgs.timeout,
|
||||
help=f"Timeout in seconds (default: {LBArgs.timeout})",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
|
||||
bootstrap_ports = args.prefill_bootstrap_ports
|
||||
if bootstrap_ports is None:
|
||||
bootstrap_ports = [None] * len(args.prefill)
|
||||
elif len(bootstrap_ports) == 1:
|
||||
bootstrap_ports = bootstrap_ports * len(args.prefill)
|
||||
else:
|
||||
if len(bootstrap_ports) != len(args.prefill):
|
||||
raise ValueError(
|
||||
"Number of prefill URLs must match number of bootstrap ports"
|
||||
)
|
||||
|
||||
prefill_infos = [
|
||||
(url, port) for url, port in zip(args.prefill, bootstrap_ports)
|
||||
]
|
||||
|
||||
return cls(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
policy=args.policy,
|
||||
prefill_infos=prefill_infos,
|
||||
decode_infos=args.decode,
|
||||
log_interval=args.log_interval,
|
||||
timeout=args.timeout,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PD Disaggregation Load Balancer Server"
|
||||
)
|
||||
LBArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
lb_args = LBArgs.from_cli_args(args)
|
||||
|
||||
prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
|
||||
run(
|
||||
prefill_configs,
|
||||
lb_args.decode_infos,
|
||||
lb_args.host,
|
||||
lb_args.port,
|
||||
lb_args.timeout,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,445 +1,6 @@
|
||||
"""
|
||||
Minimal HTTP load balancer for prefill and decode servers for testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import random
|
||||
import urllib
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
||||
from sglang.srt.utils import maybe_wrap_ipv6_address
|
||||
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
||||
1024 * 64
|
||||
) # 64KB, to prevent aiohttp's "Chunk too big" error
|
||||
|
||||
|
||||
def setup_logger():
|
||||
logger = logging.getLogger("pdlb")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PrefillConfig:
|
||||
url: str
|
||||
bootstrap_port: Optional[int] = None
|
||||
|
||||
|
||||
class MiniLoadBalancer:
|
||||
def __init__(
|
||||
self,
|
||||
prefill_configs: List[PrefillConfig],
|
||||
decode_servers: List[str],
|
||||
timeout: int,
|
||||
):
|
||||
self.prefill_configs = prefill_configs
|
||||
self.prefill_servers = [p.url for p in prefill_configs]
|
||||
self.decode_servers = decode_servers
|
||||
self.timeout = timeout
|
||||
|
||||
def add_prefill_server(self, new_prefill_config: PrefillConfig):
|
||||
self.prefill_configs.append(new_prefill_config)
|
||||
self.prefill_servers.append(new_prefill_config.url)
|
||||
|
||||
def add_decode_server(self, new_decode_server: str):
|
||||
self.decode_servers.append(new_decode_server)
|
||||
|
||||
def select_pair(self):
|
||||
# TODO: return some message instead of panic
|
||||
assert len(self.prefill_configs) > 0, "No prefill servers available"
|
||||
assert len(self.decode_servers) > 0, "No decode servers available"
|
||||
|
||||
prefill_config = random.choice(self.prefill_configs)
|
||||
decode_server = random.choice(self.decode_servers)
|
||||
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
||||
|
||||
async def generate(
|
||||
self, modified_request, prefill_server, decode_server, endpoint
|
||||
) -> ORJSONResponse:
|
||||
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.timeout
|
||||
) # Add timeout for request reliability
|
||||
) as session:
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||
]
|
||||
|
||||
# Wait for both responses to complete. Prefill should end first.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
if "return_logprob" in modified_request:
|
||||
|
||||
prefill_json = await prefill_response.json()
|
||||
ret_json = await decode_response.json()
|
||||
|
||||
# merge `meta_info.input_token_logprobs` from prefill to decode
|
||||
if "meta_info" in ret_json:
|
||||
if "input_token_logprobs" in ret_json["meta_info"]:
|
||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||
prefill_json["meta_info"]["input_token_logprobs"]
|
||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
else:
|
||||
ret_json = await decode_response.json()
|
||||
|
||||
return ORJSONResponse(
|
||||
content=ret_json,
|
||||
status_code=decode_response.status,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
||||
):
|
||||
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||
|
||||
async def stream_results():
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.timeout
|
||||
) # Add timeout for request reliability
|
||||
) as session:
|
||||
# Create the tasks for both prefill and decode requests
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||
]
|
||||
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
if modified_request.get("return_logprob", False):
|
||||
prefill_chunks = []
|
||||
async for chunk in prefill_response.content:
|
||||
prefill_chunks.append(chunk)
|
||||
|
||||
first_prefill_chunk = (
|
||||
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
||||
)
|
||||
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
||||
|
||||
async for chunk in decode_response.content:
|
||||
# Note: This is inefficient
|
||||
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
||||
decoded_chunk = chunk.decode("utf-8")
|
||||
if (
|
||||
decoded_chunk
|
||||
and decoded_chunk.startswith("data:")
|
||||
and "[DONE]" not in decoded_chunk
|
||||
):
|
||||
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||
first_prefill_chunk_json["meta_info"][
|
||||
"input_token_logprobs"
|
||||
]
|
||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
|
||||
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
||||
else:
|
||||
yield chunk
|
||||
else:
|
||||
async for chunk in decode_response.content.iter_chunked(
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
load_balancer: Optional[MiniLoadBalancer] = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate():
|
||||
prefill_servers, decode_servers = (
|
||||
load_balancer.prefill_servers,
|
||||
load_balancer.decode_servers,
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = []
|
||||
for server in chain(prefill_servers, decode_servers):
|
||||
tasks.append(session.get(f"{server}/health_generate"))
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
await response
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/flush_cache")
|
||||
async def flush_cache():
|
||||
prefill_servers, decode_servers = (
|
||||
load_balancer.prefill_servers,
|
||||
load_balancer.decode_servers,
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = []
|
||||
for server in chain(prefill_servers, decode_servers):
|
||||
tasks.append(session.post(f"{server}/flush_cache"))
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
await response
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
prefill_servers, decode_servers = (
|
||||
load_balancer.prefill_servers,
|
||||
load_balancer.decode_servers,
|
||||
)
|
||||
prefill_infos = []
|
||||
decode_infos = []
|
||||
all_internal_states = []
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for server in chain(prefill_servers):
|
||||
server_info = await session.get(f"{server}/get_server_info")
|
||||
prefill_infos.append(await server_info.json())
|
||||
for server in chain(decode_servers):
|
||||
server_info = await session.get(f"{server}/get_server_info")
|
||||
info_json = await server_info.json()
|
||||
decode_infos.append(info_json)
|
||||
# Extract internal_states from decode servers
|
||||
if "internal_states" in info_json:
|
||||
all_internal_states.extend(info_json["internal_states"])
|
||||
|
||||
# Return format expected by bench_one_batch_server.py
|
||||
if all_internal_states:
|
||||
return {
|
||||
"internal_states": all_internal_states,
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
}
|
||||
else:
|
||||
# Fallback with dummy data if no internal states found
|
||||
return {
|
||||
"internal_states": [
|
||||
{
|
||||
"last_gen_throughput": 0.0,
|
||||
"avg_spec_accept_length": None,
|
||||
}
|
||||
],
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
global load_balancer
|
||||
|
||||
if not load_balancer or not load_balancer.prefill_servers:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
detail="There is no server registered",
|
||||
)
|
||||
|
||||
target_server_url = load_balancer.prefill_servers[0]
|
||||
endpoint_url = f"{target_server_url}/get_model_info"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(endpoint_url) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_GATEWAY,
|
||||
detail=(
|
||||
f"Failed to get model info from {target_server_url}"
|
||||
f"Status: {response.status}, Response: {error_text}"
|
||||
),
|
||||
)
|
||||
|
||||
model_info_json = await response.json()
|
||||
return ORJSONResponse(content=model_info_json)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
detail=f"Failed to get model info from backend",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def handle_generate_request(request_data: dict):
|
||||
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
modified_request = request_data.copy()
|
||||
|
||||
batch_size = _get_request_batch_size(modified_request)
|
||||
if batch_size is not None:
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": [hostname] * batch_size,
|
||||
"bootstrap_port": [bootstrap_port] * batch_size,
|
||||
"bootstrap_room": [
|
||||
_generate_bootstrap_room() for _ in range(batch_size)
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": _generate_bootstrap_room(),
|
||||
}
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await load_balancer.generate_stream(
|
||||
modified_request, prefill_server, decode_server, "generate"
|
||||
)
|
||||
else:
|
||||
return await load_balancer.generate(
|
||||
modified_request, prefill_server, decode_server, "generate"
|
||||
)
|
||||
|
||||
|
||||
async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
||||
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": _generate_bootstrap_room(),
|
||||
}
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await load_balancer.generate_stream(
|
||||
modified_request,
|
||||
prefill_server,
|
||||
decode_server,
|
||||
endpoint=endpoint_name,
|
||||
)
|
||||
else:
|
||||
return await load_balancer.generate(
|
||||
modified_request,
|
||||
prefill_server,
|
||||
decode_server,
|
||||
endpoint=endpoint_name,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat_completion_request(request_data: dict):
|
||||
return await _forward_to_backend(request_data, "v1/chat/completions")
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completion_request(request_data: dict):
|
||||
return await _forward_to_backend(request_data, "v1/completions")
|
||||
|
||||
|
||||
def _generate_bootstrap_room():
|
||||
return random.randint(0, 2**63 - 1)
|
||||
|
||||
|
||||
# We may utilize `GenerateReqInput`'s logic later
|
||||
def _get_request_batch_size(request):
|
||||
if (text := request.get("text")) is not None:
|
||||
return None if isinstance(text, str) else len(text)
|
||||
if (input_ids := request.get("input_ids")) is not None:
|
||||
return None if isinstance(input_ids[0], int) else len(input_ids)
|
||||
return None
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def get_models():
|
||||
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
response = await session.get(f"{prefill_server}/v1/models")
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Prefill server error: Status {response.status}",
|
||||
)
|
||||
return ORJSONResponse(content=await response.json())
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/register")
|
||||
async def register(obj: PDRegistryRequest):
|
||||
if obj.mode == "prefill":
|
||||
load_balancer.add_prefill_server(
|
||||
PrefillConfig(obj.registry_url, obj.bootstrap_port)
|
||||
)
|
||||
logger.info(
|
||||
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
|
||||
)
|
||||
elif obj.mode == "decode":
|
||||
load_balancer.add_decode_server(obj.registry_url)
|
||||
logger.info(f"Registered decode server: {obj.registry_url}")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid mode. Must be either PREFILL or DECODE.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
|
||||
f"#Decode servers: {len(load_balancer.decode_servers)}"
|
||||
)
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def run(prefill_configs, decode_addrs, host, port, timeout):
|
||||
global load_balancer
|
||||
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs, timeout=timeout)
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
|
||||
from sglang.srt.disaggregation.launch_lb import main
|
||||
|
||||
main()
|
||||
raise RuntimeError(
|
||||
"""The 'mini_lb' module has been relocated to the 'sglang_router' package.
|
||||
We recommend installing 'sglang-router' with Rust support for optimal performance.
|
||||
If you encounter issues building the router with Rust, set the environment variable
|
||||
'SGLANG_ROUTER_BUILD_NO_RUST=1' and add '--mini-lb' to the command line to use the Python version of 'mini_lb'."""
|
||||
)
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import warnings
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.utils import get_ip, is_npu
|
||||
from sglang.srt.utils import is_npu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -305,49 +301,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
|
||||
return (num_kv_indices + page_size - 1) // page_size
|
||||
|
||||
|
||||
#########################
|
||||
# PDLB Registry
|
||||
#########################
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PDRegistryRequest:
|
||||
"""A request to register a machine itself to the LB."""
|
||||
|
||||
mode: str
|
||||
registry_url: str
|
||||
bootstrap_port: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mode == "prefill" and self.bootstrap_port is None:
|
||||
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
||||
elif self.mode == "decode" and self.bootstrap_port is not None:
|
||||
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
||||
elif self.mode not in ["prefill", "decode"]:
|
||||
raise ValueError(
|
||||
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
||||
)
|
||||
|
||||
|
||||
def register_disaggregation_server(
|
||||
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
||||
):
|
||||
boostrap_port = bootstrap_port if mode == "prefill" else None
|
||||
registry_request = PDRegistryRequest(
|
||||
mode=mode,
|
||||
registry_url=f"http://{get_ip()}:{server_port}",
|
||||
bootstrap_port=boostrap_port,
|
||||
)
|
||||
res = requests.post(
|
||||
f"{pdlb_url}/register",
|
||||
json=dataclasses.asdict(registry_request),
|
||||
)
|
||||
if res.status_code != 200:
|
||||
warnings.warn(
|
||||
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
||||
)
|
||||
|
||||
|
||||
#########################
|
||||
# Misc
|
||||
#########################
|
||||
|
||||
@@ -47,11 +47,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
DisaggregationMode,
|
||||
register_disaggregation_server,
|
||||
)
|
||||
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
@@ -1405,13 +1401,5 @@ def _wait_and_warmup(
|
||||
if server_args.debug_tensor_dump_input_file:
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
if server_args.pdlb_url is not None:
|
||||
register_disaggregation_server(
|
||||
server_args.disaggregation_mode,
|
||||
server_args.port,
|
||||
server_args.disaggregation_bootstrap_port,
|
||||
server_args.pdlb_url,
|
||||
)
|
||||
|
||||
if launch_callback is not None:
|
||||
launch_callback()
|
||||
|
||||
@@ -367,7 +367,6 @@ class ServerArgs:
|
||||
disaggregation_prefill_pp: Optional[int] = 1
|
||||
disaggregation_ib_device: Optional[str] = None
|
||||
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
||||
pdlb_url: Optional[str] = None
|
||||
|
||||
# For model weight update
|
||||
custom_weight_loader: Optional[List[str]] = None
|
||||
@@ -2071,12 +2070,6 @@ class ServerArgs:
|
||||
default=ServerArgs.num_reserved_decode_tokens,
|
||||
help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pdlb-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
|
||||
)
|
||||
|
||||
# Custom weight loader
|
||||
parser.add_argument(
|
||||
|
||||
@@ -466,6 +466,25 @@ def try_cached_model(model_repo: str):
|
||||
return model_dir if model_dir else model_repo
|
||||
|
||||
|
||||
def popen_with_error_check(command: list[str], allow_exit: bool = False):
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
def _run_and_check():
|
||||
stdout, stderr = process.communicate()
|
||||
|
||||
while process.poll() is None:
|
||||
time.sleep(5)
|
||||
|
||||
if not allow_exit or process.returncode != 0:
|
||||
raise Exception(
|
||||
f"{command} exited with code {process.returncode}\n{stdout=}\n{stderr=}"
|
||||
)
|
||||
|
||||
t = threading.Thread(target=_run_and_check)
|
||||
t.start()
|
||||
return process
|
||||
|
||||
|
||||
def popen_launch_server(
|
||||
model: str,
|
||||
base_url: str,
|
||||
|
||||
Reference in New Issue
Block a user