Simplify Router arguments passing and build it in docker image (#9964)
This commit is contained in:
@@ -36,7 +36,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
ibverbs-providers infiniband-diags perftest \
|
||||
libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev \
|
||||
libboost-all-dev libssl-dev \
|
||||
libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler-grpc \
|
||||
libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler protobuf-compiler-grpc \
|
||||
pybind11-dev \
|
||||
libhiredis-dev libcurl4-openssl-dev \
|
||||
libczmq4 libczmq-dev \
|
||||
@@ -218,6 +218,19 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1
|
||||
&& cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \
|
||||
&& rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz
|
||||
|
||||
# Install Rust toolchain for sgl-router
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
|
||||
&& rustc --version && cargo --version
|
||||
|
||||
# Build and install sgl-router
|
||||
RUN python3 -m pip install --no-cache-dir setuptools-rust \
|
||||
&& cd /sgl-workspace/sglang/sgl-router \
|
||||
&& cargo build --release \
|
||||
&& python3 -m pip install --no-cache-dir . \
|
||||
&& rm -rf /root/.cache
|
||||
|
||||
|
||||
# Add yank script
|
||||
COPY --chown=root:root <<-"EOF" /usr/local/bin/yank
|
||||
#!/bin/bash
|
||||
|
||||
@@ -36,7 +36,7 @@ uv pip install mooncake-transfer-engine
|
||||
```bash
|
||||
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-ib-device mlx5_roce0
|
||||
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-ib-device mlx5_roce0
|
||||
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
||||
$ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### DeepSeek Multi-Node
|
||||
@@ -100,7 +100,7 @@ pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx"
|
||||
```bash
|
||||
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend nixl
|
||||
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend nixl
|
||||
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
||||
$ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### DeepSeek Multi-Node
|
||||
@@ -137,7 +137,7 @@ export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true
|
||||
```bash
|
||||
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend ascend
|
||||
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend ascend
|
||||
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
||||
$ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### DeepSeek Multi-Node
|
||||
|
||||
@@ -278,7 +278,7 @@ The most sophisticated policy that combines cache optimization with load balanci
|
||||
|
||||
3. **Cache Management**:
|
||||
- Maintains approximate radix trees per worker
|
||||
- Periodically evicts LRU entries based on `--eviction-interval` and `--max-tree-size`
|
||||
- Periodically evicts LRU entries based on `--eviction-interval-secs` and `--max-tree-size`
|
||||
|
||||
### Data Parallelism Aware Routing
|
||||
|
||||
@@ -296,7 +296,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
|
||||
### Core Settings
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------------------------|------|-------------|-----------------------------------------------------------------|
|
||||
| --------------------------- | ---- | ----------- | --------------------------------------------------------------- |
|
||||
| `--host` | str | 127.0.0.1 | Router server host address |
|
||||
| `--port` | int | 30000 | Router server port |
|
||||
| `--worker-urls` | list | [] | Worker URLs for separate launch mode |
|
||||
@@ -307,18 +307,18 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
|
||||
|
||||
### Cache-Aware Routing Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|---------------------------|-------|----------|--------------------------------------------------------|
|
||||
| `--cache-threshold` | float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) |
|
||||
| `--balance-abs-threshold` | int | 32 | Absolute load difference threshold |
|
||||
| `--balance-rel-threshold` | float | 1.0001 | Relative load ratio threshold |
|
||||
| `--eviction-interval` | int | 60 | Seconds between cache eviction cycles |
|
||||
| `--max-tree-size` | int | 16777216 | Maximum nodes in routing tree |
|
||||
| Parameter | Type | Default | Description |
|
||||
| -------------------------- | ----- | -------- | ------------------------------------------------------ |
|
||||
| `--cache-threshold` | float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) |
|
||||
| `--balance-abs-threshold` | int | 32 | Absolute load difference threshold |
|
||||
| `--balance-rel-threshold` | float | 1.0001 | Relative load ratio threshold |
|
||||
| `--eviction-interval-secs` | int | 60 | Seconds between cache eviction cycles |
|
||||
| `--max-tree-size` | int | 16777216 | Maximum nodes in routing tree |
|
||||
|
||||
### Fault Tolerance Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|------------------------------|-------|---------|---------------------------------------|
|
||||
| ---------------------------- | ----- | ------- | ------------------------------------- |
|
||||
| `--retry-max-retries` | int | 3 | Maximum retry attempts per request |
|
||||
| `--retry-initial-backoff-ms` | int | 100 | Initial retry backoff in milliseconds |
|
||||
| `--retry-max-backoff-ms` | int | 10000 | Maximum retry backoff in milliseconds |
|
||||
@@ -334,7 +334,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
|
||||
### Prefill-Decode Disaggregation Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------------------------------|------|---------|-------------------------------------------------------|
|
||||
| --------------------------------- | ---- | ------- | ----------------------------------------------------- |
|
||||
| `--pd-disaggregation` | flag | False | Enable PD disaggregated mode |
|
||||
| `--prefill` | list | [] | Prefill server URLs with optional bootstrap ports |
|
||||
| `--decode` | list | [] | Decode server URLs |
|
||||
@@ -346,7 +346,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
|
||||
### Kubernetes Integration
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|---------------------------------|------|--------------------------|------------------------------------------------------|
|
||||
| ------------------------------- | ---- | ------------------------ | ---------------------------------------------------- |
|
||||
| `--service-discovery` | flag | False | Enable Kubernetes service discovery |
|
||||
| `--selector` | list | [] | Label selector for workers (key1=value1 key2=value2) |
|
||||
| `--prefill-selector` | list | [] | Label selector for prefill servers in PD mode |
|
||||
@@ -358,7 +358,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
|
||||
### Observability
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|------------------------|------|-----------|-------------------------------------------------------|
|
||||
| ---------------------- | ---- | --------- | ----------------------------------------------------- |
|
||||
| `--prometheus-port` | int | 29000 | Prometheus metrics port |
|
||||
| `--prometheus-host` | str | 127.0.0.1 | Prometheus metrics host |
|
||||
| `--log-dir` | str | None | Directory for log files |
|
||||
@@ -368,7 +368,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
|
||||
### CORS Configuration
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|--------------------------|------|---------|----------------------|
|
||||
| ------------------------ | ---- | ------- | -------------------- |
|
||||
| `--cors-allowed-origins` | list | [] | Allowed CORS origins |
|
||||
|
||||
## Advanced Features
|
||||
@@ -429,7 +429,7 @@ python -m sglang_router.launch_router \
|
||||
|
||||
2. **High latency**: Check if cache-aware routing is causing imbalance. Try adjusting `--balance-abs-threshold` and `--balance-rel-threshold`.
|
||||
|
||||
3. **Memory growth**: Reduce `--max-tree-size` or decrease `--eviction-interval` for more aggressive cache cleanup.
|
||||
3. **Memory growth**: Reduce `--max-tree-size` or decrease `--eviction-interval-secs` for more aggressive cache cleanup.
|
||||
|
||||
4. **Circuit breaker triggering frequently**: Increase `--cb-failure-threshold` or extend `--cb-window-duration-secs`.
|
||||
|
||||
|
||||
@@ -27,7 +27,8 @@ spec:
|
||||
command:
|
||||
- python
|
||||
- -m
|
||||
- sglang.srt.disaggregation.mini_lb
|
||||
- sglang_router.launch_router
|
||||
- --pd-disaggregation
|
||||
- --prefill
|
||||
- http://deepseekr10528-prefill-main:30000
|
||||
- --decode
|
||||
|
||||
@@ -714,7 +714,8 @@ spec:
|
||||
command:
|
||||
- python
|
||||
- -m
|
||||
- sglang.srt.disaggregation.mini_lb
|
||||
- sglang_router.launch_router
|
||||
- --pd-disaggregation
|
||||
- --prefill
|
||||
- http://deepseekr10528-prefill-main:30000
|
||||
- --decode
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
|
||||
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LBArgs:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
policy: str = "random"
|
||||
prefill_infos: list = dataclasses.field(default_factory=list)
|
||||
decode_infos: list = dataclasses.field(default_factory=list)
|
||||
log_interval: int = 5
|
||||
timeout: int = 600
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=LBArgs.host,
|
||||
help=f"Host to bind the server (default: {LBArgs.host})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=LBArgs.port,
|
||||
help=f"Port to bind the server (default: {LBArgs.port})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy",
|
||||
type=str,
|
||||
default=LBArgs.policy,
|
||||
choices=["random", "po2"],
|
||||
help=f"Policy to use for load balancing (default: {LBArgs.policy})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill",
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="+",
|
||||
help="URLs for prefill servers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode",
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="+",
|
||||
help="URLs for decode servers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-bootstrap-ports",
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="Bootstrap ports for prefill servers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=LBArgs.log_interval,
|
||||
help=f"Log interval in seconds (default: {LBArgs.log_interval})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=LBArgs.timeout,
|
||||
help=f"Timeout in seconds (default: {LBArgs.timeout})",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
|
||||
bootstrap_ports = args.prefill_bootstrap_ports
|
||||
if bootstrap_ports is None:
|
||||
bootstrap_ports = [None] * len(args.prefill)
|
||||
elif len(bootstrap_ports) == 1:
|
||||
bootstrap_ports = bootstrap_ports * len(args.prefill)
|
||||
else:
|
||||
if len(bootstrap_ports) != len(args.prefill):
|
||||
raise ValueError(
|
||||
"Number of prefill URLs must match number of bootstrap ports"
|
||||
)
|
||||
|
||||
prefill_infos = [
|
||||
(url, port) for url, port in zip(args.prefill, bootstrap_ports)
|
||||
]
|
||||
|
||||
return cls(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
policy=args.policy,
|
||||
prefill_infos=prefill_infos,
|
||||
decode_infos=args.decode,
|
||||
log_interval=args.log_interval,
|
||||
timeout=args.timeout,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PD Disaggregation Load Balancer Server"
|
||||
)
|
||||
LBArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
lb_args = LBArgs.from_cli_args(args)
|
||||
|
||||
prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
|
||||
run(
|
||||
prefill_configs,
|
||||
lb_args.decode_infos,
|
||||
lb_args.host,
|
||||
lb_args.port,
|
||||
lb_args.timeout,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,445 +1,6 @@
|
||||
"""
|
||||
Minimal HTTP load balancer for prefill and decode servers for testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import random
|
||||
import urllib
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
||||
from sglang.srt.utils import maybe_wrap_ipv6_address
|
||||
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
||||
1024 * 64
|
||||
) # 64KB, to prevent aiohttp's "Chunk too big" error
|
||||
|
||||
|
||||
def setup_logger():
|
||||
logger = logging.getLogger("pdlb")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PrefillConfig:
|
||||
url: str
|
||||
bootstrap_port: Optional[int] = None
|
||||
|
||||
|
||||
class MiniLoadBalancer:
|
||||
def __init__(
|
||||
self,
|
||||
prefill_configs: List[PrefillConfig],
|
||||
decode_servers: List[str],
|
||||
timeout: int,
|
||||
):
|
||||
self.prefill_configs = prefill_configs
|
||||
self.prefill_servers = [p.url for p in prefill_configs]
|
||||
self.decode_servers = decode_servers
|
||||
self.timeout = timeout
|
||||
|
||||
def add_prefill_server(self, new_prefill_config: PrefillConfig):
|
||||
self.prefill_configs.append(new_prefill_config)
|
||||
self.prefill_servers.append(new_prefill_config.url)
|
||||
|
||||
def add_decode_server(self, new_decode_server: str):
|
||||
self.decode_servers.append(new_decode_server)
|
||||
|
||||
def select_pair(self):
|
||||
# TODO: return some message instead of panic
|
||||
assert len(self.prefill_configs) > 0, "No prefill servers available"
|
||||
assert len(self.decode_servers) > 0, "No decode servers available"
|
||||
|
||||
prefill_config = random.choice(self.prefill_configs)
|
||||
decode_server = random.choice(self.decode_servers)
|
||||
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
||||
|
||||
async def generate(
|
||||
self, modified_request, prefill_server, decode_server, endpoint
|
||||
) -> ORJSONResponse:
|
||||
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.timeout
|
||||
) # Add timeout for request reliability
|
||||
) as session:
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||
]
|
||||
|
||||
# Wait for both responses to complete. Prefill should end first.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
if "return_logprob" in modified_request:
|
||||
|
||||
prefill_json = await prefill_response.json()
|
||||
ret_json = await decode_response.json()
|
||||
|
||||
# merge `meta_info.input_token_logprobs` from prefill to decode
|
||||
if "meta_info" in ret_json:
|
||||
if "input_token_logprobs" in ret_json["meta_info"]:
|
||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||
prefill_json["meta_info"]["input_token_logprobs"]
|
||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
else:
|
||||
ret_json = await decode_response.json()
|
||||
|
||||
return ORJSONResponse(
|
||||
content=ret_json,
|
||||
status_code=decode_response.status,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
||||
):
|
||||
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||
|
||||
async def stream_results():
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.timeout
|
||||
) # Add timeout for request reliability
|
||||
) as session:
|
||||
# Create the tasks for both prefill and decode requests
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||
]
|
||||
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
if modified_request.get("return_logprob", False):
|
||||
prefill_chunks = []
|
||||
async for chunk in prefill_response.content:
|
||||
prefill_chunks.append(chunk)
|
||||
|
||||
first_prefill_chunk = (
|
||||
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
||||
)
|
||||
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
||||
|
||||
async for chunk in decode_response.content:
|
||||
# Note: This is inefficient
|
||||
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
||||
decoded_chunk = chunk.decode("utf-8")
|
||||
if (
|
||||
decoded_chunk
|
||||
and decoded_chunk.startswith("data:")
|
||||
and "[DONE]" not in decoded_chunk
|
||||
):
|
||||
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||
first_prefill_chunk_json["meta_info"][
|
||||
"input_token_logprobs"
|
||||
]
|
||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
|
||||
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
||||
else:
|
||||
yield chunk
|
||||
else:
|
||||
async for chunk in decode_response.content.iter_chunked(
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
load_balancer: Optional[MiniLoadBalancer] = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate():
|
||||
prefill_servers, decode_servers = (
|
||||
load_balancer.prefill_servers,
|
||||
load_balancer.decode_servers,
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = []
|
||||
for server in chain(prefill_servers, decode_servers):
|
||||
tasks.append(session.get(f"{server}/health_generate"))
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
await response
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/flush_cache")
|
||||
async def flush_cache():
|
||||
prefill_servers, decode_servers = (
|
||||
load_balancer.prefill_servers,
|
||||
load_balancer.decode_servers,
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = []
|
||||
for server in chain(prefill_servers, decode_servers):
|
||||
tasks.append(session.post(f"{server}/flush_cache"))
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
await response
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
prefill_servers, decode_servers = (
|
||||
load_balancer.prefill_servers,
|
||||
load_balancer.decode_servers,
|
||||
)
|
||||
prefill_infos = []
|
||||
decode_infos = []
|
||||
all_internal_states = []
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for server in chain(prefill_servers):
|
||||
server_info = await session.get(f"{server}/get_server_info")
|
||||
prefill_infos.append(await server_info.json())
|
||||
for server in chain(decode_servers):
|
||||
server_info = await session.get(f"{server}/get_server_info")
|
||||
info_json = await server_info.json()
|
||||
decode_infos.append(info_json)
|
||||
# Extract internal_states from decode servers
|
||||
if "internal_states" in info_json:
|
||||
all_internal_states.extend(info_json["internal_states"])
|
||||
|
||||
# Return format expected by bench_one_batch_server.py
|
||||
if all_internal_states:
|
||||
return {
|
||||
"internal_states": all_internal_states,
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
}
|
||||
else:
|
||||
# Fallback with dummy data if no internal states found
|
||||
return {
|
||||
"internal_states": [
|
||||
{
|
||||
"last_gen_throughput": 0.0,
|
||||
"avg_spec_accept_length": None,
|
||||
}
|
||||
],
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
global load_balancer
|
||||
|
||||
if not load_balancer or not load_balancer.prefill_servers:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
detail="There is no server registered",
|
||||
)
|
||||
|
||||
target_server_url = load_balancer.prefill_servers[0]
|
||||
endpoint_url = f"{target_server_url}/get_model_info"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(endpoint_url) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_GATEWAY,
|
||||
detail=(
|
||||
f"Failed to get model info from {target_server_url}"
|
||||
f"Status: {response.status}, Response: {error_text}"
|
||||
),
|
||||
)
|
||||
|
||||
model_info_json = await response.json()
|
||||
return ORJSONResponse(content=model_info_json)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
detail=f"Failed to get model info from backend",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def handle_generate_request(request_data: dict):
|
||||
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
modified_request = request_data.copy()
|
||||
|
||||
batch_size = _get_request_batch_size(modified_request)
|
||||
if batch_size is not None:
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": [hostname] * batch_size,
|
||||
"bootstrap_port": [bootstrap_port] * batch_size,
|
||||
"bootstrap_room": [
|
||||
_generate_bootstrap_room() for _ in range(batch_size)
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": _generate_bootstrap_room(),
|
||||
}
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await load_balancer.generate_stream(
|
||||
modified_request, prefill_server, decode_server, "generate"
|
||||
)
|
||||
else:
|
||||
return await load_balancer.generate(
|
||||
modified_request, prefill_server, decode_server, "generate"
|
||||
)
|
||||
|
||||
|
||||
async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
||||
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": _generate_bootstrap_room(),
|
||||
}
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await load_balancer.generate_stream(
|
||||
modified_request,
|
||||
prefill_server,
|
||||
decode_server,
|
||||
endpoint=endpoint_name,
|
||||
)
|
||||
else:
|
||||
return await load_balancer.generate(
|
||||
modified_request,
|
||||
prefill_server,
|
||||
decode_server,
|
||||
endpoint=endpoint_name,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat_completion_request(request_data: dict):
|
||||
return await _forward_to_backend(request_data, "v1/chat/completions")
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completion_request(request_data: dict):
|
||||
return await _forward_to_backend(request_data, "v1/completions")
|
||||
|
||||
|
||||
def _generate_bootstrap_room():
|
||||
return random.randint(0, 2**63 - 1)
|
||||
|
||||
|
||||
# We may utilize `GenerateReqInput`'s logic later
|
||||
def _get_request_batch_size(request):
|
||||
if (text := request.get("text")) is not None:
|
||||
return None if isinstance(text, str) else len(text)
|
||||
if (input_ids := request.get("input_ids")) is not None:
|
||||
return None if isinstance(input_ids[0], int) else len(input_ids)
|
||||
return None
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def get_models():
|
||||
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
response = await session.get(f"{prefill_server}/v1/models")
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Prefill server error: Status {response.status}",
|
||||
)
|
||||
return ORJSONResponse(content=await response.json())
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/register")
|
||||
async def register(obj: PDRegistryRequest):
|
||||
if obj.mode == "prefill":
|
||||
load_balancer.add_prefill_server(
|
||||
PrefillConfig(obj.registry_url, obj.bootstrap_port)
|
||||
)
|
||||
logger.info(
|
||||
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
|
||||
)
|
||||
elif obj.mode == "decode":
|
||||
load_balancer.add_decode_server(obj.registry_url)
|
||||
logger.info(f"Registered decode server: {obj.registry_url}")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid mode. Must be either PREFILL or DECODE.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
|
||||
f"#Decode servers: {len(load_balancer.decode_servers)}"
|
||||
)
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def run(prefill_configs, decode_addrs, host, port, timeout):
|
||||
global load_balancer
|
||||
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs, timeout=timeout)
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
|
||||
from sglang.srt.disaggregation.launch_lb import main
|
||||
|
||||
main()
|
||||
raise RuntimeError(
|
||||
"""The 'mini_lb' module has been relocated to the 'sglang_router' package.
|
||||
We recommend installing 'sglang-router' with Rust support for optimal performance.
|
||||
If you encounter issues building the router with Rust, set the environment variable
|
||||
'SGLANG_ROUTER_BUILD_NO_RUST=1' and add '--mini-lb' to the command line to use the Python version of 'mini_lb'."""
|
||||
)
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import warnings
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.utils import get_ip, is_npu
|
||||
from sglang.srt.utils import is_npu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -305,49 +301,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
|
||||
return (num_kv_indices + page_size - 1) // page_size
|
||||
|
||||
|
||||
#########################
|
||||
# PDLB Registry
|
||||
#########################
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PDRegistryRequest:
|
||||
"""A request to register a machine itself to the LB."""
|
||||
|
||||
mode: str
|
||||
registry_url: str
|
||||
bootstrap_port: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mode == "prefill" and self.bootstrap_port is None:
|
||||
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
||||
elif self.mode == "decode" and self.bootstrap_port is not None:
|
||||
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
||||
elif self.mode not in ["prefill", "decode"]:
|
||||
raise ValueError(
|
||||
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
||||
)
|
||||
|
||||
|
||||
def register_disaggregation_server(
|
||||
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
||||
):
|
||||
boostrap_port = bootstrap_port if mode == "prefill" else None
|
||||
registry_request = PDRegistryRequest(
|
||||
mode=mode,
|
||||
registry_url=f"http://{get_ip()}:{server_port}",
|
||||
bootstrap_port=boostrap_port,
|
||||
)
|
||||
res = requests.post(
|
||||
f"{pdlb_url}/register",
|
||||
json=dataclasses.asdict(registry_request),
|
||||
)
|
||||
if res.status_code != 200:
|
||||
warnings.warn(
|
||||
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
||||
)
|
||||
|
||||
|
||||
#########################
|
||||
# Misc
|
||||
#########################
|
||||
|
||||
@@ -47,11 +47,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
DisaggregationMode,
|
||||
register_disaggregation_server,
|
||||
)
|
||||
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
@@ -1405,13 +1401,5 @@ def _wait_and_warmup(
|
||||
if server_args.debug_tensor_dump_input_file:
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
if server_args.pdlb_url is not None:
|
||||
register_disaggregation_server(
|
||||
server_args.disaggregation_mode,
|
||||
server_args.port,
|
||||
server_args.disaggregation_bootstrap_port,
|
||||
server_args.pdlb_url,
|
||||
)
|
||||
|
||||
if launch_callback is not None:
|
||||
launch_callback()
|
||||
|
||||
@@ -367,7 +367,6 @@ class ServerArgs:
|
||||
disaggregation_prefill_pp: Optional[int] = 1
|
||||
disaggregation_ib_device: Optional[str] = None
|
||||
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
||||
pdlb_url: Optional[str] = None
|
||||
|
||||
# For model weight update
|
||||
custom_weight_loader: Optional[List[str]] = None
|
||||
@@ -2071,12 +2070,6 @@ class ServerArgs:
|
||||
default=ServerArgs.num_reserved_decode_tokens,
|
||||
help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pdlb-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
|
||||
)
|
||||
|
||||
# Custom weight loader
|
||||
parser.add_argument(
|
||||
|
||||
@@ -466,6 +466,25 @@ def try_cached_model(model_repo: str):
|
||||
return model_dir if model_dir else model_repo
|
||||
|
||||
|
||||
def popen_with_error_check(command: list[str], allow_exit: bool = False):
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
def _run_and_check():
|
||||
stdout, stderr = process.communicate()
|
||||
|
||||
while process.poll() is None:
|
||||
time.sleep(5)
|
||||
|
||||
if not allow_exit or process.returncode != 0:
|
||||
raise Exception(
|
||||
f"{command} exited with code {process.returncode}\n{stdout=}\n{stderr=}"
|
||||
)
|
||||
|
||||
t = threading.Thread(target=_run_and_check)
|
||||
t.start()
|
||||
return process
|
||||
|
||||
|
||||
def popen_launch_server(
|
||||
model: str,
|
||||
base_url: str,
|
||||
|
||||
@@ -45,6 +45,10 @@ fi
|
||||
# Install the main package
|
||||
$PIP_CMD install -e "python[dev]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX
|
||||
|
||||
# Install router for pd-disagg test
|
||||
SGLANG_ROUTER_BUILD_NO_RUST=1 $PIP_CMD install -e "sgl-router" $PIP_INSTALL_SUFFIX
|
||||
|
||||
|
||||
if [ "$IS_BLACKWELL" = "1" ]; then
|
||||
# TODO auto determine sgl-kernel version
|
||||
SGL_KERNEL_VERSION=0.3.8
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
# a lightweihgt wrapper on router with argument type and comments
|
||||
# no wrapper on policy type => direct export
|
||||
from sglang_router.router import Router
|
||||
from sglang_router.version import __version__
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
__all__ = ["Router", "PolicyType", "__version__"]
|
||||
__all__ = ["__version__"]
|
||||
|
||||
@@ -1,654 +1,22 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang_router import Router
|
||||
from sglang_router_rs import PolicyType
|
||||
import setproctitle
|
||||
from sglang_router.mini_lb import MiniLoadBalancer
|
||||
from sglang_router.router_args import RouterArgs
|
||||
|
||||
logger = logging.getLogger("router")
|
||||
|
||||
def setup_logger():
|
||||
logger = logging.getLogger("router")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
try:
|
||||
from sglang_router.router import Router
|
||||
except ImportError:
|
||||
Router = None
|
||||
logger.warning(
|
||||
"Rust Router is not installed, only python MiniLB (debugging only) is available"
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RouterArgs:
|
||||
# Worker configuration
|
||||
worker_urls: List[str] = dataclasses.field(default_factory=list)
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
|
||||
# PD-specific configuration
|
||||
pd_disaggregation: bool = False # Enable PD disaggregated mode
|
||||
prefill_urls: List[tuple] = dataclasses.field(
|
||||
default_factory=list
|
||||
) # List of (url, bootstrap_port)
|
||||
decode_urls: List[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
# Routing policy
|
||||
policy: str = "cache_aware"
|
||||
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
|
||||
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
|
||||
worker_startup_timeout_secs: int = 600
|
||||
worker_startup_check_interval: int = 30
|
||||
cache_threshold: float = 0.3
|
||||
balance_abs_threshold: int = 64
|
||||
balance_rel_threshold: float = 1.5
|
||||
eviction_interval: int = 120
|
||||
max_tree_size: int = 2**26
|
||||
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
|
||||
dp_aware: bool = False
|
||||
api_key: Optional[str] = None
|
||||
log_dir: Optional[str] = None
|
||||
log_level: Optional[str] = None
|
||||
# Service discovery configuration
|
||||
service_discovery: bool = False
|
||||
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
service_discovery_port: int = 80
|
||||
service_discovery_namespace: Optional[str] = None
|
||||
# PD service discovery configuration
|
||||
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
|
||||
# Prometheus configuration
|
||||
prometheus_port: Optional[int] = None
|
||||
prometheus_host: Optional[str] = None
|
||||
# Request ID headers configuration
|
||||
request_id_headers: Optional[List[str]] = None
|
||||
# Request timeout in seconds
|
||||
request_timeout_secs: int = 1800
|
||||
# Max concurrent requests for rate limiting
|
||||
max_concurrent_requests: int = 256
|
||||
# Queue size for pending requests when max concurrent limit reached
|
||||
queue_size: int = 100
|
||||
# Maximum time (in seconds) a request can wait in queue before timing out
|
||||
queue_timeout_secs: int = 60
|
||||
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
|
||||
rate_limit_tokens_per_second: Optional[int] = None
|
||||
# CORS allowed origins
|
||||
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
|
||||
# Retry configuration
|
||||
retry_max_retries: int = 5
|
||||
retry_initial_backoff_ms: int = 50
|
||||
retry_max_backoff_ms: int = 30_000
|
||||
retry_backoff_multiplier: float = 1.5
|
||||
retry_jitter_factor: float = 0.2
|
||||
disable_retries: bool = False
|
||||
# Health check configuration
|
||||
health_failure_threshold: int = 3
|
||||
health_success_threshold: int = 2
|
||||
health_check_timeout_secs: int = 5
|
||||
health_check_interval_secs: int = 60
|
||||
health_check_endpoint: str = "/health"
|
||||
# Circuit breaker configuration
|
||||
cb_failure_threshold: int = 10
|
||||
cb_success_threshold: int = 3
|
||||
cb_timeout_duration_secs: int = 60
|
||||
cb_window_duration_secs: int = 120
|
||||
disable_circuit_breaker: bool = False
|
||||
# Tokenizer configuration
|
||||
model_path: Optional[str] = None
|
||||
tokenizer_path: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
use_router_prefix: bool = False,
|
||||
exclude_host_port: bool = False,
|
||||
):
|
||||
"""
|
||||
Add router-specific arguments to an argument parser.
|
||||
|
||||
Args:
|
||||
parser: The argument parser to add arguments to
|
||||
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
|
||||
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
|
||||
"""
|
||||
prefix = "router-" if use_router_prefix else ""
|
||||
|
||||
# Worker configuration
|
||||
if not exclude_host_port:
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=RouterArgs.host,
|
||||
help="Host address to bind the router server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=RouterArgs.port,
|
||||
help="Port number to bind the router server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--worker-urls",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
|
||||
)
|
||||
|
||||
# Routing policy configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}policy",
|
||||
type=str,
|
||||
default=RouterArgs.policy,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
|
||||
)
|
||||
|
||||
# PD-specific arguments
|
||||
parser.add_argument(
|
||||
f"--{prefix}pd-disaggregation",
|
||||
action="store_true",
|
||||
help="Enable PD (Prefill-Decode) disaggregated mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill",
|
||||
nargs="+",
|
||||
action="append",
|
||||
help="Prefill server URL and optional bootstrap port. Can be specified multiple times. "
|
||||
"Format: --prefill URL [BOOTSTRAP_PORT]. "
|
||||
"BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode",
|
||||
nargs=1,
|
||||
action="append",
|
||||
metavar=("URL",),
|
||||
help="Decode server URL. Can be specified multiple times.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}worker-startup-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.worker_startup_timeout_secs,
|
||||
help="Timeout in seconds for worker startup",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}worker-startup-check-interval",
|
||||
type=int,
|
||||
default=RouterArgs.worker_startup_check_interval,
|
||||
help="Interval in seconds between checks for worker startup",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cache-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.cache_threshold,
|
||||
help="Cache threshold (0.0-1.0) for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}balance-abs-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.balance_abs_threshold,
|
||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}balance-rel-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.balance_rel_threshold,
|
||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}eviction-interval",
|
||||
type=int,
|
||||
default=RouterArgs.eviction_interval,
|
||||
help="Interval in seconds between cache eviction operations",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-tree-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_tree_size,
|
||||
help="Maximum size of the approximation tree for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-payload-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_payload_size,
|
||||
help="Maximum payload size in bytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}dp-aware",
|
||||
action="store_true",
|
||||
help="Enable data parallelism aware schedule",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}log-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to store log files. If not specified, logs are only output to console.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}log-level",
|
||||
type=str,
|
||||
default="info",
|
||||
choices=["debug", "info", "warning", "error", "critical"],
|
||||
help="Set the logging level. If not specified, defaults to INFO.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery",
|
||||
action="store_true",
|
||||
help="Enable Kubernetes service discovery",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery-port",
|
||||
type=int,
|
||||
default=RouterArgs.service_discovery_port,
|
||||
help="Port to use for discovered worker pods",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery-namespace",
|
||||
type=str,
|
||||
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill-selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode-selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
|
||||
)
|
||||
# Prometheus configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}prometheus-port",
|
||||
type=int,
|
||||
default=29000,
|
||||
help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prometheus-host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Host address to bind the Prometheus metrics server",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}request-id-headers",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}request-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.request_timeout_secs,
|
||||
help="Request timeout in seconds",
|
||||
)
|
||||
# Retry configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-max-retries",
|
||||
type=int,
|
||||
default=RouterArgs.retry_max_retries,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-initial-backoff-ms",
|
||||
type=int,
|
||||
default=RouterArgs.retry_initial_backoff_ms,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-max-backoff-ms",
|
||||
type=int,
|
||||
default=RouterArgs.retry_max_backoff_ms,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-backoff-multiplier",
|
||||
type=float,
|
||||
default=RouterArgs.retry_backoff_multiplier,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-jitter-factor",
|
||||
type=float,
|
||||
default=RouterArgs.retry_jitter_factor,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}disable-retries",
|
||||
action="store_true",
|
||||
help="Disable retries (equivalent to setting retry_max_retries=1)",
|
||||
)
|
||||
# Circuit breaker configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-failure-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.cb_failure_threshold,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-success-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.cb_success_threshold,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-timeout-duration-secs",
|
||||
type=int,
|
||||
default=RouterArgs.cb_timeout_duration_secs,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-window-duration-secs",
|
||||
type=int,
|
||||
default=RouterArgs.cb_window_duration_secs,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}disable-circuit-breaker",
|
||||
action="store_true",
|
||||
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
|
||||
)
|
||||
# Health check configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-failure-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.health_failure_threshold,
|
||||
help="Number of consecutive health check failures before marking worker unhealthy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-success-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.health_success_threshold,
|
||||
help="Number of consecutive health check successes before marking worker healthy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.health_check_timeout_secs,
|
||||
help="Timeout in seconds for health check requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-interval-secs",
|
||||
type=int,
|
||||
default=RouterArgs.health_check_interval_secs,
|
||||
help="Interval in seconds between runtime health checks",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-endpoint",
|
||||
type=str,
|
||||
default=RouterArgs.health_check_endpoint,
|
||||
help="Health check endpoint path",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-concurrent-requests",
|
||||
type=int,
|
||||
default=RouterArgs.max_concurrent_requests,
|
||||
help="Maximum number of concurrent requests allowed (for rate limiting)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}queue-size",
|
||||
type=int,
|
||||
default=RouterArgs.queue_size,
|
||||
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}queue-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.queue_timeout_secs,
|
||||
help="Maximum time (in seconds) a request can wait in queue before timing out",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}rate-limit-tokens-per-second",
|
||||
type=int,
|
||||
default=RouterArgs.rate_limit_tokens_per_second,
|
||||
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cors-allowed-origins",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
|
||||
)
|
||||
# Tokenizer configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model path for loading tokenizer (HuggingFace model ID or local path)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}tokenizer-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
cls, args: argparse.Namespace, use_router_prefix: bool = False
|
||||
) -> "RouterArgs":
|
||||
"""
|
||||
Create RouterArgs instance from parsed command line arguments.
|
||||
|
||||
Args:
|
||||
args: Parsed command line arguments
|
||||
use_router_prefix: If True, look for arguments with 'router-' prefix
|
||||
"""
|
||||
prefix = "router_" if use_router_prefix else ""
|
||||
worker_urls = getattr(args, "worker_urls", [])
|
||||
|
||||
# Parse PD URLs
|
||||
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
|
||||
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
|
||||
|
||||
return cls(
|
||||
worker_urls=worker_urls,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False),
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
policy=getattr(args, f"{prefix}policy"),
|
||||
prefill_policy=getattr(args, f"{prefix}prefill_policy", None),
|
||||
decode_policy=getattr(args, f"{prefix}decode_policy", None),
|
||||
worker_startup_timeout_secs=getattr(
|
||||
args, f"{prefix}worker_startup_timeout_secs"
|
||||
),
|
||||
worker_startup_check_interval=getattr(
|
||||
args, f"{prefix}worker_startup_check_interval"
|
||||
),
|
||||
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
|
||||
balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"),
|
||||
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
|
||||
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
|
||||
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
|
||||
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
|
||||
dp_aware=getattr(args, f"{prefix}dp_aware", False),
|
||||
api_key=getattr(args, f"{prefix}api_key", None),
|
||||
log_dir=getattr(args, f"{prefix}log_dir", None),
|
||||
log_level=getattr(args, f"{prefix}log_level", None),
|
||||
service_discovery=getattr(args, f"{prefix}service_discovery", False),
|
||||
selector=cls._parse_selector(getattr(args, f"{prefix}selector", None)),
|
||||
service_discovery_port=getattr(args, f"{prefix}service_discovery_port"),
|
||||
service_discovery_namespace=getattr(
|
||||
args, f"{prefix}service_discovery_namespace", None
|
||||
),
|
||||
prefill_selector=cls._parse_selector(
|
||||
getattr(args, f"{prefix}prefill_selector", None)
|
||||
),
|
||||
decode_selector=cls._parse_selector(
|
||||
getattr(args, f"{prefix}decode_selector", None)
|
||||
),
|
||||
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
|
||||
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
|
||||
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
|
||||
request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
|
||||
request_timeout_secs=getattr(
|
||||
args, f"{prefix}request_timeout_secs", RouterArgs.request_timeout_secs
|
||||
),
|
||||
max_concurrent_requests=getattr(
|
||||
args,
|
||||
f"{prefix}max_concurrent_requests",
|
||||
RouterArgs.max_concurrent_requests,
|
||||
),
|
||||
queue_size=getattr(
|
||||
args,
|
||||
f"{prefix}queue_size",
|
||||
RouterArgs.queue_size,
|
||||
),
|
||||
queue_timeout_secs=getattr(
|
||||
args,
|
||||
f"{prefix}queue_timeout_secs",
|
||||
RouterArgs.queue_timeout_secs,
|
||||
),
|
||||
rate_limit_tokens_per_second=getattr(
|
||||
args,
|
||||
f"{prefix}rate_limit_tokens_per_second",
|
||||
RouterArgs.rate_limit_tokens_per_second,
|
||||
),
|
||||
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
|
||||
retry_max_retries=getattr(args, f"{prefix}retry_max_retries"),
|
||||
retry_initial_backoff_ms=getattr(args, f"{prefix}retry_initial_backoff_ms"),
|
||||
retry_max_backoff_ms=getattr(args, f"{prefix}retry_max_backoff_ms"),
|
||||
retry_backoff_multiplier=getattr(args, f"{prefix}retry_backoff_multiplier"),
|
||||
retry_jitter_factor=getattr(args, f"{prefix}retry_jitter_factor"),
|
||||
cb_failure_threshold=getattr(args, f"{prefix}cb_failure_threshold"),
|
||||
cb_success_threshold=getattr(args, f"{prefix}cb_success_threshold"),
|
||||
cb_timeout_duration_secs=getattr(args, f"{prefix}cb_timeout_duration_secs"),
|
||||
cb_window_duration_secs=getattr(args, f"{prefix}cb_window_duration_secs"),
|
||||
disable_retries=getattr(args, f"{prefix}disable_retries", False),
|
||||
disable_circuit_breaker=getattr(
|
||||
args, f"{prefix}disable_circuit_breaker", False
|
||||
),
|
||||
health_failure_threshold=getattr(
|
||||
args,
|
||||
f"{prefix}health_failure_threshold",
|
||||
RouterArgs.health_failure_threshold,
|
||||
),
|
||||
health_success_threshold=getattr(
|
||||
args,
|
||||
f"{prefix}health_success_threshold",
|
||||
RouterArgs.health_success_threshold,
|
||||
),
|
||||
health_check_timeout_secs=getattr(
|
||||
args,
|
||||
f"{prefix}health_check_timeout_secs",
|
||||
RouterArgs.health_check_timeout_secs,
|
||||
),
|
||||
health_check_interval_secs=getattr(
|
||||
args,
|
||||
f"{prefix}health_check_interval_secs",
|
||||
RouterArgs.health_check_interval_secs,
|
||||
),
|
||||
health_check_endpoint=getattr(
|
||||
args, f"{prefix}health_check_endpoint", RouterArgs.health_check_endpoint
|
||||
),
|
||||
model_path=getattr(args, f"{prefix}model_path", None),
|
||||
tokenizer_path=getattr(args, f"{prefix}tokenizer_path", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_selector(selector_list):
|
||||
if not selector_list:
|
||||
return {}
|
||||
|
||||
selector = {}
|
||||
for item in selector_list:
|
||||
if "=" in item:
|
||||
key, value = item.split("=", 1)
|
||||
selector[key] = value
|
||||
return selector
|
||||
|
||||
@staticmethod
|
||||
def _parse_prefill_urls(prefill_list):
|
||||
"""Parse prefill URLs from --prefill arguments.
|
||||
|
||||
Format: --prefill URL [BOOTSTRAP_PORT]
|
||||
Example:
|
||||
--prefill http://prefill1:8080 9000 # With bootstrap port
|
||||
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
|
||||
--prefill http://prefill3:8080 # Defaults to no bootstrap port
|
||||
"""
|
||||
if not prefill_list:
|
||||
return []
|
||||
|
||||
prefill_urls = []
|
||||
for prefill_args in prefill_list:
|
||||
|
||||
url = prefill_args[0]
|
||||
|
||||
# Handle optional bootstrap port
|
||||
if len(prefill_args) >= 2:
|
||||
bootstrap_port_str = prefill_args[1]
|
||||
# Handle 'none' as None
|
||||
if bootstrap_port_str.lower() == "none":
|
||||
bootstrap_port = None
|
||||
else:
|
||||
try:
|
||||
bootstrap_port = int(bootstrap_port_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
|
||||
)
|
||||
else:
|
||||
# No bootstrap port specified, default to None
|
||||
bootstrap_port = None
|
||||
|
||||
prefill_urls.append((url, bootstrap_port))
|
||||
|
||||
return prefill_urls
|
||||
|
||||
@staticmethod
|
||||
def _parse_decode_urls(decode_list):
|
||||
"""Parse decode URLs from --decode arguments.
|
||||
|
||||
Format: --decode URL
|
||||
Example: --decode http://decode1:8081 --decode http://decode2:8081
|
||||
"""
|
||||
if not decode_list:
|
||||
return []
|
||||
|
||||
# decode_list is a list of single-element lists due to nargs=1
|
||||
return [url[0] for url in decode_list]
|
||||
|
||||
|
||||
def policy_from_str(policy_str: str) -> PolicyType:
|
||||
"""Convert policy string to PolicyType enum."""
|
||||
policy_map = {
|
||||
"random": PolicyType.Random,
|
||||
"round_robin": PolicyType.RoundRobin,
|
||||
"cache_aware": PolicyType.CacheAware,
|
||||
"power_of_two": PolicyType.PowerOfTwo,
|
||||
}
|
||||
return policy_map[policy_str]
|
||||
|
||||
|
||||
def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
"""
|
||||
@@ -661,7 +29,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
Returns:
|
||||
Router instance if successful, None if failed
|
||||
"""
|
||||
logger = logging.getLogger("router")
|
||||
setproctitle.setproctitle("sglang::router")
|
||||
try:
|
||||
# Convert to RouterArgs if needed
|
||||
if not isinstance(args, RouterArgs):
|
||||
@@ -669,120 +37,15 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
else:
|
||||
router_args = args
|
||||
|
||||
# Validate configuration based on mode
|
||||
if router_args.pd_disaggregation:
|
||||
# Validate PD configuration - skip URL requirements if using service discovery
|
||||
if not router_args.service_discovery:
|
||||
if not router_args.prefill_urls:
|
||||
raise ValueError("PD disaggregation mode requires --prefill")
|
||||
if not router_args.decode_urls:
|
||||
raise ValueError("PD disaggregation mode requires --decode")
|
||||
|
||||
# Warn about policy usage in PD mode
|
||||
if (
|
||||
router_args.prefill_policy
|
||||
and router_args.decode_policy
|
||||
and router_args.policy
|
||||
):
|
||||
logger.warning(
|
||||
"Both --prefill-policy and --decode-policy are specified. "
|
||||
"The main --policy flag will be ignored for PD mode."
|
||||
)
|
||||
elif (
|
||||
router_args.prefill_policy
|
||||
and not router_args.decode_policy
|
||||
and router_args.policy
|
||||
):
|
||||
logger.info(
|
||||
f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes "
|
||||
f"and --policy '{router_args.policy}' for decode nodes."
|
||||
)
|
||||
elif (
|
||||
router_args.decode_policy
|
||||
and not router_args.prefill_policy
|
||||
and router_args.policy
|
||||
):
|
||||
logger.info(
|
||||
f"Using --policy '{router_args.policy}' for prefill nodes "
|
||||
f"and --decode-policy '{router_args.decode_policy}' for decode nodes."
|
||||
)
|
||||
|
||||
# Create router with unified constructor
|
||||
router = Router(
|
||||
worker_urls=(
|
||||
[]
|
||||
if router_args.service_discovery or router_args.pd_disaggregation
|
||||
else router_args.worker_urls
|
||||
),
|
||||
host=router_args.host,
|
||||
port=router_args.port,
|
||||
policy=policy_from_str(router_args.policy),
|
||||
worker_startup_timeout_secs=router_args.worker_startup_timeout_secs,
|
||||
worker_startup_check_interval=router_args.worker_startup_check_interval,
|
||||
cache_threshold=router_args.cache_threshold,
|
||||
balance_abs_threshold=router_args.balance_abs_threshold,
|
||||
balance_rel_threshold=router_args.balance_rel_threshold,
|
||||
eviction_interval_secs=router_args.eviction_interval,
|
||||
max_tree_size=router_args.max_tree_size,
|
||||
max_payload_size=router_args.max_payload_size,
|
||||
dp_aware=router_args.dp_aware,
|
||||
api_key=router_args.api_key,
|
||||
log_dir=router_args.log_dir,
|
||||
log_level=router_args.log_level,
|
||||
service_discovery=router_args.service_discovery,
|
||||
selector=router_args.selector,
|
||||
service_discovery_port=router_args.service_discovery_port,
|
||||
service_discovery_namespace=router_args.service_discovery_namespace,
|
||||
prefill_selector=router_args.prefill_selector,
|
||||
decode_selector=router_args.decode_selector,
|
||||
prometheus_port=router_args.prometheus_port,
|
||||
prometheus_host=router_args.prometheus_host,
|
||||
request_timeout_secs=router_args.request_timeout_secs,
|
||||
pd_disaggregation=router_args.pd_disaggregation,
|
||||
prefill_urls=(
|
||||
router_args.prefill_urls if router_args.pd_disaggregation else None
|
||||
),
|
||||
decode_urls=(
|
||||
router_args.decode_urls if router_args.pd_disaggregation else None
|
||||
),
|
||||
prefill_policy=(
|
||||
policy_from_str(router_args.prefill_policy)
|
||||
if router_args.prefill_policy
|
||||
else None
|
||||
),
|
||||
decode_policy=(
|
||||
policy_from_str(router_args.decode_policy)
|
||||
if router_args.decode_policy
|
||||
else None
|
||||
),
|
||||
request_id_headers=router_args.request_id_headers,
|
||||
max_concurrent_requests=router_args.max_concurrent_requests,
|
||||
queue_size=router_args.queue_size,
|
||||
queue_timeout_secs=router_args.queue_timeout_secs,
|
||||
rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second,
|
||||
cors_allowed_origins=router_args.cors_allowed_origins,
|
||||
retry_max_retries=router_args.retry_max_retries,
|
||||
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
|
||||
retry_max_backoff_ms=router_args.retry_max_backoff_ms,
|
||||
retry_backoff_multiplier=router_args.retry_backoff_multiplier,
|
||||
retry_jitter_factor=router_args.retry_jitter_factor,
|
||||
cb_failure_threshold=router_args.cb_failure_threshold,
|
||||
cb_success_threshold=router_args.cb_success_threshold,
|
||||
cb_timeout_duration_secs=router_args.cb_timeout_duration_secs,
|
||||
cb_window_duration_secs=router_args.cb_window_duration_secs,
|
||||
disable_retries=router_args.disable_retries,
|
||||
disable_circuit_breaker=router_args.disable_circuit_breaker,
|
||||
health_failure_threshold=router_args.health_failure_threshold,
|
||||
health_success_threshold=router_args.health_success_threshold,
|
||||
health_check_timeout_secs=router_args.health_check_timeout_secs,
|
||||
health_check_interval_secs=router_args.health_check_interval_secs,
|
||||
health_check_endpoint=router_args.health_check_endpoint,
|
||||
model_path=router_args.model_path,
|
||||
tokenizer_path=router_args.tokenizer_path,
|
||||
)
|
||||
|
||||
router.start()
|
||||
return router
|
||||
if router_args.mini_lb:
|
||||
mini_lb = MiniLoadBalancer(router_args)
|
||||
mini_lb.start()
|
||||
else:
|
||||
if Router is None:
|
||||
raise RuntimeError("Rust Router is not installed")
|
||||
router_args._validate_router_args()
|
||||
router = Router.from_args(router_args)
|
||||
router.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting router: {e}")
|
||||
|
||||
395
sgl-router/py_src/sglang_router/mini_lb.py
Normal file
395
sgl-router/py_src/sglang_router/mini_lb.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Minimal HTTP load balancer for prefill and decode servers for testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import logging
|
||||
import random
|
||||
import urllib
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
from sglang_router.router_args import RouterArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
||||
1024 * 64
|
||||
) # 64KB, to prevent aiohttp's "Chunk too big" error
|
||||
|
||||
|
||||
def maybe_wrap_ipv6_address(address: str) -> str:
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
return f"[{address}]"
|
||||
except ValueError:
|
||||
return address
|
||||
|
||||
|
||||
class MiniLoadBalancer:
|
||||
def __init__(
|
||||
self,
|
||||
router_args: RouterArgs,
|
||||
):
|
||||
self._validate_router_args(router_args)
|
||||
|
||||
self.host = router_args.host
|
||||
self.port = router_args.port
|
||||
self.timeout = router_args.request_timeout_secs
|
||||
self.prefill_urls = [url[0] for url in router_args.prefill_urls]
|
||||
self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls]
|
||||
self.decode_urls = router_args.decode_urls
|
||||
|
||||
def _validate_router_args(self, router_args: RouterArgs):
|
||||
logger.warning(
|
||||
"\x1b[33mMiniLB is only for debugging purposes, it only supports random policy!\033[0m"
|
||||
)
|
||||
|
||||
# NOTE: too many arguments unsupported, just validate some important ones
|
||||
if router_args.policy != "random":
|
||||
logger.warning("[MiniLB] Overriding policy to random")
|
||||
router_args.policy = "random"
|
||||
|
||||
if not router_args.pd_disaggregation:
|
||||
raise ValueError("MiniLB only supports PD disaggregation mode")
|
||||
|
||||
if len(router_args.prefill_urls) == 0 or len(router_args.decode_urls) == 0:
|
||||
raise ValueError(
|
||||
"MiniLB requires at least one prefill and one decode server"
|
||||
)
|
||||
|
||||
def start(self):
|
||||
global lb
|
||||
lb = self
|
||||
uvicorn.run(app, host=self.host, port=self.port)
|
||||
|
||||
def select_pair(self):
|
||||
assert len(self.prefill_urls) > 0, "No prefill servers available"
|
||||
assert len(self.decode_urls) > 0, "No decode servers available"
|
||||
pidx = random.randint(0, len(self.prefill_urls) - 1)
|
||||
didx = random.randint(0, len(self.decode_urls) - 1)
|
||||
return (
|
||||
self.prefill_urls[pidx],
|
||||
self.prefill_bootstrap_ports[pidx],
|
||||
self.decode_urls[didx],
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self, modified_request, prefill_server, decode_server, endpoint
|
||||
) -> ORJSONResponse:
|
||||
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.timeout
|
||||
) # Add timeout for request reliability
|
||||
) as session:
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||
]
|
||||
|
||||
# Wait for both responses to complete. Prefill should end first.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
if "return_logprob" in modified_request:
|
||||
|
||||
prefill_json = await prefill_response.json()
|
||||
ret_json = await decode_response.json()
|
||||
|
||||
# merge `meta_info.input_token_logprobs` from prefill to decode
|
||||
if "meta_info" in ret_json:
|
||||
if "input_token_logprobs" in ret_json["meta_info"]:
|
||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||
prefill_json["meta_info"]["input_token_logprobs"]
|
||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
else:
|
||||
ret_json = await decode_response.json()
|
||||
|
||||
return ORJSONResponse(
|
||||
content=ret_json,
|
||||
status_code=decode_response.status,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
||||
):
|
||||
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||
|
||||
async def stream_results():
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.timeout
|
||||
) # Add timeout for request reliability
|
||||
) as session:
|
||||
# Create the tasks for both prefill and decode requests
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||
]
|
||||
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
if modified_request.get("return_logprob", False):
|
||||
prefill_chunks = []
|
||||
async for chunk in prefill_response.content:
|
||||
prefill_chunks.append(chunk)
|
||||
|
||||
first_prefill_chunk = (
|
||||
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
||||
)
|
||||
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
||||
|
||||
async for chunk in decode_response.content:
|
||||
# Note: This is inefficient
|
||||
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
||||
decoded_chunk = chunk.decode("utf-8")
|
||||
if (
|
||||
decoded_chunk
|
||||
and decoded_chunk.startswith("data:")
|
||||
and "[DONE]" not in decoded_chunk
|
||||
):
|
||||
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||
first_prefill_chunk_json["meta_info"][
|
||||
"input_token_logprobs"
|
||||
]
|
||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
|
||||
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
||||
else:
|
||||
yield chunk
|
||||
else:
|
||||
async for chunk in decode_response.content.iter_chunked(
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
lb: Optional[MiniLoadBalancer] = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = []
|
||||
for server in chain(lb.prefill_urls, lb.decode_urls):
|
||||
tasks.append(session.get(f"{server}/health_generate"))
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
await response
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/flush_cache")
|
||||
async def flush_cache():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = []
|
||||
for server in chain(lb.prefill_urls, lb.decode_urls):
|
||||
tasks.append(session.post(f"{server}/flush_cache"))
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
await response
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
prefill_infos = []
|
||||
decode_infos = []
|
||||
all_internal_states = []
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for server in lb.prefill_urls:
|
||||
server_info = await session.get(f"{server}/get_server_info")
|
||||
prefill_infos.append(await server_info.json())
|
||||
for server in lb.decode_urls:
|
||||
server_info = await session.get(f"{server}/get_server_info")
|
||||
info_json = await server_info.json()
|
||||
decode_infos.append(info_json)
|
||||
# Extract internal_states from decode servers
|
||||
if "internal_states" in info_json:
|
||||
all_internal_states.extend(info_json["internal_states"])
|
||||
|
||||
# Return format expected by bench_one_batch_server.py
|
||||
if all_internal_states:
|
||||
return {
|
||||
"internal_states": all_internal_states,
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
}
|
||||
else:
|
||||
# Fallback with dummy data if no internal states found
|
||||
return {
|
||||
"internal_states": [
|
||||
{
|
||||
"last_gen_throughput": 0.0,
|
||||
"avg_spec_accept_length": None,
|
||||
}
|
||||
],
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
if not lb or not lb.prefill_urls:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
detail="There is no server registered",
|
||||
)
|
||||
|
||||
target_server_url = lb.prefill_urls[0]
|
||||
endpoint_url = f"{target_server_url}/get_model_info"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(endpoint_url) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_GATEWAY,
|
||||
detail=(
|
||||
f"Failed to get model info from {target_server_url}"
|
||||
f"Status: {response.status}, Response: {error_text}"
|
||||
),
|
||||
)
|
||||
|
||||
model_info_json = await response.json()
|
||||
return ORJSONResponse(content=model_info_json)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
detail=f"Failed to get model info from backend",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def handle_generate_request(request_data: dict):
|
||||
prefill_server, bootstrap_port, decode_server = lb.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
modified_request = request_data.copy()
|
||||
|
||||
batch_size = _get_request_batch_size(modified_request)
|
||||
if batch_size is not None:
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": [hostname] * batch_size,
|
||||
"bootstrap_port": [bootstrap_port] * batch_size,
|
||||
"bootstrap_room": [
|
||||
_generate_bootstrap_room() for _ in range(batch_size)
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": _generate_bootstrap_room(),
|
||||
}
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await lb.generate_stream(
|
||||
modified_request, prefill_server, decode_server, "generate"
|
||||
)
|
||||
else:
|
||||
return await lb.generate(
|
||||
modified_request, prefill_server, decode_server, "generate"
|
||||
)
|
||||
|
||||
|
||||
async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
||||
prefill_server, bootstrap_port, decode_server = lb.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": _generate_bootstrap_room(),
|
||||
}
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await lb.generate_stream(
|
||||
modified_request,
|
||||
prefill_server,
|
||||
decode_server,
|
||||
endpoint=endpoint_name,
|
||||
)
|
||||
else:
|
||||
return await lb.generate(
|
||||
modified_request,
|
||||
prefill_server,
|
||||
decode_server,
|
||||
endpoint=endpoint_name,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat_completion_request(request_data: dict):
|
||||
return await _forward_to_backend(request_data, "v1/chat/completions")
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completion_request(request_data: dict):
|
||||
return await _forward_to_backend(request_data, "v1/completions")
|
||||
|
||||
|
||||
def _generate_bootstrap_room():
|
||||
return random.randint(0, 2**63 - 1)
|
||||
|
||||
|
||||
# We may utilize `GenerateReqInput`'s logic later
|
||||
def _get_request_batch_size(request):
|
||||
if (text := request.get("text")) is not None:
|
||||
return None if isinstance(text, str) else len(text)
|
||||
if (input_ids := request.get("input_ids")) is not None:
|
||||
return None if isinstance(input_ids[0], int) else len(input_ids)
|
||||
return None
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def get_models():
|
||||
prefill_server = lb.prefill_urls[0] # Get the first prefill server
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
response = await session.get(f"{prefill_server}/v1/models")
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Prefill server error: Status {response.status}",
|
||||
)
|
||||
return ORJSONResponse(content=await response.json())
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1,9 +1,23 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from sglang_router.router_args import RouterArgs
|
||||
from sglang_router_rs import PolicyType
|
||||
from sglang_router_rs import Router as _Router
|
||||
|
||||
|
||||
def policy_from_str(policy_str: Optional[str]) -> PolicyType:
|
||||
"""Convert policy string to PolicyType enum."""
|
||||
if policy_str is None:
|
||||
return None
|
||||
policy_map = {
|
||||
"random": PolicyType.Random,
|
||||
"round_robin": PolicyType.RoundRobin,
|
||||
"cache_aware": PolicyType.CacheAware,
|
||||
"power_of_two": PolicyType.PowerOfTwo,
|
||||
}
|
||||
return policy_map[policy_str]
|
||||
|
||||
|
||||
class Router:
|
||||
"""
|
||||
A high-performance router for distributing requests across worker nodes.
|
||||
@@ -78,130 +92,34 @@ class Router:
|
||||
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_urls: List[str],
|
||||
policy: PolicyType = PolicyType.RoundRobin,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 3001,
|
||||
worker_startup_timeout_secs: int = 600,
|
||||
worker_startup_check_interval: int = 30,
|
||||
cache_threshold: float = 0.3,
|
||||
balance_abs_threshold: int = 64,
|
||||
balance_rel_threshold: float = 1.5,
|
||||
eviction_interval_secs: int = 120,
|
||||
max_tree_size: int = 2**26,
|
||||
max_payload_size: int = 512 * 1024 * 1024, # 512MB
|
||||
dp_aware: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
log_dir: Optional[str] = None,
|
||||
log_level: Optional[str] = None,
|
||||
service_discovery: bool = False,
|
||||
selector: Dict[str, str] = None,
|
||||
service_discovery_port: int = 80,
|
||||
service_discovery_namespace: Optional[str] = None,
|
||||
prefill_selector: Dict[str, str] = None,
|
||||
decode_selector: Dict[str, str] = None,
|
||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
|
||||
prometheus_port: Optional[int] = None,
|
||||
prometheus_host: Optional[str] = None,
|
||||
request_timeout_secs: int = 1800,
|
||||
request_id_headers: Optional[List[str]] = None,
|
||||
pd_disaggregation: bool = False,
|
||||
prefill_urls: Optional[List[tuple]] = None,
|
||||
decode_urls: Optional[List[str]] = None,
|
||||
prefill_policy: Optional[PolicyType] = None,
|
||||
decode_policy: Optional[PolicyType] = None,
|
||||
max_concurrent_requests: int = 256,
|
||||
queue_size: int = 100,
|
||||
queue_timeout_secs: int = 60,
|
||||
rate_limit_tokens_per_second: Optional[int] = None,
|
||||
cors_allowed_origins: List[str] = None,
|
||||
retry_max_retries: int = 5,
|
||||
retry_initial_backoff_ms: int = 50,
|
||||
retry_max_backoff_ms: int = 30_000,
|
||||
retry_backoff_multiplier: float = 1.5,
|
||||
retry_jitter_factor: float = 0.2,
|
||||
cb_failure_threshold: int = 10,
|
||||
cb_success_threshold: int = 3,
|
||||
cb_timeout_duration_secs: int = 60,
|
||||
cb_window_duration_secs: int = 120,
|
||||
disable_retries: bool = False,
|
||||
disable_circuit_breaker: bool = False,
|
||||
health_failure_threshold: int = 3,
|
||||
health_success_threshold: int = 2,
|
||||
health_check_timeout_secs: int = 5,
|
||||
health_check_interval_secs: int = 60,
|
||||
health_check_endpoint: str = "/health",
|
||||
model_path: Optional[str] = None,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
):
|
||||
if selector is None:
|
||||
selector = {}
|
||||
if prefill_selector is None:
|
||||
prefill_selector = {}
|
||||
if decode_selector is None:
|
||||
decode_selector = {}
|
||||
if cors_allowed_origins is None:
|
||||
cors_allowed_origins = []
|
||||
def __init__(self, router: _Router):
|
||||
self._router = router
|
||||
|
||||
self._router = _Router(
|
||||
worker_urls=worker_urls,
|
||||
policy=policy,
|
||||
host=host,
|
||||
port=port,
|
||||
worker_startup_timeout_secs=worker_startup_timeout_secs,
|
||||
worker_startup_check_interval=worker_startup_check_interval,
|
||||
cache_threshold=cache_threshold,
|
||||
balance_abs_threshold=balance_abs_threshold,
|
||||
balance_rel_threshold=balance_rel_threshold,
|
||||
eviction_interval_secs=eviction_interval_secs,
|
||||
max_tree_size=max_tree_size,
|
||||
max_payload_size=max_payload_size,
|
||||
dp_aware=dp_aware,
|
||||
api_key=api_key,
|
||||
log_dir=log_dir,
|
||||
log_level=log_level,
|
||||
service_discovery=service_discovery,
|
||||
selector=selector,
|
||||
service_discovery_port=service_discovery_port,
|
||||
service_discovery_namespace=service_discovery_namespace,
|
||||
prefill_selector=prefill_selector,
|
||||
decode_selector=decode_selector,
|
||||
bootstrap_port_annotation=bootstrap_port_annotation,
|
||||
prometheus_port=prometheus_port,
|
||||
prometheus_host=prometheus_host,
|
||||
request_timeout_secs=request_timeout_secs,
|
||||
request_id_headers=request_id_headers,
|
||||
pd_disaggregation=pd_disaggregation,
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
queue_size=queue_size,
|
||||
queue_timeout_secs=queue_timeout_secs,
|
||||
rate_limit_tokens_per_second=rate_limit_tokens_per_second,
|
||||
cors_allowed_origins=cors_allowed_origins,
|
||||
retry_max_retries=retry_max_retries,
|
||||
retry_initial_backoff_ms=retry_initial_backoff_ms,
|
||||
retry_max_backoff_ms=retry_max_backoff_ms,
|
||||
retry_backoff_multiplier=retry_backoff_multiplier,
|
||||
retry_jitter_factor=retry_jitter_factor,
|
||||
cb_failure_threshold=cb_failure_threshold,
|
||||
cb_success_threshold=cb_success_threshold,
|
||||
cb_timeout_duration_secs=cb_timeout_duration_secs,
|
||||
cb_window_duration_secs=cb_window_duration_secs,
|
||||
disable_retries=disable_retries,
|
||||
disable_circuit_breaker=disable_circuit_breaker,
|
||||
health_failure_threshold=health_failure_threshold,
|
||||
health_success_threshold=health_success_threshold,
|
||||
health_check_timeout_secs=health_check_timeout_secs,
|
||||
health_check_interval_secs=health_check_interval_secs,
|
||||
health_check_endpoint=health_check_endpoint,
|
||||
model_path=model_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
@staticmethod
|
||||
def from_args(args: RouterArgs) -> "Router":
|
||||
"""Create a router from a RouterArgs instance."""
|
||||
|
||||
args_dict = vars(args)
|
||||
# Convert RouterArgs to _Router parameters
|
||||
args_dict["worker_urls"] = (
|
||||
[]
|
||||
if args_dict["service_discovery"] or args_dict["pd_disaggregation"]
|
||||
else args_dict["worker_urls"]
|
||||
)
|
||||
args_dict["policy"] = policy_from_str(args_dict["policy"])
|
||||
args_dict["prefill_urls"] = (
|
||||
args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None
|
||||
)
|
||||
args_dict["decode_urls"] = (
|
||||
args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None
|
||||
)
|
||||
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
|
||||
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
|
||||
|
||||
# remoge mini_lb parameter
|
||||
args_dict.pop("mini_lb")
|
||||
|
||||
return Router(_Router(**args_dict))
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the router server.
|
||||
|
||||
577
sgl-router/py_src/sglang_router/router_args.py
Normal file
577
sgl-router/py_src/sglang_router/router_args.py
Normal file
@@ -0,0 +1,577 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RouterArgs:
|
||||
# Worker configuration
|
||||
worker_urls: List[str] = dataclasses.field(default_factory=list)
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
|
||||
# PD-specific configuration
|
||||
mini_lb: bool = False
|
||||
pd_disaggregation: bool = False # Enable PD disaggregated mode
|
||||
prefill_urls: List[tuple] = dataclasses.field(
|
||||
default_factory=list
|
||||
) # List of (url, bootstrap_port)
|
||||
decode_urls: List[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
# Routing policy
|
||||
policy: str = "cache_aware"
|
||||
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
|
||||
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
|
||||
worker_startup_timeout_secs: int = 600
|
||||
worker_startup_check_interval: int = 30
|
||||
cache_threshold: float = 0.3
|
||||
balance_abs_threshold: int = 64
|
||||
balance_rel_threshold: float = 1.5
|
||||
eviction_interval_secs: int = 120
|
||||
max_tree_size: int = 2**26
|
||||
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
|
||||
dp_aware: bool = False
|
||||
api_key: Optional[str] = None
|
||||
log_dir: Optional[str] = None
|
||||
log_level: Optional[str] = None
|
||||
# Service discovery configuration
|
||||
service_discovery: bool = False
|
||||
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
service_discovery_port: int = 80
|
||||
service_discovery_namespace: Optional[str] = None
|
||||
# PD service discovery configuration
|
||||
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
|
||||
# Prometheus configuration
|
||||
prometheus_port: Optional[int] = None
|
||||
prometheus_host: Optional[str] = None
|
||||
# Request ID headers configuration
|
||||
request_id_headers: Optional[List[str]] = None
|
||||
# Request timeout in seconds
|
||||
request_timeout_secs: int = 1800
|
||||
# Max concurrent requests for rate limiting
|
||||
max_concurrent_requests: int = 256
|
||||
# Queue size for pending requests when max concurrent limit reached
|
||||
queue_size: int = 100
|
||||
# Maximum time (in seconds) a request can wait in queue before timing out
|
||||
queue_timeout_secs: int = 60
|
||||
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
|
||||
rate_limit_tokens_per_second: Optional[int] = None
|
||||
# CORS allowed origins
|
||||
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
|
||||
# Retry configuration
|
||||
retry_max_retries: int = 5
|
||||
retry_initial_backoff_ms: int = 50
|
||||
retry_max_backoff_ms: int = 30_000
|
||||
retry_backoff_multiplier: float = 1.5
|
||||
retry_jitter_factor: float = 0.2
|
||||
disable_retries: bool = False
|
||||
# Health check configuration
|
||||
health_failure_threshold: int = 3
|
||||
health_success_threshold: int = 2
|
||||
health_check_timeout_secs: int = 5
|
||||
health_check_interval_secs: int = 60
|
||||
health_check_endpoint: str = "/health"
|
||||
# Circuit breaker configuration
|
||||
cb_failure_threshold: int = 10
|
||||
cb_success_threshold: int = 3
|
||||
cb_timeout_duration_secs: int = 60
|
||||
cb_window_duration_secs: int = 120
|
||||
disable_circuit_breaker: bool = False
|
||||
# Tokenizer configuration
|
||||
model_path: Optional[str] = None
|
||||
tokenizer_path: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
use_router_prefix: bool = False,
|
||||
exclude_host_port: bool = False,
|
||||
):
|
||||
"""
|
||||
Add router-specific arguments to an argument parser.
|
||||
|
||||
Args:
|
||||
parser: The argument parser to add arguments to
|
||||
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
|
||||
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
|
||||
"""
|
||||
prefix = "router-" if use_router_prefix else ""
|
||||
|
||||
# Worker configuration
|
||||
if not exclude_host_port:
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=RouterArgs.host,
|
||||
help="Host address to bind the router server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=RouterArgs.port,
|
||||
help="Port number to bind the router server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--worker-urls",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
|
||||
)
|
||||
|
||||
# Routing policy configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}policy",
|
||||
type=str,
|
||||
default=RouterArgs.policy,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
|
||||
)
|
||||
|
||||
# PD-specific arguments
|
||||
parser.add_argument(
|
||||
f"--{prefix}mini-lb",
|
||||
action="store_true",
|
||||
help="Enable MiniLB",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}pd-disaggregation",
|
||||
action="store_true",
|
||||
help="Enable PD (Prefill-Decode) disaggregated mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill",
|
||||
nargs="+",
|
||||
action="append",
|
||||
help="Prefill server URL and optional bootstrap port. Can be specified multiple times. "
|
||||
"Format: --prefill URL [BOOTSTRAP_PORT]. "
|
||||
"BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode",
|
||||
nargs=1,
|
||||
action="append",
|
||||
metavar=("URL",),
|
||||
help="Decode server URL. Can be specified multiple times.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}worker-startup-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.worker_startup_timeout_secs,
|
||||
help="Timeout in seconds for worker startup",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}worker-startup-check-interval",
|
||||
type=int,
|
||||
default=RouterArgs.worker_startup_check_interval,
|
||||
help="Interval in seconds between checks for worker startup",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cache-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.cache_threshold,
|
||||
help="Cache threshold (0.0-1.0) for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}balance-abs-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.balance_abs_threshold,
|
||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}balance-rel-threshold",
|
||||
type=float,
|
||||
default=RouterArgs.balance_rel_threshold,
|
||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}eviction-interval-secs",
|
||||
type=int,
|
||||
default=RouterArgs.eviction_interval_secs,
|
||||
help="Interval in seconds between cache eviction operations",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-tree-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_tree_size,
|
||||
help="Maximum size of the approximation tree for cache-aware routing",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-payload-size",
|
||||
type=int,
|
||||
default=RouterArgs.max_payload_size,
|
||||
help="Maximum payload size in bytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}dp-aware",
|
||||
action="store_true",
|
||||
help="Enable data parallelism aware schedule",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}log-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to store log files. If not specified, logs are only output to console.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}log-level",
|
||||
type=str,
|
||||
default="info",
|
||||
choices=["debug", "info", "warning", "error", "critical"],
|
||||
help="Set the logging level. If not specified, defaults to INFO.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery",
|
||||
action="store_true",
|
||||
help="Enable Kubernetes service discovery",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default={},
|
||||
help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery-port",
|
||||
type=int,
|
||||
default=RouterArgs.service_discovery_port,
|
||||
help="Port to use for discovered worker pods",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}service-discovery-namespace",
|
||||
type=str,
|
||||
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill-selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default={},
|
||||
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode-selector",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default={},
|
||||
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
|
||||
)
|
||||
# Prometheus configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}prometheus-port",
|
||||
type=int,
|
||||
default=29000,
|
||||
help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prometheus-host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Host address to bind the Prometheus metrics server",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}request-id-headers",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}request-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.request_timeout_secs,
|
||||
help="Request timeout in seconds",
|
||||
)
|
||||
# Retry configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-max-retries",
|
||||
type=int,
|
||||
default=RouterArgs.retry_max_retries,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-initial-backoff-ms",
|
||||
type=int,
|
||||
default=RouterArgs.retry_initial_backoff_ms,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-max-backoff-ms",
|
||||
type=int,
|
||||
default=RouterArgs.retry_max_backoff_ms,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-backoff-multiplier",
|
||||
type=float,
|
||||
default=RouterArgs.retry_backoff_multiplier,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}retry-jitter-factor",
|
||||
type=float,
|
||||
default=RouterArgs.retry_jitter_factor,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}disable-retries",
|
||||
action="store_true",
|
||||
help="Disable retries (equivalent to setting retry_max_retries=1)",
|
||||
)
|
||||
# Circuit breaker configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-failure-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.cb_failure_threshold,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-success-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.cb_success_threshold,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-timeout-duration-secs",
|
||||
type=int,
|
||||
default=RouterArgs.cb_timeout_duration_secs,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cb-window-duration-secs",
|
||||
type=int,
|
||||
default=RouterArgs.cb_window_duration_secs,
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}disable-circuit-breaker",
|
||||
action="store_true",
|
||||
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
|
||||
)
|
||||
# Health check configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-failure-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.health_failure_threshold,
|
||||
help="Number of consecutive health check failures before marking worker unhealthy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-success-threshold",
|
||||
type=int,
|
||||
default=RouterArgs.health_success_threshold,
|
||||
help="Number of consecutive health check successes before marking worker healthy",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.health_check_timeout_secs,
|
||||
help="Timeout in seconds for health check requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-interval-secs",
|
||||
type=int,
|
||||
default=RouterArgs.health_check_interval_secs,
|
||||
help="Interval in seconds between runtime health checks",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}health-check-endpoint",
|
||||
type=str,
|
||||
default=RouterArgs.health_check_endpoint,
|
||||
help="Health check endpoint path",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-concurrent-requests",
|
||||
type=int,
|
||||
default=RouterArgs.max_concurrent_requests,
|
||||
help="Maximum number of concurrent requests allowed (for rate limiting)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}queue-size",
|
||||
type=int,
|
||||
default=RouterArgs.queue_size,
|
||||
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}queue-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.queue_timeout_secs,
|
||||
help="Maximum time (in seconds) a request can wait in queue before timing out",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}rate-limit-tokens-per-second",
|
||||
type=int,
|
||||
default=RouterArgs.rate_limit_tokens_per_second,
|
||||
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cors-allowed-origins",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
|
||||
)
|
||||
# Tokenizer configuration
|
||||
parser.add_argument(
|
||||
f"--{prefix}model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model path for loading tokenizer (HuggingFace model ID or local path)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}tokenizer-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
cls, args: argparse.Namespace, use_router_prefix: bool = False
|
||||
) -> "RouterArgs":
|
||||
"""
|
||||
Create RouterArgs instance from parsed command line arguments.
|
||||
|
||||
Args:
|
||||
args: Parsed command line arguments
|
||||
use_router_prefix: If True, look for arguments with 'router-' prefix
|
||||
"""
|
||||
prefix = "router_" if use_router_prefix else ""
|
||||
cli_args_dict = vars(args)
|
||||
args_dict = {}
|
||||
|
||||
for attr in dataclasses.fields(cls):
|
||||
# Auto strip prefix from args
|
||||
if f"{prefix}{attr.name}" in cli_args_dict:
|
||||
args_dict[attr.name] = cli_args_dict[f"{prefix}{attr.name}"]
|
||||
elif attr.name in cli_args_dict:
|
||||
args_dict[attr.name] = cli_args_dict[attr.name]
|
||||
|
||||
# parse special arguments and remove "--prefill" and "--decode" from cli_args_dict
|
||||
args_dict["prefill_urls"] = cls._parse_prefill_urls(
|
||||
cli_args_dict.get(f"{prefix}prefill", None)
|
||||
)
|
||||
args_dict["decode_urls"] = cls._parse_decode_urls(
|
||||
cli_args_dict.get(f"{prefix}decode", None)
|
||||
)
|
||||
args_dict["selector"] = cls._parse_selector(
|
||||
cli_args_dict.get(f"{prefix}selector", None)
|
||||
)
|
||||
args_dict["prefill_selector"] = cls._parse_selector(
|
||||
cli_args_dict.get(f"{prefix}prefill_selector", None)
|
||||
)
|
||||
args_dict["decode_selector"] = cls._parse_selector(
|
||||
cli_args_dict.get(f"{prefix}decode_selector", None)
|
||||
)
|
||||
|
||||
# Mooncake-specific annotation
|
||||
args_dict["bootstrap_port_annotation"] = "sglang.ai/bootstrap-port"
|
||||
|
||||
return cls(**args_dict)
|
||||
|
||||
def _validate_router_args(self):
|
||||
# Validate configuration based on mode
|
||||
if self.pd_disaggregation:
|
||||
# Validate PD configuration - skip URL requirements if using service discovery
|
||||
if not self.service_discovery:
|
||||
if not self.prefill_urls:
|
||||
raise ValueError("PD disaggregation mode requires --prefill")
|
||||
if not self.decode_urls:
|
||||
raise ValueError("PD disaggregation mode requires --decode")
|
||||
|
||||
# Warn about policy usage in PD mode
|
||||
if self.prefill_policy and self.decode_policy and self.policy:
|
||||
logger.warning(
|
||||
"Both --prefill-policy and --decode-policy are specified. "
|
||||
"The main --policy flag will be ignored for PD mode."
|
||||
)
|
||||
elif self.prefill_policy and not self.decode_policy and self.policy:
|
||||
logger.info(
|
||||
f"Using --prefill-policy '{self.prefill_policy}' for prefill nodes "
|
||||
f"and --policy '{self.policy}' for decode nodes."
|
||||
)
|
||||
elif self.decode_policy and not self.prefill_policy and self.policy:
|
||||
logger.info(
|
||||
f"Using --policy '{self.policy}' for prefill nodes "
|
||||
f"and --decode-policy '{self.decode_policy}' for decode nodes."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_selector(selector_list):
|
||||
if not selector_list:
|
||||
return {}
|
||||
|
||||
selector = {}
|
||||
for item in selector_list:
|
||||
if "=" in item:
|
||||
key, value = item.split("=", 1)
|
||||
selector[key] = value
|
||||
return selector
|
||||
|
||||
@staticmethod
|
||||
def _parse_prefill_urls(prefill_list):
|
||||
"""Parse prefill URLs from --prefill arguments.
|
||||
|
||||
Format: --prefill URL [BOOTSTRAP_PORT]
|
||||
Example:
|
||||
--prefill http://prefill1:8080 9000 # With bootstrap port
|
||||
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
|
||||
--prefill http://prefill3:8080 # Defaults to no bootstrap port
|
||||
"""
|
||||
if not prefill_list:
|
||||
return []
|
||||
|
||||
prefill_urls = []
|
||||
for prefill_args in prefill_list:
|
||||
|
||||
url = prefill_args[0]
|
||||
|
||||
# Handle optional bootstrap port
|
||||
if len(prefill_args) >= 2:
|
||||
bootstrap_port_str = prefill_args[1]
|
||||
# Handle 'none' as None
|
||||
if bootstrap_port_str.lower() == "none":
|
||||
bootstrap_port = None
|
||||
else:
|
||||
try:
|
||||
bootstrap_port = int(bootstrap_port_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
|
||||
)
|
||||
else:
|
||||
# No bootstrap port specified, default to None
|
||||
bootstrap_port = None
|
||||
|
||||
prefill_urls.append((url, bootstrap_port))
|
||||
|
||||
return prefill_urls
|
||||
|
||||
@staticmethod
|
||||
def _parse_decode_urls(decode_list):
|
||||
"""Parse decode URLs from --decode arguments.
|
||||
|
||||
Format: --decode URL
|
||||
Example: --decode http://decode1:8081 --decode http://decode2:8081
|
||||
"""
|
||||
if not decode_list:
|
||||
return []
|
||||
|
||||
# decode_list is a list of single-element lists due to nargs=1
|
||||
return [url[0] for url in decode_list]
|
||||
@@ -33,7 +33,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
cache_threshold=0.5,
|
||||
balance_abs_threshold=32,
|
||||
balance_rel_threshold=1.0001,
|
||||
eviction_interval=60,
|
||||
eviction_interval_secs=60,
|
||||
max_tree_size=2**24,
|
||||
max_payload_size=256 * 1024 * 1024, # 256MB
|
||||
verbose=False,
|
||||
@@ -176,9 +176,8 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
"""Test basic PD router functionality without actually starting servers."""
|
||||
# This test just verifies the PD router can be created and configured
|
||||
# without actually starting it (which would require real prefill/decode servers)
|
||||
from sglang_router import Router
|
||||
from sglang_router.launch_router import RouterArgs
|
||||
from sglang_router_rs import PolicyType
|
||||
from sglang_router.router import PolicyType, Router
|
||||
|
||||
# Test RouterArgs parsing for PD mode
|
||||
# Simulate the parsed args structure from argparse with action="append"
|
||||
@@ -209,18 +208,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
|
||||
|
||||
# Test Router creation in PD mode
|
||||
router = Router(
|
||||
worker_urls=[], # Empty for PD mode
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[
|
||||
("http://prefill1:8080", 9000),
|
||||
("http://prefill2:8080", None),
|
||||
],
|
||||
decode_urls=["http://decode1:8081", "http://decode2:8081"],
|
||||
policy=PolicyType.CacheAware,
|
||||
host="127.0.0.1",
|
||||
port=3001,
|
||||
)
|
||||
router = Router.from_args(router_args)
|
||||
self.assertIsNotNone(router)
|
||||
|
||||
def test_policy_validation(self):
|
||||
|
||||
@@ -77,7 +77,7 @@ def popen_launch_router(
|
||||
port,
|
||||
"--dp",
|
||||
str(dp_size),
|
||||
"--router-eviction-interval",
|
||||
"--router-eviction-interval-secs",
|
||||
"5",
|
||||
"--router-policy",
|
||||
policy,
|
||||
|
||||
@@ -28,8 +28,3 @@ find = { where = ["py_src"] }
|
||||
# workaround for https://github.com/pypa/twine/issues/1216
|
||||
[tool.setuptools]
|
||||
license-files = []
|
||||
|
||||
[[tool.setuptools-rust.ext-modules]]
|
||||
target = "sglang_router_rs"
|
||||
path = "Cargo.toml"
|
||||
binding = "PyO3"
|
||||
|
||||
21
sgl-router/setup.py
Normal file
21
sgl-router/setup.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import os
|
||||
|
||||
from setuptools import setup
|
||||
from setuptools_rust import Binding, RustExtension
|
||||
|
||||
no_rust = os.environ.get("SGLANG_ROUTER_BUILD_NO_RUST") == "1"
|
||||
|
||||
rust_extensions = []
|
||||
if not no_rust:
|
||||
rust_extensions.append(
|
||||
RustExtension(
|
||||
target="sglang_router_rs",
|
||||
path="Cargo.toml",
|
||||
binding=Binding.PyO3,
|
||||
)
|
||||
)
|
||||
|
||||
setup(
|
||||
rust_extensions=rust_extensions,
|
||||
zip_safe=False,
|
||||
)
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
@@ -18,6 +17,7 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_pd_server,
|
||||
popen_with_error_check,
|
||||
)
|
||||
|
||||
|
||||
@@ -47,7 +47,9 @@ class TestDisaggregationAccuracy(CustomTestCase):
|
||||
lb_command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.srt.disaggregation.mini_lb",
|
||||
"sglang_router.launch_router",
|
||||
"--pd-disaggregation",
|
||||
"--mini-lb", # FIXME: remove this
|
||||
"--prefill",
|
||||
cls.prefill_url,
|
||||
"--decode",
|
||||
@@ -59,9 +61,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
|
||||
]
|
||||
|
||||
print("Starting load balancer:", " ".join(lb_command))
|
||||
cls.process_lb = subprocess.Popen(
|
||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
cls.process_lb = popen_with_error_check(lb_command)
|
||||
cls.wait_server_ready(cls.lb_url + "/health")
|
||||
|
||||
@classmethod
|
||||
@@ -228,7 +228,9 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
|
||||
lb_command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.srt.disaggregation.mini_lb",
|
||||
"sglang_router.launch_router",
|
||||
"--pd-disaggregation",
|
||||
"--mini-lb", # FIXME: remove this
|
||||
"--prefill",
|
||||
cls.prefill_url,
|
||||
"--decode",
|
||||
@@ -240,9 +242,7 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
|
||||
]
|
||||
|
||||
print("Starting load balancer:", " ".join(lb_command))
|
||||
cls.process_lb = subprocess.Popen(
|
||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
cls.process_lb = popen_with_error_check(lb_command)
|
||||
cls.wait_server_ready(cls.lb_url + "/health")
|
||||
|
||||
@classmethod
|
||||
@@ -383,7 +383,9 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
|
||||
lb_command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.srt.disaggregation.mini_lb",
|
||||
"sglang_router.launch_router",
|
||||
"--pd-disaggregation",
|
||||
"--mini-lb", # FIXME: remove this
|
||||
"--prefill",
|
||||
cls.prefill_url,
|
||||
"--decode",
|
||||
@@ -395,9 +397,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
|
||||
]
|
||||
|
||||
print("Starting load balancer:", " ".join(lb_command))
|
||||
cls.process_lb = subprocess.Popen(
|
||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
cls.process_lb = popen_with_error_check(lb_command)
|
||||
cls.wait_server_ready(cls.lb_url + "/health")
|
||||
|
||||
@classmethod
|
||||
@@ -509,7 +509,9 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
|
||||
lb_command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.srt.disaggregation.mini_lb",
|
||||
"sglang_router.launch_router",
|
||||
"--pd-disaggregation",
|
||||
"--mini-lb", # FIXME: remove this
|
||||
"--prefill",
|
||||
cls.prefill_url,
|
||||
"--decode",
|
||||
@@ -521,9 +523,7 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
|
||||
]
|
||||
|
||||
print("Starting load balancer:", " ".join(lb_command))
|
||||
cls.process_lb = subprocess.Popen(
|
||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
cls.process_lb = popen_with_error_check(lb_command)
|
||||
cls.wait_server_ready(cls.lb_url + "/health")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -15,7 +15,7 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_pd_server,
|
||||
run_with_timeout,
|
||||
popen_with_error_check,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,7 +49,9 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
|
||||
lb_command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.srt.disaggregation.mini_lb",
|
||||
"sglang_router.launch_router",
|
||||
"--pd-disaggregation",
|
||||
"--mini-lb", # FIXME: remove this
|
||||
"--prefill",
|
||||
cls.prefill_url,
|
||||
"--decode",
|
||||
@@ -61,9 +63,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
|
||||
]
|
||||
|
||||
print("Starting load balancer:", " ".join(lb_command))
|
||||
cls.process_lb = subprocess.Popen(
|
||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
cls.process_lb = popen_with_error_check(lb_command)
|
||||
cls.wait_server_ready(cls.lb_url + "/health")
|
||||
|
||||
@classmethod
|
||||
@@ -183,7 +183,9 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
|
||||
lb_command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.srt.disaggregation.mini_lb",
|
||||
"sglang_router.launch_router",
|
||||
"--pd-disaggregation",
|
||||
"--mini-lb", # FIXME: remove this
|
||||
"--prefill",
|
||||
cls.prefill_url,
|
||||
"--decode",
|
||||
@@ -195,9 +197,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
|
||||
]
|
||||
|
||||
print("Starting load balancer:", " ".join(lb_command))
|
||||
cls.process_lb = subprocess.Popen(
|
||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
cls.process_lb = popen_with_error_check(lb_command)
|
||||
cls.wait_server_ready(cls.lb_url + "/health")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -49,7 +49,9 @@ class TestPDPPAccuracy(unittest.TestCase):
|
||||
lb_command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.srt.disaggregation.mini_lb",
|
||||
"sglang_router.launch_router",
|
||||
"--pd-disaggregation",
|
||||
"--mini-lb", # FIXME: remove this
|
||||
"--prefill",
|
||||
cls.prefill_url,
|
||||
"--decode",
|
||||
|
||||
Reference in New Issue
Block a user