Simplify Router arguments passing and build it in docker image (#9964)
This commit is contained in:
@@ -36,7 +36,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
|||||||
ibverbs-providers infiniband-diags perftest \
|
ibverbs-providers infiniband-diags perftest \
|
||||||
libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev \
|
libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev \
|
||||||
libboost-all-dev libssl-dev \
|
libboost-all-dev libssl-dev \
|
||||||
libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler-grpc \
|
libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler protobuf-compiler-grpc \
|
||||||
pybind11-dev \
|
pybind11-dev \
|
||||||
libhiredis-dev libcurl4-openssl-dev \
|
libhiredis-dev libcurl4-openssl-dev \
|
||||||
libczmq4 libczmq-dev \
|
libczmq4 libczmq-dev \
|
||||||
@@ -218,6 +218,19 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1
|
|||||||
&& cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \
|
&& cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \
|
||||||
&& rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz
|
&& rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz
|
||||||
|
|
||||||
|
# Install Rust toolchain for sgl-router
|
||||||
|
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||||
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
|
||||||
|
&& rustc --version && cargo --version
|
||||||
|
|
||||||
|
# Build and install sgl-router
|
||||||
|
RUN python3 -m pip install --no-cache-dir setuptools-rust \
|
||||||
|
&& cd /sgl-workspace/sglang/sgl-router \
|
||||||
|
&& cargo build --release \
|
||||||
|
&& python3 -m pip install --no-cache-dir . \
|
||||||
|
&& rm -rf /root/.cache
|
||||||
|
|
||||||
|
|
||||||
# Add yank script
|
# Add yank script
|
||||||
COPY --chown=root:root <<-"EOF" /usr/local/bin/yank
|
COPY --chown=root:root <<-"EOF" /usr/local/bin/yank
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ uv pip install mooncake-transfer-engine
|
|||||||
```bash
|
```bash
|
||||||
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-ib-device mlx5_roce0
|
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-ib-device mlx5_roce0
|
||||||
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-ib-device mlx5_roce0
|
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-ib-device mlx5_roce0
|
||||||
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
$ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
||||||
```
|
```
|
||||||
|
|
||||||
### DeepSeek Multi-Node
|
### DeepSeek Multi-Node
|
||||||
@@ -100,7 +100,7 @@ pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx"
|
|||||||
```bash
|
```bash
|
||||||
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend nixl
|
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend nixl
|
||||||
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend nixl
|
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend nixl
|
||||||
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
$ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
||||||
```
|
```
|
||||||
|
|
||||||
### DeepSeek Multi-Node
|
### DeepSeek Multi-Node
|
||||||
@@ -137,7 +137,7 @@ export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true
|
|||||||
```bash
|
```bash
|
||||||
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend ascend
|
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend ascend
|
||||||
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend ascend
|
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend ascend
|
||||||
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
$ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
|
||||||
```
|
```
|
||||||
|
|
||||||
### DeepSeek Multi-Node
|
### DeepSeek Multi-Node
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ The most sophisticated policy that combines cache optimization with load balanci
|
|||||||
|
|
||||||
3. **Cache Management**:
|
3. **Cache Management**:
|
||||||
- Maintains approximate radix trees per worker
|
- Maintains approximate radix trees per worker
|
||||||
- Periodically evicts LRU entries based on `--eviction-interval` and `--max-tree-size`
|
- Periodically evicts LRU entries based on `--eviction-interval-secs` and `--max-tree-size`
|
||||||
|
|
||||||
### Data Parallelism Aware Routing
|
### Data Parallelism Aware Routing
|
||||||
|
|
||||||
@@ -296,7 +296,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
|
|||||||
### Core Settings
|
### Core Settings
|
||||||
|
|
||||||
| Parameter | Type | Default | Description |
|
| Parameter | Type | Default | Description |
|
||||||
|-----------------------------|------|-------------|-----------------------------------------------------------------|
|
| --------------------------- | ---- | ----------- | --------------------------------------------------------------- |
|
||||||
| `--host` | str | 127.0.0.1 | Router server host address |
|
| `--host` | str | 127.0.0.1 | Router server host address |
|
||||||
| `--port` | int | 30000 | Router server port |
|
| `--port` | int | 30000 | Router server port |
|
||||||
| `--worker-urls` | list | [] | Worker URLs for separate launch mode |
|
| `--worker-urls` | list | [] | Worker URLs for separate launch mode |
|
||||||
@@ -307,18 +307,18 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
|
|||||||
|
|
||||||
### Cache-Aware Routing Parameters
|
### Cache-Aware Routing Parameters
|
||||||
|
|
||||||
| Parameter | Type | Default | Description |
|
| Parameter | Type | Default | Description |
|
||||||
|---------------------------|-------|----------|--------------------------------------------------------|
|
| -------------------------- | ----- | -------- | ------------------------------------------------------ |
|
||||||
| `--cache-threshold` | float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) |
|
| `--cache-threshold` | float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) |
|
||||||
| `--balance-abs-threshold` | int | 32 | Absolute load difference threshold |
|
| `--balance-abs-threshold` | int | 32 | Absolute load difference threshold |
|
||||||
| `--balance-rel-threshold` | float | 1.0001 | Relative load ratio threshold |
|
| `--balance-rel-threshold` | float | 1.0001 | Relative load ratio threshold |
|
||||||
| `--eviction-interval` | int | 60 | Seconds between cache eviction cycles |
|
| `--eviction-interval-secs` | int | 60 | Seconds between cache eviction cycles |
|
||||||
| `--max-tree-size` | int | 16777216 | Maximum nodes in routing tree |
|
| `--max-tree-size` | int | 16777216 | Maximum nodes in routing tree |
|
||||||
|
|
||||||
### Fault Tolerance Parameters
|
### Fault Tolerance Parameters
|
||||||
|
|
||||||
| Parameter | Type | Default | Description |
|
| Parameter | Type | Default | Description |
|
||||||
|------------------------------|-------|---------|---------------------------------------|
|
| ---------------------------- | ----- | ------- | ------------------------------------- |
|
||||||
| `--retry-max-retries` | int | 3 | Maximum retry attempts per request |
|
| `--retry-max-retries` | int | 3 | Maximum retry attempts per request |
|
||||||
| `--retry-initial-backoff-ms` | int | 100 | Initial retry backoff in milliseconds |
|
| `--retry-initial-backoff-ms` | int | 100 | Initial retry backoff in milliseconds |
|
||||||
| `--retry-max-backoff-ms` | int | 10000 | Maximum retry backoff in milliseconds |
|
| `--retry-max-backoff-ms` | int | 10000 | Maximum retry backoff in milliseconds |
|
||||||
@@ -334,7 +334,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
|
|||||||
### Prefill-Decode Disaggregation Parameters
|
### Prefill-Decode Disaggregation Parameters
|
||||||
|
|
||||||
| Parameter | Type | Default | Description |
|
| Parameter | Type | Default | Description |
|
||||||
|-----------------------------------|------|---------|-------------------------------------------------------|
|
| --------------------------------- | ---- | ------- | ----------------------------------------------------- |
|
||||||
| `--pd-disaggregation` | flag | False | Enable PD disaggregated mode |
|
| `--pd-disaggregation` | flag | False | Enable PD disaggregated mode |
|
||||||
| `--prefill` | list | [] | Prefill server URLs with optional bootstrap ports |
|
| `--prefill` | list | [] | Prefill server URLs with optional bootstrap ports |
|
||||||
| `--decode` | list | [] | Decode server URLs |
|
| `--decode` | list | [] | Decode server URLs |
|
||||||
@@ -346,7 +346,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
|
|||||||
### Kubernetes Integration
|
### Kubernetes Integration
|
||||||
|
|
||||||
| Parameter | Type | Default | Description |
|
| Parameter | Type | Default | Description |
|
||||||
|---------------------------------|------|--------------------------|------------------------------------------------------|
|
| ------------------------------- | ---- | ------------------------ | ---------------------------------------------------- |
|
||||||
| `--service-discovery` | flag | False | Enable Kubernetes service discovery |
|
| `--service-discovery` | flag | False | Enable Kubernetes service discovery |
|
||||||
| `--selector` | list | [] | Label selector for workers (key1=value1 key2=value2) |
|
| `--selector` | list | [] | Label selector for workers (key1=value1 key2=value2) |
|
||||||
| `--prefill-selector` | list | [] | Label selector for prefill servers in PD mode |
|
| `--prefill-selector` | list | [] | Label selector for prefill servers in PD mode |
|
||||||
@@ -358,7 +358,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
|
|||||||
### Observability
|
### Observability
|
||||||
|
|
||||||
| Parameter | Type | Default | Description |
|
| Parameter | Type | Default | Description |
|
||||||
|------------------------|------|-----------|-------------------------------------------------------|
|
| ---------------------- | ---- | --------- | ----------------------------------------------------- |
|
||||||
| `--prometheus-port` | int | 29000 | Prometheus metrics port |
|
| `--prometheus-port` | int | 29000 | Prometheus metrics port |
|
||||||
| `--prometheus-host` | str | 127.0.0.1 | Prometheus metrics host |
|
| `--prometheus-host` | str | 127.0.0.1 | Prometheus metrics host |
|
||||||
| `--log-dir` | str | None | Directory for log files |
|
| `--log-dir` | str | None | Directory for log files |
|
||||||
@@ -368,7 +368,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
|
|||||||
### CORS Configuration
|
### CORS Configuration
|
||||||
|
|
||||||
| Parameter | Type | Default | Description |
|
| Parameter | Type | Default | Description |
|
||||||
|--------------------------|------|---------|----------------------|
|
| ------------------------ | ---- | ------- | -------------------- |
|
||||||
| `--cors-allowed-origins` | list | [] | Allowed CORS origins |
|
| `--cors-allowed-origins` | list | [] | Allowed CORS origins |
|
||||||
|
|
||||||
## Advanced Features
|
## Advanced Features
|
||||||
@@ -429,7 +429,7 @@ python -m sglang_router.launch_router \
|
|||||||
|
|
||||||
2. **High latency**: Check if cache-aware routing is causing imbalance. Try adjusting `--balance-abs-threshold` and `--balance-rel-threshold`.
|
2. **High latency**: Check if cache-aware routing is causing imbalance. Try adjusting `--balance-abs-threshold` and `--balance-rel-threshold`.
|
||||||
|
|
||||||
3. **Memory growth**: Reduce `--max-tree-size` or decrease `--eviction-interval` for more aggressive cache cleanup.
|
3. **Memory growth**: Reduce `--max-tree-size` or decrease `--eviction-interval-secs` for more aggressive cache cleanup.
|
||||||
|
|
||||||
4. **Circuit breaker triggering frequently**: Increase `--cb-failure-threshold` or extend `--cb-window-duration-secs`.
|
4. **Circuit breaker triggering frequently**: Increase `--cb-failure-threshold` or extend `--cb-window-duration-secs`.
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,8 @@ spec:
|
|||||||
command:
|
command:
|
||||||
- python
|
- python
|
||||||
- -m
|
- -m
|
||||||
- sglang.srt.disaggregation.mini_lb
|
- sglang_router.launch_router
|
||||||
|
- --pd-disaggregation
|
||||||
- --prefill
|
- --prefill
|
||||||
- http://deepseekr10528-prefill-main:30000
|
- http://deepseekr10528-prefill-main:30000
|
||||||
- --decode
|
- --decode
|
||||||
|
|||||||
@@ -714,7 +714,8 @@ spec:
|
|||||||
command:
|
command:
|
||||||
- python
|
- python
|
||||||
- -m
|
- -m
|
||||||
- sglang.srt.disaggregation.mini_lb
|
- sglang_router.launch_router
|
||||||
|
- --pd-disaggregation
|
||||||
- --prefill
|
- --prefill
|
||||||
- http://deepseekr10528-prefill-main:30000
|
- http://deepseekr10528-prefill-main:30000
|
||||||
- --decode
|
- --decode
|
||||||
|
|||||||
@@ -1,118 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import dataclasses
|
|
||||||
|
|
||||||
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class LBArgs:
|
|
||||||
host: str = "0.0.0.0"
|
|
||||||
port: int = 8000
|
|
||||||
policy: str = "random"
|
|
||||||
prefill_infos: list = dataclasses.field(default_factory=list)
|
|
||||||
decode_infos: list = dataclasses.field(default_factory=list)
|
|
||||||
log_interval: int = 5
|
|
||||||
timeout: int = 600
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
|
||||||
parser.add_argument(
|
|
||||||
"--host",
|
|
||||||
type=str,
|
|
||||||
default=LBArgs.host,
|
|
||||||
help=f"Host to bind the server (default: {LBArgs.host})",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--port",
|
|
||||||
type=int,
|
|
||||||
default=LBArgs.port,
|
|
||||||
help=f"Port to bind the server (default: {LBArgs.port})",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--policy",
|
|
||||||
type=str,
|
|
||||||
default=LBArgs.policy,
|
|
||||||
choices=["random", "po2"],
|
|
||||||
help=f"Policy to use for load balancing (default: {LBArgs.policy})",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--prefill",
|
|
||||||
type=str,
|
|
||||||
default=[],
|
|
||||||
nargs="+",
|
|
||||||
help="URLs for prefill servers",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--decode",
|
|
||||||
type=str,
|
|
||||||
default=[],
|
|
||||||
nargs="+",
|
|
||||||
help="URLs for decode servers",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--prefill-bootstrap-ports",
|
|
||||||
type=int,
|
|
||||||
nargs="+",
|
|
||||||
help="Bootstrap ports for prefill servers",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--log-interval",
|
|
||||||
type=int,
|
|
||||||
default=LBArgs.log_interval,
|
|
||||||
help=f"Log interval in seconds (default: {LBArgs.log_interval})",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--timeout",
|
|
||||||
type=int,
|
|
||||||
default=LBArgs.timeout,
|
|
||||||
help=f"Timeout in seconds (default: {LBArgs.timeout})",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
|
|
||||||
bootstrap_ports = args.prefill_bootstrap_ports
|
|
||||||
if bootstrap_ports is None:
|
|
||||||
bootstrap_ports = [None] * len(args.prefill)
|
|
||||||
elif len(bootstrap_ports) == 1:
|
|
||||||
bootstrap_ports = bootstrap_ports * len(args.prefill)
|
|
||||||
else:
|
|
||||||
if len(bootstrap_ports) != len(args.prefill):
|
|
||||||
raise ValueError(
|
|
||||||
"Number of prefill URLs must match number of bootstrap ports"
|
|
||||||
)
|
|
||||||
|
|
||||||
prefill_infos = [
|
|
||||||
(url, port) for url, port in zip(args.prefill, bootstrap_ports)
|
|
||||||
]
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
host=args.host,
|
|
||||||
port=args.port,
|
|
||||||
policy=args.policy,
|
|
||||||
prefill_infos=prefill_infos,
|
|
||||||
decode_infos=args.decode,
|
|
||||||
log_interval=args.log_interval,
|
|
||||||
timeout=args.timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="PD Disaggregation Load Balancer Server"
|
|
||||||
)
|
|
||||||
LBArgs.add_cli_args(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
lb_args = LBArgs.from_cli_args(args)
|
|
||||||
|
|
||||||
prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
|
|
||||||
run(
|
|
||||||
prefill_configs,
|
|
||||||
lb_args.decode_infos,
|
|
||||||
lb_args.host,
|
|
||||||
lb_args.port,
|
|
||||||
lb_args.timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,445 +1,6 @@
|
|||||||
"""
|
raise RuntimeError(
|
||||||
Minimal HTTP load balancer for prefill and decode servers for testing.
|
"""The 'mini_lb' module has been relocated to the 'sglang_router' package.
|
||||||
"""
|
We recommend installing 'sglang-router' with Rust support for optimal performance.
|
||||||
|
If you encounter issues building the router with Rust, set the environment variable
|
||||||
import asyncio
|
'SGLANG_ROUTER_BUILD_NO_RUST=1' and add '--mini-lb' to the command line to use the Python version of 'mini_lb'."""
|
||||||
import dataclasses
|
)
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
import urllib
|
|
||||||
from http import HTTPStatus
|
|
||||||
from itertools import chain
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import orjson
|
|
||||||
import uvicorn
|
|
||||||
from fastapi import FastAPI, HTTPException
|
|
||||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
|
||||||
|
|
||||||
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
|
||||||
from sglang.srt.utils import maybe_wrap_ipv6_address
|
|
||||||
|
|
||||||
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
|
||||||
1024 * 64
|
|
||||||
) # 64KB, to prevent aiohttp's "Chunk too big" error
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logger():
|
|
||||||
logger = logging.getLogger("pdlb")
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
)
|
|
||||||
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
|
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class PrefillConfig:
|
|
||||||
url: str
|
|
||||||
bootstrap_port: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class MiniLoadBalancer:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prefill_configs: List[PrefillConfig],
|
|
||||||
decode_servers: List[str],
|
|
||||||
timeout: int,
|
|
||||||
):
|
|
||||||
self.prefill_configs = prefill_configs
|
|
||||||
self.prefill_servers = [p.url for p in prefill_configs]
|
|
||||||
self.decode_servers = decode_servers
|
|
||||||
self.timeout = timeout
|
|
||||||
|
|
||||||
def add_prefill_server(self, new_prefill_config: PrefillConfig):
|
|
||||||
self.prefill_configs.append(new_prefill_config)
|
|
||||||
self.prefill_servers.append(new_prefill_config.url)
|
|
||||||
|
|
||||||
def add_decode_server(self, new_decode_server: str):
|
|
||||||
self.decode_servers.append(new_decode_server)
|
|
||||||
|
|
||||||
def select_pair(self):
|
|
||||||
# TODO: return some message instead of panic
|
|
||||||
assert len(self.prefill_configs) > 0, "No prefill servers available"
|
|
||||||
assert len(self.decode_servers) > 0, "No decode servers available"
|
|
||||||
|
|
||||||
prefill_config = random.choice(self.prefill_configs)
|
|
||||||
decode_server = random.choice(self.decode_servers)
|
|
||||||
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self, modified_request, prefill_server, decode_server, endpoint
|
|
||||||
) -> ORJSONResponse:
|
|
||||||
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
|
||||||
timeout=aiohttp.ClientTimeout(
|
|
||||||
total=self.timeout
|
|
||||||
) # Add timeout for request reliability
|
|
||||||
) as session:
|
|
||||||
tasks = [
|
|
||||||
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
|
||||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Wait for both responses to complete. Prefill should end first.
|
|
||||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
if "return_logprob" in modified_request:
|
|
||||||
|
|
||||||
prefill_json = await prefill_response.json()
|
|
||||||
ret_json = await decode_response.json()
|
|
||||||
|
|
||||||
# merge `meta_info.input_token_logprobs` from prefill to decode
|
|
||||||
if "meta_info" in ret_json:
|
|
||||||
if "input_token_logprobs" in ret_json["meta_info"]:
|
|
||||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
|
||||||
prefill_json["meta_info"]["input_token_logprobs"]
|
|
||||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ret_json = await decode_response.json()
|
|
||||||
|
|
||||||
return ORJSONResponse(
|
|
||||||
content=ret_json,
|
|
||||||
status_code=decode_response.status,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def generate_stream(
|
|
||||||
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
|
||||||
):
|
|
||||||
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
|
||||||
|
|
||||||
async def stream_results():
|
|
||||||
async with aiohttp.ClientSession(
|
|
||||||
timeout=aiohttp.ClientTimeout(
|
|
||||||
total=self.timeout
|
|
||||||
) # Add timeout for request reliability
|
|
||||||
) as session:
|
|
||||||
# Create the tasks for both prefill and decode requests
|
|
||||||
tasks = [
|
|
||||||
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
|
||||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
|
||||||
]
|
|
||||||
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
|
||||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
if modified_request.get("return_logprob", False):
|
|
||||||
prefill_chunks = []
|
|
||||||
async for chunk in prefill_response.content:
|
|
||||||
prefill_chunks.append(chunk)
|
|
||||||
|
|
||||||
first_prefill_chunk = (
|
|
||||||
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
|
||||||
)
|
|
||||||
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
|
||||||
|
|
||||||
async for chunk in decode_response.content:
|
|
||||||
# Note: This is inefficient
|
|
||||||
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
|
||||||
decoded_chunk = chunk.decode("utf-8")
|
|
||||||
if (
|
|
||||||
decoded_chunk
|
|
||||||
and decoded_chunk.startswith("data:")
|
|
||||||
and "[DONE]" not in decoded_chunk
|
|
||||||
):
|
|
||||||
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
|
||||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
|
||||||
first_prefill_chunk_json["meta_info"][
|
|
||||||
"input_token_logprobs"
|
|
||||||
]
|
|
||||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
|
||||||
)
|
|
||||||
|
|
||||||
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
|
||||||
else:
|
|
||||||
yield chunk
|
|
||||||
else:
|
|
||||||
async for chunk in decode_response.content.iter_chunked(
|
|
||||||
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
|
||||||
):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
stream_results(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
load_balancer: Optional[MiniLoadBalancer] = None
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health_check():
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health_generate")
|
|
||||||
async def health_generate():
|
|
||||||
prefill_servers, decode_servers = (
|
|
||||||
load_balancer.prefill_servers,
|
|
||||||
load_balancer.decode_servers,
|
|
||||||
)
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
# Create the tasks
|
|
||||||
tasks = []
|
|
||||||
for server in chain(prefill_servers, decode_servers):
|
|
||||||
tasks.append(session.get(f"{server}/health_generate"))
|
|
||||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
|
||||||
await response
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/flush_cache")
|
|
||||||
async def flush_cache():
|
|
||||||
prefill_servers, decode_servers = (
|
|
||||||
load_balancer.prefill_servers,
|
|
||||||
load_balancer.decode_servers,
|
|
||||||
)
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
# Create the tasks
|
|
||||||
tasks = []
|
|
||||||
for server in chain(prefill_servers, decode_servers):
|
|
||||||
tasks.append(session.post(f"{server}/flush_cache"))
|
|
||||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
|
||||||
await response
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/get_server_info")
|
|
||||||
async def get_server_info():
|
|
||||||
prefill_servers, decode_servers = (
|
|
||||||
load_balancer.prefill_servers,
|
|
||||||
load_balancer.decode_servers,
|
|
||||||
)
|
|
||||||
prefill_infos = []
|
|
||||||
decode_infos = []
|
|
||||||
all_internal_states = []
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
for server in chain(prefill_servers):
|
|
||||||
server_info = await session.get(f"{server}/get_server_info")
|
|
||||||
prefill_infos.append(await server_info.json())
|
|
||||||
for server in chain(decode_servers):
|
|
||||||
server_info = await session.get(f"{server}/get_server_info")
|
|
||||||
info_json = await server_info.json()
|
|
||||||
decode_infos.append(info_json)
|
|
||||||
# Extract internal_states from decode servers
|
|
||||||
if "internal_states" in info_json:
|
|
||||||
all_internal_states.extend(info_json["internal_states"])
|
|
||||||
|
|
||||||
# Return format expected by bench_one_batch_server.py
|
|
||||||
if all_internal_states:
|
|
||||||
return {
|
|
||||||
"internal_states": all_internal_states,
|
|
||||||
"prefill": prefill_infos,
|
|
||||||
"decode": decode_infos,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# Fallback with dummy data if no internal states found
|
|
||||||
return {
|
|
||||||
"internal_states": [
|
|
||||||
{
|
|
||||||
"last_gen_throughput": 0.0,
|
|
||||||
"avg_spec_accept_length": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"prefill": prefill_infos,
|
|
||||||
"decode": decode_infos,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/get_model_info")
|
|
||||||
async def get_model_info():
|
|
||||||
global load_balancer
|
|
||||||
|
|
||||||
if not load_balancer or not load_balancer.prefill_servers:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
|
||||||
detail="There is no server registered",
|
|
||||||
)
|
|
||||||
|
|
||||||
target_server_url = load_balancer.prefill_servers[0]
|
|
||||||
endpoint_url = f"{target_server_url}/get_model_info"
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
try:
|
|
||||||
async with session.get(endpoint_url) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
error_text = await response.text()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.BAD_GATEWAY,
|
|
||||||
detail=(
|
|
||||||
f"Failed to get model info from {target_server_url}"
|
|
||||||
f"Status: {response.status}, Response: {error_text}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
model_info_json = await response.json()
|
|
||||||
return ORJSONResponse(content=model_info_json)
|
|
||||||
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
|
||||||
detail=f"Failed to get model info from backend",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate")
|
|
||||||
async def handle_generate_request(request_data: dict):
|
|
||||||
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
|
||||||
|
|
||||||
# Parse and transform prefill_server for bootstrap data
|
|
||||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
|
||||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
|
||||||
modified_request = request_data.copy()
|
|
||||||
|
|
||||||
batch_size = _get_request_batch_size(modified_request)
|
|
||||||
if batch_size is not None:
|
|
||||||
modified_request.update(
|
|
||||||
{
|
|
||||||
"bootstrap_host": [hostname] * batch_size,
|
|
||||||
"bootstrap_port": [bootstrap_port] * batch_size,
|
|
||||||
"bootstrap_room": [
|
|
||||||
_generate_bootstrap_room() for _ in range(batch_size)
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
modified_request.update(
|
|
||||||
{
|
|
||||||
"bootstrap_host": hostname,
|
|
||||||
"bootstrap_port": bootstrap_port,
|
|
||||||
"bootstrap_room": _generate_bootstrap_room(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if request_data.get("stream", False):
|
|
||||||
return await load_balancer.generate_stream(
|
|
||||||
modified_request, prefill_server, decode_server, "generate"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return await load_balancer.generate(
|
|
||||||
modified_request, prefill_server, decode_server, "generate"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
|
||||||
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
|
||||||
|
|
||||||
# Parse and transform prefill_server for bootstrap data
|
|
||||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
|
||||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
|
||||||
modified_request = request_data.copy()
|
|
||||||
modified_request.update(
|
|
||||||
{
|
|
||||||
"bootstrap_host": hostname,
|
|
||||||
"bootstrap_port": bootstrap_port,
|
|
||||||
"bootstrap_room": _generate_bootstrap_room(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if request_data.get("stream", False):
|
|
||||||
return await load_balancer.generate_stream(
|
|
||||||
modified_request,
|
|
||||||
prefill_server,
|
|
||||||
decode_server,
|
|
||||||
endpoint=endpoint_name,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return await load_balancer.generate(
|
|
||||||
modified_request,
|
|
||||||
prefill_server,
|
|
||||||
decode_server,
|
|
||||||
endpoint=endpoint_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
|
||||||
async def handle_chat_completion_request(request_data: dict):
|
|
||||||
return await _forward_to_backend(request_data, "v1/chat/completions")
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
|
||||||
async def handle_completion_request(request_data: dict):
|
|
||||||
return await _forward_to_backend(request_data, "v1/completions")
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_bootstrap_room():
|
|
||||||
return random.randint(0, 2**63 - 1)
|
|
||||||
|
|
||||||
|
|
||||||
# We may utilize `GenerateReqInput`'s logic later
|
|
||||||
def _get_request_batch_size(request):
|
|
||||||
if (text := request.get("text")) is not None:
|
|
||||||
return None if isinstance(text, str) else len(text)
|
|
||||||
if (input_ids := request.get("input_ids")) is not None:
|
|
||||||
return None if isinstance(input_ids[0], int) else len(input_ids)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
|
||||||
async def get_models():
|
|
||||||
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
try:
|
|
||||||
response = await session.get(f"{prefill_server}/v1/models")
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=response.status,
|
|
||||||
detail=f"Prefill server error: Status {response.status}",
|
|
||||||
)
|
|
||||||
return ORJSONResponse(content=await response.json())
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/register")
|
|
||||||
async def register(obj: PDRegistryRequest):
|
|
||||||
if obj.mode == "prefill":
|
|
||||||
load_balancer.add_prefill_server(
|
|
||||||
PrefillConfig(obj.registry_url, obj.bootstrap_port)
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
|
|
||||||
)
|
|
||||||
elif obj.mode == "decode":
|
|
||||||
load_balancer.add_decode_server(obj.registry_url)
|
|
||||||
logger.info(f"Registered decode server: {obj.registry_url}")
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Invalid mode. Must be either PREFILL or DECODE.",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
|
|
||||||
f"#Decode servers: {len(load_balancer.decode_servers)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
def run(prefill_configs, decode_addrs, host, port, timeout):
|
|
||||||
global load_balancer
|
|
||||||
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs, timeout=timeout)
|
|
||||||
uvicorn.run(app, host=host, port=port)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
|
|
||||||
from sglang.srt.disaggregation.launch_lb import main
|
|
||||||
|
|
||||||
main()
|
|
||||||
|
|||||||
@@ -1,21 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import threading
|
|
||||||
import warnings
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.utils import get_ip, is_npu
|
from sglang.srt.utils import is_npu
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
@@ -305,49 +301,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
|
|||||||
return (num_kv_indices + page_size - 1) // page_size
|
return (num_kv_indices + page_size - 1) // page_size
|
||||||
|
|
||||||
|
|
||||||
#########################
|
|
||||||
# PDLB Registry
|
|
||||||
#########################
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class PDRegistryRequest:
|
|
||||||
"""A request to register a machine itself to the LB."""
|
|
||||||
|
|
||||||
mode: str
|
|
||||||
registry_url: str
|
|
||||||
bootstrap_port: Optional[int] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.mode == "prefill" and self.bootstrap_port is None:
|
|
||||||
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
|
||||||
elif self.mode == "decode" and self.bootstrap_port is not None:
|
|
||||||
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
|
||||||
elif self.mode not in ["prefill", "decode"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_disaggregation_server(
|
|
||||||
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
|
||||||
):
|
|
||||||
boostrap_port = bootstrap_port if mode == "prefill" else None
|
|
||||||
registry_request = PDRegistryRequest(
|
|
||||||
mode=mode,
|
|
||||||
registry_url=f"http://{get_ip()}:{server_port}",
|
|
||||||
bootstrap_port=boostrap_port,
|
|
||||||
)
|
|
||||||
res = requests.post(
|
|
||||||
f"{pdlb_url}/register",
|
|
||||||
json=dataclasses.asdict(registry_request),
|
|
||||||
)
|
|
||||||
if res.status_code != 200:
|
|
||||||
warnings.warn(
|
|
||||||
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
#########################
|
#########################
|
||||||
# Misc
|
# Misc
|
||||||
#########################
|
#########################
|
||||||
|
|||||||
@@ -47,11 +47,7 @@ from fastapi.exceptions import RequestValidationError
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
||||||
FAKE_BOOTSTRAP_HOST,
|
|
||||||
DisaggregationMode,
|
|
||||||
register_disaggregation_server,
|
|
||||||
)
|
|
||||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
@@ -1405,13 +1401,5 @@ def _wait_and_warmup(
|
|||||||
if server_args.debug_tensor_dump_input_file:
|
if server_args.debug_tensor_dump_input_file:
|
||||||
kill_process_tree(os.getpid())
|
kill_process_tree(os.getpid())
|
||||||
|
|
||||||
if server_args.pdlb_url is not None:
|
|
||||||
register_disaggregation_server(
|
|
||||||
server_args.disaggregation_mode,
|
|
||||||
server_args.port,
|
|
||||||
server_args.disaggregation_bootstrap_port,
|
|
||||||
server_args.pdlb_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
if launch_callback is not None:
|
if launch_callback is not None:
|
||||||
launch_callback()
|
launch_callback()
|
||||||
|
|||||||
@@ -367,7 +367,6 @@ class ServerArgs:
|
|||||||
disaggregation_prefill_pp: Optional[int] = 1
|
disaggregation_prefill_pp: Optional[int] = 1
|
||||||
disaggregation_ib_device: Optional[str] = None
|
disaggregation_ib_device: Optional[str] = None
|
||||||
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
||||||
pdlb_url: Optional[str] = None
|
|
||||||
|
|
||||||
# For model weight update
|
# For model weight update
|
||||||
custom_weight_loader: Optional[List[str]] = None
|
custom_weight_loader: Optional[List[str]] = None
|
||||||
@@ -2071,12 +2070,6 @@ class ServerArgs:
|
|||||||
default=ServerArgs.num_reserved_decode_tokens,
|
default=ServerArgs.num_reserved_decode_tokens,
|
||||||
help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
|
help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--pdlb-url",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Custom weight loader
|
# Custom weight loader
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -466,6 +466,25 @@ def try_cached_model(model_repo: str):
|
|||||||
return model_dir if model_dir else model_repo
|
return model_dir if model_dir else model_repo
|
||||||
|
|
||||||
|
|
||||||
|
def popen_with_error_check(command: list[str], allow_exit: bool = False):
|
||||||
|
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
|
|
||||||
|
def _run_and_check():
|
||||||
|
stdout, stderr = process.communicate()
|
||||||
|
|
||||||
|
while process.poll() is None:
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
if not allow_exit or process.returncode != 0:
|
||||||
|
raise Exception(
|
||||||
|
f"{command} exited with code {process.returncode}\n{stdout=}\n{stderr=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
t = threading.Thread(target=_run_and_check)
|
||||||
|
t.start()
|
||||||
|
return process
|
||||||
|
|
||||||
|
|
||||||
def popen_launch_server(
|
def popen_launch_server(
|
||||||
model: str,
|
model: str,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
|
|||||||
@@ -45,6 +45,10 @@ fi
|
|||||||
# Install the main package
|
# Install the main package
|
||||||
$PIP_CMD install -e "python[dev]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX
|
$PIP_CMD install -e "python[dev]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX
|
||||||
|
|
||||||
|
# Install router for pd-disagg test
|
||||||
|
SGLANG_ROUTER_BUILD_NO_RUST=1 $PIP_CMD install -e "sgl-router" $PIP_INSTALL_SUFFIX
|
||||||
|
|
||||||
|
|
||||||
if [ "$IS_BLACKWELL" = "1" ]; then
|
if [ "$IS_BLACKWELL" = "1" ]; then
|
||||||
# TODO auto determine sgl-kernel version
|
# TODO auto determine sgl-kernel version
|
||||||
SGL_KERNEL_VERSION=0.3.8
|
SGL_KERNEL_VERSION=0.3.8
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
# a lightweihgt wrapper on router with argument type and comments
|
|
||||||
# no wrapper on policy type => direct export
|
|
||||||
from sglang_router.router import Router
|
|
||||||
from sglang_router.version import __version__
|
from sglang_router.version import __version__
|
||||||
from sglang_router_rs import PolicyType
|
|
||||||
|
|
||||||
__all__ = ["Router", "PolicyType", "__version__"]
|
__all__ = ["__version__"]
|
||||||
|
|||||||
@@ -1,654 +1,22 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import Dict, List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from sglang_router import Router
|
import setproctitle
|
||||||
from sglang_router_rs import PolicyType
|
from sglang_router.mini_lb import MiniLoadBalancer
|
||||||
|
from sglang_router.router_args import RouterArgs
|
||||||
|
|
||||||
|
logger = logging.getLogger("router")
|
||||||
|
|
||||||
def setup_logger():
|
try:
|
||||||
logger = logging.getLogger("router")
|
from sglang_router.router import Router
|
||||||
logger.setLevel(logging.INFO)
|
except ImportError:
|
||||||
|
Router = None
|
||||||
formatter = logging.Formatter(
|
logger.warning(
|
||||||
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
"Rust Router is not installed, only python MiniLB (debugging only) is available"
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
|
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class RouterArgs:
|
|
||||||
# Worker configuration
|
|
||||||
worker_urls: List[str] = dataclasses.field(default_factory=list)
|
|
||||||
host: str = "127.0.0.1"
|
|
||||||
port: int = 30000
|
|
||||||
|
|
||||||
# PD-specific configuration
|
|
||||||
pd_disaggregation: bool = False # Enable PD disaggregated mode
|
|
||||||
prefill_urls: List[tuple] = dataclasses.field(
|
|
||||||
default_factory=list
|
|
||||||
) # List of (url, bootstrap_port)
|
|
||||||
decode_urls: List[str] = dataclasses.field(default_factory=list)
|
|
||||||
|
|
||||||
# Routing policy
|
|
||||||
policy: str = "cache_aware"
|
|
||||||
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
|
|
||||||
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
|
|
||||||
worker_startup_timeout_secs: int = 600
|
|
||||||
worker_startup_check_interval: int = 30
|
|
||||||
cache_threshold: float = 0.3
|
|
||||||
balance_abs_threshold: int = 64
|
|
||||||
balance_rel_threshold: float = 1.5
|
|
||||||
eviction_interval: int = 120
|
|
||||||
max_tree_size: int = 2**26
|
|
||||||
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
|
|
||||||
dp_aware: bool = False
|
|
||||||
api_key: Optional[str] = None
|
|
||||||
log_dir: Optional[str] = None
|
|
||||||
log_level: Optional[str] = None
|
|
||||||
# Service discovery configuration
|
|
||||||
service_discovery: bool = False
|
|
||||||
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
|
||||||
service_discovery_port: int = 80
|
|
||||||
service_discovery_namespace: Optional[str] = None
|
|
||||||
# PD service discovery configuration
|
|
||||||
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
|
||||||
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
|
||||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
|
|
||||||
# Prometheus configuration
|
|
||||||
prometheus_port: Optional[int] = None
|
|
||||||
prometheus_host: Optional[str] = None
|
|
||||||
# Request ID headers configuration
|
|
||||||
request_id_headers: Optional[List[str]] = None
|
|
||||||
# Request timeout in seconds
|
|
||||||
request_timeout_secs: int = 1800
|
|
||||||
# Max concurrent requests for rate limiting
|
|
||||||
max_concurrent_requests: int = 256
|
|
||||||
# Queue size for pending requests when max concurrent limit reached
|
|
||||||
queue_size: int = 100
|
|
||||||
# Maximum time (in seconds) a request can wait in queue before timing out
|
|
||||||
queue_timeout_secs: int = 60
|
|
||||||
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
|
|
||||||
rate_limit_tokens_per_second: Optional[int] = None
|
|
||||||
# CORS allowed origins
|
|
||||||
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
|
|
||||||
# Retry configuration
|
|
||||||
retry_max_retries: int = 5
|
|
||||||
retry_initial_backoff_ms: int = 50
|
|
||||||
retry_max_backoff_ms: int = 30_000
|
|
||||||
retry_backoff_multiplier: float = 1.5
|
|
||||||
retry_jitter_factor: float = 0.2
|
|
||||||
disable_retries: bool = False
|
|
||||||
# Health check configuration
|
|
||||||
health_failure_threshold: int = 3
|
|
||||||
health_success_threshold: int = 2
|
|
||||||
health_check_timeout_secs: int = 5
|
|
||||||
health_check_interval_secs: int = 60
|
|
||||||
health_check_endpoint: str = "/health"
|
|
||||||
# Circuit breaker configuration
|
|
||||||
cb_failure_threshold: int = 10
|
|
||||||
cb_success_threshold: int = 3
|
|
||||||
cb_timeout_duration_secs: int = 60
|
|
||||||
cb_window_duration_secs: int = 120
|
|
||||||
disable_circuit_breaker: bool = False
|
|
||||||
# Tokenizer configuration
|
|
||||||
model_path: Optional[str] = None
|
|
||||||
tokenizer_path: Optional[str] = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def add_cli_args(
|
|
||||||
parser: argparse.ArgumentParser,
|
|
||||||
use_router_prefix: bool = False,
|
|
||||||
exclude_host_port: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Add router-specific arguments to an argument parser.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
parser: The argument parser to add arguments to
|
|
||||||
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
|
|
||||||
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
|
|
||||||
"""
|
|
||||||
prefix = "router-" if use_router_prefix else ""
|
|
||||||
|
|
||||||
# Worker configuration
|
|
||||||
if not exclude_host_port:
|
|
||||||
parser.add_argument(
|
|
||||||
"--host",
|
|
||||||
type=str,
|
|
||||||
default=RouterArgs.host,
|
|
||||||
help="Host address to bind the router server",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--port",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.port,
|
|
||||||
help="Port number to bind the router server",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--worker-urls",
|
|
||||||
type=str,
|
|
||||||
nargs="*",
|
|
||||||
default=[],
|
|
||||||
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Routing policy configuration
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}policy",
|
|
||||||
type=str,
|
|
||||||
default=RouterArgs.policy,
|
|
||||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
|
||||||
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}prefill-policy",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
|
||||||
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}decode-policy",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
|
||||||
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
|
|
||||||
)
|
|
||||||
|
|
||||||
# PD-specific arguments
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}pd-disaggregation",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable PD (Prefill-Decode) disaggregated mode",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}prefill",
|
|
||||||
nargs="+",
|
|
||||||
action="append",
|
|
||||||
help="Prefill server URL and optional bootstrap port. Can be specified multiple times. "
|
|
||||||
"Format: --prefill URL [BOOTSTRAP_PORT]. "
|
|
||||||
"BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}decode",
|
|
||||||
nargs=1,
|
|
||||||
action="append",
|
|
||||||
metavar=("URL",),
|
|
||||||
help="Decode server URL. Can be specified multiple times.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}worker-startup-timeout-secs",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.worker_startup_timeout_secs,
|
|
||||||
help="Timeout in seconds for worker startup",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}worker-startup-check-interval",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.worker_startup_check_interval,
|
|
||||||
help="Interval in seconds between checks for worker startup",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}cache-threshold",
|
|
||||||
type=float,
|
|
||||||
default=RouterArgs.cache_threshold,
|
|
||||||
help="Cache threshold (0.0-1.0) for cache-aware routing",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}balance-abs-threshold",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.balance_abs_threshold,
|
|
||||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}balance-rel-threshold",
|
|
||||||
type=float,
|
|
||||||
default=RouterArgs.balance_rel_threshold,
|
|
||||||
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}eviction-interval",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.eviction_interval,
|
|
||||||
help="Interval in seconds between cache eviction operations",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}max-tree-size",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.max_tree_size,
|
|
||||||
help="Maximum size of the approximation tree for cache-aware routing",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}max-payload-size",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.max_payload_size,
|
|
||||||
help="Maximum payload size in bytes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}dp-aware",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable data parallelism aware schedule",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}api-key",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}log-dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Directory to store log files. If not specified, logs are only output to console.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}log-level",
|
|
||||||
type=str,
|
|
||||||
default="info",
|
|
||||||
choices=["debug", "info", "warning", "error", "critical"],
|
|
||||||
help="Set the logging level. If not specified, defaults to INFO.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}service-discovery",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable Kubernetes service discovery",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}selector",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}service-discovery-port",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.service_discovery_port,
|
|
||||||
help="Port to use for discovered worker pods",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}service-discovery-namespace",
|
|
||||||
type=str,
|
|
||||||
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}prefill-selector",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}decode-selector",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
|
|
||||||
)
|
|
||||||
# Prometheus configuration
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}prometheus-port",
|
|
||||||
type=int,
|
|
||||||
default=29000,
|
|
||||||
help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}prometheus-host",
|
|
||||||
type=str,
|
|
||||||
default="127.0.0.1",
|
|
||||||
help="Host address to bind the Prometheus metrics server",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}request-id-headers",
|
|
||||||
type=str,
|
|
||||||
nargs="*",
|
|
||||||
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}request-timeout-secs",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.request_timeout_secs,
|
|
||||||
help="Request timeout in seconds",
|
|
||||||
)
|
|
||||||
# Retry configuration
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}retry-max-retries",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.retry_max_retries,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}retry-initial-backoff-ms",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.retry_initial_backoff_ms,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}retry-max-backoff-ms",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.retry_max_backoff_ms,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}retry-backoff-multiplier",
|
|
||||||
type=float,
|
|
||||||
default=RouterArgs.retry_backoff_multiplier,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}retry-jitter-factor",
|
|
||||||
type=float,
|
|
||||||
default=RouterArgs.retry_jitter_factor,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}disable-retries",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable retries (equivalent to setting retry_max_retries=1)",
|
|
||||||
)
|
|
||||||
# Circuit breaker configuration
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}cb-failure-threshold",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.cb_failure_threshold,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}cb-success-threshold",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.cb_success_threshold,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}cb-timeout-duration-secs",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.cb_timeout_duration_secs,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}cb-window-duration-secs",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.cb_window_duration_secs,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}disable-circuit-breaker",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
|
|
||||||
)
|
|
||||||
# Health check configuration
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}health-failure-threshold",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.health_failure_threshold,
|
|
||||||
help="Number of consecutive health check failures before marking worker unhealthy",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}health-success-threshold",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.health_success_threshold,
|
|
||||||
help="Number of consecutive health check successes before marking worker healthy",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}health-check-timeout-secs",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.health_check_timeout_secs,
|
|
||||||
help="Timeout in seconds for health check requests",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}health-check-interval-secs",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.health_check_interval_secs,
|
|
||||||
help="Interval in seconds between runtime health checks",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}health-check-endpoint",
|
|
||||||
type=str,
|
|
||||||
default=RouterArgs.health_check_endpoint,
|
|
||||||
help="Health check endpoint path",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}max-concurrent-requests",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.max_concurrent_requests,
|
|
||||||
help="Maximum number of concurrent requests allowed (for rate limiting)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}queue-size",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.queue_size,
|
|
||||||
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}queue-timeout-secs",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.queue_timeout_secs,
|
|
||||||
help="Maximum time (in seconds) a request can wait in queue before timing out",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}rate-limit-tokens-per-second",
|
|
||||||
type=int,
|
|
||||||
default=RouterArgs.rate_limit_tokens_per_second,
|
|
||||||
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}cors-allowed-origins",
|
|
||||||
type=str,
|
|
||||||
nargs="*",
|
|
||||||
default=[],
|
|
||||||
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
|
|
||||||
)
|
|
||||||
# Tokenizer configuration
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}model-path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Model path for loading tokenizer (HuggingFace model ID or local path)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{prefix}tokenizer-path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_cli_args(
|
|
||||||
cls, args: argparse.Namespace, use_router_prefix: bool = False
|
|
||||||
) -> "RouterArgs":
|
|
||||||
"""
|
|
||||||
Create RouterArgs instance from parsed command line arguments.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: Parsed command line arguments
|
|
||||||
use_router_prefix: If True, look for arguments with 'router-' prefix
|
|
||||||
"""
|
|
||||||
prefix = "router_" if use_router_prefix else ""
|
|
||||||
worker_urls = getattr(args, "worker_urls", [])
|
|
||||||
|
|
||||||
# Parse PD URLs
|
|
||||||
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
|
|
||||||
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
worker_urls=worker_urls,
|
|
||||||
host=args.host,
|
|
||||||
port=args.port,
|
|
||||||
pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False),
|
|
||||||
prefill_urls=prefill_urls,
|
|
||||||
decode_urls=decode_urls,
|
|
||||||
policy=getattr(args, f"{prefix}policy"),
|
|
||||||
prefill_policy=getattr(args, f"{prefix}prefill_policy", None),
|
|
||||||
decode_policy=getattr(args, f"{prefix}decode_policy", None),
|
|
||||||
worker_startup_timeout_secs=getattr(
|
|
||||||
args, f"{prefix}worker_startup_timeout_secs"
|
|
||||||
),
|
|
||||||
worker_startup_check_interval=getattr(
|
|
||||||
args, f"{prefix}worker_startup_check_interval"
|
|
||||||
),
|
|
||||||
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
|
|
||||||
balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"),
|
|
||||||
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
|
|
||||||
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
|
|
||||||
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
|
|
||||||
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
|
|
||||||
dp_aware=getattr(args, f"{prefix}dp_aware", False),
|
|
||||||
api_key=getattr(args, f"{prefix}api_key", None),
|
|
||||||
log_dir=getattr(args, f"{prefix}log_dir", None),
|
|
||||||
log_level=getattr(args, f"{prefix}log_level", None),
|
|
||||||
service_discovery=getattr(args, f"{prefix}service_discovery", False),
|
|
||||||
selector=cls._parse_selector(getattr(args, f"{prefix}selector", None)),
|
|
||||||
service_discovery_port=getattr(args, f"{prefix}service_discovery_port"),
|
|
||||||
service_discovery_namespace=getattr(
|
|
||||||
args, f"{prefix}service_discovery_namespace", None
|
|
||||||
),
|
|
||||||
prefill_selector=cls._parse_selector(
|
|
||||||
getattr(args, f"{prefix}prefill_selector", None)
|
|
||||||
),
|
|
||||||
decode_selector=cls._parse_selector(
|
|
||||||
getattr(args, f"{prefix}decode_selector", None)
|
|
||||||
),
|
|
||||||
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
|
|
||||||
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
|
|
||||||
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
|
|
||||||
request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
|
|
||||||
request_timeout_secs=getattr(
|
|
||||||
args, f"{prefix}request_timeout_secs", RouterArgs.request_timeout_secs
|
|
||||||
),
|
|
||||||
max_concurrent_requests=getattr(
|
|
||||||
args,
|
|
||||||
f"{prefix}max_concurrent_requests",
|
|
||||||
RouterArgs.max_concurrent_requests,
|
|
||||||
),
|
|
||||||
queue_size=getattr(
|
|
||||||
args,
|
|
||||||
f"{prefix}queue_size",
|
|
||||||
RouterArgs.queue_size,
|
|
||||||
),
|
|
||||||
queue_timeout_secs=getattr(
|
|
||||||
args,
|
|
||||||
f"{prefix}queue_timeout_secs",
|
|
||||||
RouterArgs.queue_timeout_secs,
|
|
||||||
),
|
|
||||||
rate_limit_tokens_per_second=getattr(
|
|
||||||
args,
|
|
||||||
f"{prefix}rate_limit_tokens_per_second",
|
|
||||||
RouterArgs.rate_limit_tokens_per_second,
|
|
||||||
),
|
|
||||||
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
|
|
||||||
retry_max_retries=getattr(args, f"{prefix}retry_max_retries"),
|
|
||||||
retry_initial_backoff_ms=getattr(args, f"{prefix}retry_initial_backoff_ms"),
|
|
||||||
retry_max_backoff_ms=getattr(args, f"{prefix}retry_max_backoff_ms"),
|
|
||||||
retry_backoff_multiplier=getattr(args, f"{prefix}retry_backoff_multiplier"),
|
|
||||||
retry_jitter_factor=getattr(args, f"{prefix}retry_jitter_factor"),
|
|
||||||
cb_failure_threshold=getattr(args, f"{prefix}cb_failure_threshold"),
|
|
||||||
cb_success_threshold=getattr(args, f"{prefix}cb_success_threshold"),
|
|
||||||
cb_timeout_duration_secs=getattr(args, f"{prefix}cb_timeout_duration_secs"),
|
|
||||||
cb_window_duration_secs=getattr(args, f"{prefix}cb_window_duration_secs"),
|
|
||||||
disable_retries=getattr(args, f"{prefix}disable_retries", False),
|
|
||||||
disable_circuit_breaker=getattr(
|
|
||||||
args, f"{prefix}disable_circuit_breaker", False
|
|
||||||
),
|
|
||||||
health_failure_threshold=getattr(
|
|
||||||
args,
|
|
||||||
f"{prefix}health_failure_threshold",
|
|
||||||
RouterArgs.health_failure_threshold,
|
|
||||||
),
|
|
||||||
health_success_threshold=getattr(
|
|
||||||
args,
|
|
||||||
f"{prefix}health_success_threshold",
|
|
||||||
RouterArgs.health_success_threshold,
|
|
||||||
),
|
|
||||||
health_check_timeout_secs=getattr(
|
|
||||||
args,
|
|
||||||
f"{prefix}health_check_timeout_secs",
|
|
||||||
RouterArgs.health_check_timeout_secs,
|
|
||||||
),
|
|
||||||
health_check_interval_secs=getattr(
|
|
||||||
args,
|
|
||||||
f"{prefix}health_check_interval_secs",
|
|
||||||
RouterArgs.health_check_interval_secs,
|
|
||||||
),
|
|
||||||
health_check_endpoint=getattr(
|
|
||||||
args, f"{prefix}health_check_endpoint", RouterArgs.health_check_endpoint
|
|
||||||
),
|
|
||||||
model_path=getattr(args, f"{prefix}model_path", None),
|
|
||||||
tokenizer_path=getattr(args, f"{prefix}tokenizer_path", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_selector(selector_list):
|
|
||||||
if not selector_list:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
selector = {}
|
|
||||||
for item in selector_list:
|
|
||||||
if "=" in item:
|
|
||||||
key, value = item.split("=", 1)
|
|
||||||
selector[key] = value
|
|
||||||
return selector
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_prefill_urls(prefill_list):
|
|
||||||
"""Parse prefill URLs from --prefill arguments.
|
|
||||||
|
|
||||||
Format: --prefill URL [BOOTSTRAP_PORT]
|
|
||||||
Example:
|
|
||||||
--prefill http://prefill1:8080 9000 # With bootstrap port
|
|
||||||
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
|
|
||||||
--prefill http://prefill3:8080 # Defaults to no bootstrap port
|
|
||||||
"""
|
|
||||||
if not prefill_list:
|
|
||||||
return []
|
|
||||||
|
|
||||||
prefill_urls = []
|
|
||||||
for prefill_args in prefill_list:
|
|
||||||
|
|
||||||
url = prefill_args[0]
|
|
||||||
|
|
||||||
# Handle optional bootstrap port
|
|
||||||
if len(prefill_args) >= 2:
|
|
||||||
bootstrap_port_str = prefill_args[1]
|
|
||||||
# Handle 'none' as None
|
|
||||||
if bootstrap_port_str.lower() == "none":
|
|
||||||
bootstrap_port = None
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
bootstrap_port = int(bootstrap_port_str)
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# No bootstrap port specified, default to None
|
|
||||||
bootstrap_port = None
|
|
||||||
|
|
||||||
prefill_urls.append((url, bootstrap_port))
|
|
||||||
|
|
||||||
return prefill_urls
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_decode_urls(decode_list):
|
|
||||||
"""Parse decode URLs from --decode arguments.
|
|
||||||
|
|
||||||
Format: --decode URL
|
|
||||||
Example: --decode http://decode1:8081 --decode http://decode2:8081
|
|
||||||
"""
|
|
||||||
if not decode_list:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# decode_list is a list of single-element lists due to nargs=1
|
|
||||||
return [url[0] for url in decode_list]
|
|
||||||
|
|
||||||
|
|
||||||
def policy_from_str(policy_str: str) -> PolicyType:
|
|
||||||
"""Convert policy string to PolicyType enum."""
|
|
||||||
policy_map = {
|
|
||||||
"random": PolicyType.Random,
|
|
||||||
"round_robin": PolicyType.RoundRobin,
|
|
||||||
"cache_aware": PolicyType.CacheAware,
|
|
||||||
"power_of_two": PolicyType.PowerOfTwo,
|
|
||||||
}
|
|
||||||
return policy_map[policy_str]
|
|
||||||
|
|
||||||
|
|
||||||
def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||||
"""
|
"""
|
||||||
@@ -661,7 +29,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
Returns:
|
Returns:
|
||||||
Router instance if successful, None if failed
|
Router instance if successful, None if failed
|
||||||
"""
|
"""
|
||||||
logger = logging.getLogger("router")
|
setproctitle.setproctitle("sglang::router")
|
||||||
try:
|
try:
|
||||||
# Convert to RouterArgs if needed
|
# Convert to RouterArgs if needed
|
||||||
if not isinstance(args, RouterArgs):
|
if not isinstance(args, RouterArgs):
|
||||||
@@ -669,120 +37,15 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
else:
|
else:
|
||||||
router_args = args
|
router_args = args
|
||||||
|
|
||||||
# Validate configuration based on mode
|
if router_args.mini_lb:
|
||||||
if router_args.pd_disaggregation:
|
mini_lb = MiniLoadBalancer(router_args)
|
||||||
# Validate PD configuration - skip URL requirements if using service discovery
|
mini_lb.start()
|
||||||
if not router_args.service_discovery:
|
else:
|
||||||
if not router_args.prefill_urls:
|
if Router is None:
|
||||||
raise ValueError("PD disaggregation mode requires --prefill")
|
raise RuntimeError("Rust Router is not installed")
|
||||||
if not router_args.decode_urls:
|
router_args._validate_router_args()
|
||||||
raise ValueError("PD disaggregation mode requires --decode")
|
router = Router.from_args(router_args)
|
||||||
|
router.start()
|
||||||
# Warn about policy usage in PD mode
|
|
||||||
if (
|
|
||||||
router_args.prefill_policy
|
|
||||||
and router_args.decode_policy
|
|
||||||
and router_args.policy
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"Both --prefill-policy and --decode-policy are specified. "
|
|
||||||
"The main --policy flag will be ignored for PD mode."
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
router_args.prefill_policy
|
|
||||||
and not router_args.decode_policy
|
|
||||||
and router_args.policy
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes "
|
|
||||||
f"and --policy '{router_args.policy}' for decode nodes."
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
router_args.decode_policy
|
|
||||||
and not router_args.prefill_policy
|
|
||||||
and router_args.policy
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
f"Using --policy '{router_args.policy}' for prefill nodes "
|
|
||||||
f"and --decode-policy '{router_args.decode_policy}' for decode nodes."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create router with unified constructor
|
|
||||||
router = Router(
|
|
||||||
worker_urls=(
|
|
||||||
[]
|
|
||||||
if router_args.service_discovery or router_args.pd_disaggregation
|
|
||||||
else router_args.worker_urls
|
|
||||||
),
|
|
||||||
host=router_args.host,
|
|
||||||
port=router_args.port,
|
|
||||||
policy=policy_from_str(router_args.policy),
|
|
||||||
worker_startup_timeout_secs=router_args.worker_startup_timeout_secs,
|
|
||||||
worker_startup_check_interval=router_args.worker_startup_check_interval,
|
|
||||||
cache_threshold=router_args.cache_threshold,
|
|
||||||
balance_abs_threshold=router_args.balance_abs_threshold,
|
|
||||||
balance_rel_threshold=router_args.balance_rel_threshold,
|
|
||||||
eviction_interval_secs=router_args.eviction_interval,
|
|
||||||
max_tree_size=router_args.max_tree_size,
|
|
||||||
max_payload_size=router_args.max_payload_size,
|
|
||||||
dp_aware=router_args.dp_aware,
|
|
||||||
api_key=router_args.api_key,
|
|
||||||
log_dir=router_args.log_dir,
|
|
||||||
log_level=router_args.log_level,
|
|
||||||
service_discovery=router_args.service_discovery,
|
|
||||||
selector=router_args.selector,
|
|
||||||
service_discovery_port=router_args.service_discovery_port,
|
|
||||||
service_discovery_namespace=router_args.service_discovery_namespace,
|
|
||||||
prefill_selector=router_args.prefill_selector,
|
|
||||||
decode_selector=router_args.decode_selector,
|
|
||||||
prometheus_port=router_args.prometheus_port,
|
|
||||||
prometheus_host=router_args.prometheus_host,
|
|
||||||
request_timeout_secs=router_args.request_timeout_secs,
|
|
||||||
pd_disaggregation=router_args.pd_disaggregation,
|
|
||||||
prefill_urls=(
|
|
||||||
router_args.prefill_urls if router_args.pd_disaggregation else None
|
|
||||||
),
|
|
||||||
decode_urls=(
|
|
||||||
router_args.decode_urls if router_args.pd_disaggregation else None
|
|
||||||
),
|
|
||||||
prefill_policy=(
|
|
||||||
policy_from_str(router_args.prefill_policy)
|
|
||||||
if router_args.prefill_policy
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
decode_policy=(
|
|
||||||
policy_from_str(router_args.decode_policy)
|
|
||||||
if router_args.decode_policy
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
request_id_headers=router_args.request_id_headers,
|
|
||||||
max_concurrent_requests=router_args.max_concurrent_requests,
|
|
||||||
queue_size=router_args.queue_size,
|
|
||||||
queue_timeout_secs=router_args.queue_timeout_secs,
|
|
||||||
rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second,
|
|
||||||
cors_allowed_origins=router_args.cors_allowed_origins,
|
|
||||||
retry_max_retries=router_args.retry_max_retries,
|
|
||||||
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
|
|
||||||
retry_max_backoff_ms=router_args.retry_max_backoff_ms,
|
|
||||||
retry_backoff_multiplier=router_args.retry_backoff_multiplier,
|
|
||||||
retry_jitter_factor=router_args.retry_jitter_factor,
|
|
||||||
cb_failure_threshold=router_args.cb_failure_threshold,
|
|
||||||
cb_success_threshold=router_args.cb_success_threshold,
|
|
||||||
cb_timeout_duration_secs=router_args.cb_timeout_duration_secs,
|
|
||||||
cb_window_duration_secs=router_args.cb_window_duration_secs,
|
|
||||||
disable_retries=router_args.disable_retries,
|
|
||||||
disable_circuit_breaker=router_args.disable_circuit_breaker,
|
|
||||||
health_failure_threshold=router_args.health_failure_threshold,
|
|
||||||
health_success_threshold=router_args.health_success_threshold,
|
|
||||||
health_check_timeout_secs=router_args.health_check_timeout_secs,
|
|
||||||
health_check_interval_secs=router_args.health_check_interval_secs,
|
|
||||||
health_check_endpoint=router_args.health_check_endpoint,
|
|
||||||
model_path=router_args.model_path,
|
|
||||||
tokenizer_path=router_args.tokenizer_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
router.start()
|
|
||||||
return router
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error starting router: {e}")
|
logger.error(f"Error starting router: {e}")
|
||||||
|
|||||||
395
sgl-router/py_src/sglang_router/mini_lb.py
Normal file
395
sgl-router/py_src/sglang_router/mini_lb.py
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
"""
|
||||||
|
Minimal HTTP load balancer for prefill and decode servers for testing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import ipaddress
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import urllib
|
||||||
|
from http import HTTPStatus
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import orjson
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
|
from sglang_router.router_args import RouterArgs
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
||||||
|
1024 * 64
|
||||||
|
) # 64KB, to prevent aiohttp's "Chunk too big" error
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_wrap_ipv6_address(address: str) -> str:
|
||||||
|
try:
|
||||||
|
ipaddress.IPv6Address(address)
|
||||||
|
return f"[{address}]"
|
||||||
|
except ValueError:
|
||||||
|
return address
|
||||||
|
|
||||||
|
|
||||||
|
class MiniLoadBalancer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
router_args: RouterArgs,
|
||||||
|
):
|
||||||
|
self._validate_router_args(router_args)
|
||||||
|
|
||||||
|
self.host = router_args.host
|
||||||
|
self.port = router_args.port
|
||||||
|
self.timeout = router_args.request_timeout_secs
|
||||||
|
self.prefill_urls = [url[0] for url in router_args.prefill_urls]
|
||||||
|
self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls]
|
||||||
|
self.decode_urls = router_args.decode_urls
|
||||||
|
|
||||||
|
def _validate_router_args(self, router_args: RouterArgs):
|
||||||
|
logger.warning(
|
||||||
|
"\x1b[33mMiniLB is only for debugging purposes, it only supports random policy!\033[0m"
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: too many arguments unsupported, just validate some important ones
|
||||||
|
if router_args.policy != "random":
|
||||||
|
logger.warning("[MiniLB] Overriding policy to random")
|
||||||
|
router_args.policy = "random"
|
||||||
|
|
||||||
|
if not router_args.pd_disaggregation:
|
||||||
|
raise ValueError("MiniLB only supports PD disaggregation mode")
|
||||||
|
|
||||||
|
if len(router_args.prefill_urls) == 0 or len(router_args.decode_urls) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"MiniLB requires at least one prefill and one decode server"
|
||||||
|
)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
global lb
|
||||||
|
lb = self
|
||||||
|
uvicorn.run(app, host=self.host, port=self.port)
|
||||||
|
|
||||||
|
def select_pair(self):
|
||||||
|
assert len(self.prefill_urls) > 0, "No prefill servers available"
|
||||||
|
assert len(self.decode_urls) > 0, "No decode servers available"
|
||||||
|
pidx = random.randint(0, len(self.prefill_urls) - 1)
|
||||||
|
didx = random.randint(0, len(self.decode_urls) - 1)
|
||||||
|
return (
|
||||||
|
self.prefill_urls[pidx],
|
||||||
|
self.prefill_bootstrap_ports[pidx],
|
||||||
|
self.decode_urls[didx],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self, modified_request, prefill_server, decode_server, endpoint
|
||||||
|
) -> ORJSONResponse:
|
||||||
|
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(
|
||||||
|
total=self.timeout
|
||||||
|
) # Add timeout for request reliability
|
||||||
|
) as session:
|
||||||
|
tasks = [
|
||||||
|
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||||
|
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Wait for both responses to complete. Prefill should end first.
|
||||||
|
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
if "return_logprob" in modified_request:
|
||||||
|
|
||||||
|
prefill_json = await prefill_response.json()
|
||||||
|
ret_json = await decode_response.json()
|
||||||
|
|
||||||
|
# merge `meta_info.input_token_logprobs` from prefill to decode
|
||||||
|
if "meta_info" in ret_json:
|
||||||
|
if "input_token_logprobs" in ret_json["meta_info"]:
|
||||||
|
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||||
|
prefill_json["meta_info"]["input_token_logprobs"]
|
||||||
|
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ret_json = await decode_response.json()
|
||||||
|
|
||||||
|
return ORJSONResponse(
|
||||||
|
content=ret_json,
|
||||||
|
status_code=decode_response.status,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
||||||
|
):
|
||||||
|
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||||
|
|
||||||
|
async def stream_results():
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(
|
||||||
|
total=self.timeout
|
||||||
|
) # Add timeout for request reliability
|
||||||
|
) as session:
|
||||||
|
# Create the tasks for both prefill and decode requests
|
||||||
|
tasks = [
|
||||||
|
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||||
|
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||||
|
]
|
||||||
|
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
||||||
|
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
if modified_request.get("return_logprob", False):
|
||||||
|
prefill_chunks = []
|
||||||
|
async for chunk in prefill_response.content:
|
||||||
|
prefill_chunks.append(chunk)
|
||||||
|
|
||||||
|
first_prefill_chunk = (
|
||||||
|
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
||||||
|
)
|
||||||
|
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
||||||
|
|
||||||
|
async for chunk in decode_response.content:
|
||||||
|
# Note: This is inefficient
|
||||||
|
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
||||||
|
decoded_chunk = chunk.decode("utf-8")
|
||||||
|
if (
|
||||||
|
decoded_chunk
|
||||||
|
and decoded_chunk.startswith("data:")
|
||||||
|
and "[DONE]" not in decoded_chunk
|
||||||
|
):
|
||||||
|
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
||||||
|
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||||
|
first_prefill_chunk_json["meta_info"][
|
||||||
|
"input_token_logprobs"
|
||||||
|
]
|
||||||
|
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||||
|
)
|
||||||
|
|
||||||
|
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
||||||
|
else:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
async for chunk in decode_response.content.iter_chunked(
|
||||||
|
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_results(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
lb: Optional[MiniLoadBalancer] = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health_generate")
|
||||||
|
async def health_generate():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Create the tasks
|
||||||
|
tasks = []
|
||||||
|
for server in chain(lb.prefill_urls, lb.decode_urls):
|
||||||
|
tasks.append(session.get(f"{server}/health_generate"))
|
||||||
|
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||||
|
await response
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/flush_cache")
|
||||||
|
async def flush_cache():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Create the tasks
|
||||||
|
tasks = []
|
||||||
|
for server in chain(lb.prefill_urls, lb.decode_urls):
|
||||||
|
tasks.append(session.post(f"{server}/flush_cache"))
|
||||||
|
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||||
|
await response
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_server_info")
|
||||||
|
async def get_server_info():
|
||||||
|
prefill_infos = []
|
||||||
|
decode_infos = []
|
||||||
|
all_internal_states = []
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
for server in lb.prefill_urls:
|
||||||
|
server_info = await session.get(f"{server}/get_server_info")
|
||||||
|
prefill_infos.append(await server_info.json())
|
||||||
|
for server in lb.decode_urls:
|
||||||
|
server_info = await session.get(f"{server}/get_server_info")
|
||||||
|
info_json = await server_info.json()
|
||||||
|
decode_infos.append(info_json)
|
||||||
|
# Extract internal_states from decode servers
|
||||||
|
if "internal_states" in info_json:
|
||||||
|
all_internal_states.extend(info_json["internal_states"])
|
||||||
|
|
||||||
|
# Return format expected by bench_one_batch_server.py
|
||||||
|
if all_internal_states:
|
||||||
|
return {
|
||||||
|
"internal_states": all_internal_states,
|
||||||
|
"prefill": prefill_infos,
|
||||||
|
"decode": decode_infos,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Fallback with dummy data if no internal states found
|
||||||
|
return {
|
||||||
|
"internal_states": [
|
||||||
|
{
|
||||||
|
"last_gen_throughput": 0.0,
|
||||||
|
"avg_spec_accept_length": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"prefill": prefill_infos,
|
||||||
|
"decode": decode_infos,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_model_info")
|
||||||
|
async def get_model_info():
|
||||||
|
if not lb or not lb.prefill_urls:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||||
|
detail="There is no server registered",
|
||||||
|
)
|
||||||
|
|
||||||
|
target_server_url = lb.prefill_urls[0]
|
||||||
|
endpoint_url = f"{target_server_url}/get_model_info"
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
async with session.get(endpoint_url) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.BAD_GATEWAY,
|
||||||
|
detail=(
|
||||||
|
f"Failed to get model info from {target_server_url}"
|
||||||
|
f"Status: {response.status}, Response: {error_text}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
model_info_json = await response.json()
|
||||||
|
return ORJSONResponse(content=model_info_json)
|
||||||
|
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||||
|
detail=f"Failed to get model info from backend",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate")
|
||||||
|
async def handle_generate_request(request_data: dict):
|
||||||
|
prefill_server, bootstrap_port, decode_server = lb.select_pair()
|
||||||
|
|
||||||
|
# Parse and transform prefill_server for bootstrap data
|
||||||
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||||
|
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||||
|
modified_request = request_data.copy()
|
||||||
|
|
||||||
|
batch_size = _get_request_batch_size(modified_request)
|
||||||
|
if batch_size is not None:
|
||||||
|
modified_request.update(
|
||||||
|
{
|
||||||
|
"bootstrap_host": [hostname] * batch_size,
|
||||||
|
"bootstrap_port": [bootstrap_port] * batch_size,
|
||||||
|
"bootstrap_room": [
|
||||||
|
_generate_bootstrap_room() for _ in range(batch_size)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
modified_request.update(
|
||||||
|
{
|
||||||
|
"bootstrap_host": hostname,
|
||||||
|
"bootstrap_port": bootstrap_port,
|
||||||
|
"bootstrap_room": _generate_bootstrap_room(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if request_data.get("stream", False):
|
||||||
|
return await lb.generate_stream(
|
||||||
|
modified_request, prefill_server, decode_server, "generate"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await lb.generate(
|
||||||
|
modified_request, prefill_server, decode_server, "generate"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
||||||
|
prefill_server, bootstrap_port, decode_server = lb.select_pair()
|
||||||
|
|
||||||
|
# Parse and transform prefill_server for bootstrap data
|
||||||
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||||
|
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||||
|
modified_request = request_data.copy()
|
||||||
|
modified_request.update(
|
||||||
|
{
|
||||||
|
"bootstrap_host": hostname,
|
||||||
|
"bootstrap_port": bootstrap_port,
|
||||||
|
"bootstrap_room": _generate_bootstrap_room(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if request_data.get("stream", False):
|
||||||
|
return await lb.generate_stream(
|
||||||
|
modified_request,
|
||||||
|
prefill_server,
|
||||||
|
decode_server,
|
||||||
|
endpoint=endpoint_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await lb.generate(
|
||||||
|
modified_request,
|
||||||
|
prefill_server,
|
||||||
|
decode_server,
|
||||||
|
endpoint=endpoint_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def handle_chat_completion_request(request_data: dict):
|
||||||
|
return await _forward_to_backend(request_data, "v1/chat/completions")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/completions")
|
||||||
|
async def handle_completion_request(request_data: dict):
|
||||||
|
return await _forward_to_backend(request_data, "v1/completions")
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_bootstrap_room():
|
||||||
|
return random.randint(0, 2**63 - 1)
|
||||||
|
|
||||||
|
|
||||||
|
# We may utilize `GenerateReqInput`'s logic later
|
||||||
|
def _get_request_batch_size(request):
|
||||||
|
if (text := request.get("text")) is not None:
|
||||||
|
return None if isinstance(text, str) else len(text)
|
||||||
|
if (input_ids := request.get("input_ids")) is not None:
|
||||||
|
return None if isinstance(input_ids[0], int) else len(input_ids)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
async def get_models():
|
||||||
|
prefill_server = lb.prefill_urls[0] # Get the first prefill server
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
response = await session.get(f"{prefill_server}/v1/models")
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status,
|
||||||
|
detail=f"Prefill server error: Status {response.status}",
|
||||||
|
)
|
||||||
|
return ORJSONResponse(content=await response.json())
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -1,9 +1,23 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from sglang_router.router_args import RouterArgs
|
||||||
from sglang_router_rs import PolicyType
|
from sglang_router_rs import PolicyType
|
||||||
from sglang_router_rs import Router as _Router
|
from sglang_router_rs import Router as _Router
|
||||||
|
|
||||||
|
|
||||||
|
def policy_from_str(policy_str: Optional[str]) -> PolicyType:
|
||||||
|
"""Convert policy string to PolicyType enum."""
|
||||||
|
if policy_str is None:
|
||||||
|
return None
|
||||||
|
policy_map = {
|
||||||
|
"random": PolicyType.Random,
|
||||||
|
"round_robin": PolicyType.RoundRobin,
|
||||||
|
"cache_aware": PolicyType.CacheAware,
|
||||||
|
"power_of_two": PolicyType.PowerOfTwo,
|
||||||
|
}
|
||||||
|
return policy_map[policy_str]
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
"""
|
"""
|
||||||
A high-performance router for distributing requests across worker nodes.
|
A high-performance router for distributing requests across worker nodes.
|
||||||
@@ -78,130 +92,34 @@ class Router:
|
|||||||
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
|
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, router: _Router):
|
||||||
self,
|
self._router = router
|
||||||
worker_urls: List[str],
|
|
||||||
policy: PolicyType = PolicyType.RoundRobin,
|
|
||||||
host: str = "127.0.0.1",
|
|
||||||
port: int = 3001,
|
|
||||||
worker_startup_timeout_secs: int = 600,
|
|
||||||
worker_startup_check_interval: int = 30,
|
|
||||||
cache_threshold: float = 0.3,
|
|
||||||
balance_abs_threshold: int = 64,
|
|
||||||
balance_rel_threshold: float = 1.5,
|
|
||||||
eviction_interval_secs: int = 120,
|
|
||||||
max_tree_size: int = 2**26,
|
|
||||||
max_payload_size: int = 512 * 1024 * 1024, # 512MB
|
|
||||||
dp_aware: bool = False,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
log_dir: Optional[str] = None,
|
|
||||||
log_level: Optional[str] = None,
|
|
||||||
service_discovery: bool = False,
|
|
||||||
selector: Dict[str, str] = None,
|
|
||||||
service_discovery_port: int = 80,
|
|
||||||
service_discovery_namespace: Optional[str] = None,
|
|
||||||
prefill_selector: Dict[str, str] = None,
|
|
||||||
decode_selector: Dict[str, str] = None,
|
|
||||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
|
|
||||||
prometheus_port: Optional[int] = None,
|
|
||||||
prometheus_host: Optional[str] = None,
|
|
||||||
request_timeout_secs: int = 1800,
|
|
||||||
request_id_headers: Optional[List[str]] = None,
|
|
||||||
pd_disaggregation: bool = False,
|
|
||||||
prefill_urls: Optional[List[tuple]] = None,
|
|
||||||
decode_urls: Optional[List[str]] = None,
|
|
||||||
prefill_policy: Optional[PolicyType] = None,
|
|
||||||
decode_policy: Optional[PolicyType] = None,
|
|
||||||
max_concurrent_requests: int = 256,
|
|
||||||
queue_size: int = 100,
|
|
||||||
queue_timeout_secs: int = 60,
|
|
||||||
rate_limit_tokens_per_second: Optional[int] = None,
|
|
||||||
cors_allowed_origins: List[str] = None,
|
|
||||||
retry_max_retries: int = 5,
|
|
||||||
retry_initial_backoff_ms: int = 50,
|
|
||||||
retry_max_backoff_ms: int = 30_000,
|
|
||||||
retry_backoff_multiplier: float = 1.5,
|
|
||||||
retry_jitter_factor: float = 0.2,
|
|
||||||
cb_failure_threshold: int = 10,
|
|
||||||
cb_success_threshold: int = 3,
|
|
||||||
cb_timeout_duration_secs: int = 60,
|
|
||||||
cb_window_duration_secs: int = 120,
|
|
||||||
disable_retries: bool = False,
|
|
||||||
disable_circuit_breaker: bool = False,
|
|
||||||
health_failure_threshold: int = 3,
|
|
||||||
health_success_threshold: int = 2,
|
|
||||||
health_check_timeout_secs: int = 5,
|
|
||||||
health_check_interval_secs: int = 60,
|
|
||||||
health_check_endpoint: str = "/health",
|
|
||||||
model_path: Optional[str] = None,
|
|
||||||
tokenizer_path: Optional[str] = None,
|
|
||||||
):
|
|
||||||
if selector is None:
|
|
||||||
selector = {}
|
|
||||||
if prefill_selector is None:
|
|
||||||
prefill_selector = {}
|
|
||||||
if decode_selector is None:
|
|
||||||
decode_selector = {}
|
|
||||||
if cors_allowed_origins is None:
|
|
||||||
cors_allowed_origins = []
|
|
||||||
|
|
||||||
self._router = _Router(
|
@staticmethod
|
||||||
worker_urls=worker_urls,
|
def from_args(args: RouterArgs) -> "Router":
|
||||||
policy=policy,
|
"""Create a router from a RouterArgs instance."""
|
||||||
host=host,
|
|
||||||
port=port,
|
args_dict = vars(args)
|
||||||
worker_startup_timeout_secs=worker_startup_timeout_secs,
|
# Convert RouterArgs to _Router parameters
|
||||||
worker_startup_check_interval=worker_startup_check_interval,
|
args_dict["worker_urls"] = (
|
||||||
cache_threshold=cache_threshold,
|
[]
|
||||||
balance_abs_threshold=balance_abs_threshold,
|
if args_dict["service_discovery"] or args_dict["pd_disaggregation"]
|
||||||
balance_rel_threshold=balance_rel_threshold,
|
else args_dict["worker_urls"]
|
||||||
eviction_interval_secs=eviction_interval_secs,
|
|
||||||
max_tree_size=max_tree_size,
|
|
||||||
max_payload_size=max_payload_size,
|
|
||||||
dp_aware=dp_aware,
|
|
||||||
api_key=api_key,
|
|
||||||
log_dir=log_dir,
|
|
||||||
log_level=log_level,
|
|
||||||
service_discovery=service_discovery,
|
|
||||||
selector=selector,
|
|
||||||
service_discovery_port=service_discovery_port,
|
|
||||||
service_discovery_namespace=service_discovery_namespace,
|
|
||||||
prefill_selector=prefill_selector,
|
|
||||||
decode_selector=decode_selector,
|
|
||||||
bootstrap_port_annotation=bootstrap_port_annotation,
|
|
||||||
prometheus_port=prometheus_port,
|
|
||||||
prometheus_host=prometheus_host,
|
|
||||||
request_timeout_secs=request_timeout_secs,
|
|
||||||
request_id_headers=request_id_headers,
|
|
||||||
pd_disaggregation=pd_disaggregation,
|
|
||||||
prefill_urls=prefill_urls,
|
|
||||||
decode_urls=decode_urls,
|
|
||||||
prefill_policy=prefill_policy,
|
|
||||||
decode_policy=decode_policy,
|
|
||||||
max_concurrent_requests=max_concurrent_requests,
|
|
||||||
queue_size=queue_size,
|
|
||||||
queue_timeout_secs=queue_timeout_secs,
|
|
||||||
rate_limit_tokens_per_second=rate_limit_tokens_per_second,
|
|
||||||
cors_allowed_origins=cors_allowed_origins,
|
|
||||||
retry_max_retries=retry_max_retries,
|
|
||||||
retry_initial_backoff_ms=retry_initial_backoff_ms,
|
|
||||||
retry_max_backoff_ms=retry_max_backoff_ms,
|
|
||||||
retry_backoff_multiplier=retry_backoff_multiplier,
|
|
||||||
retry_jitter_factor=retry_jitter_factor,
|
|
||||||
cb_failure_threshold=cb_failure_threshold,
|
|
||||||
cb_success_threshold=cb_success_threshold,
|
|
||||||
cb_timeout_duration_secs=cb_timeout_duration_secs,
|
|
||||||
cb_window_duration_secs=cb_window_duration_secs,
|
|
||||||
disable_retries=disable_retries,
|
|
||||||
disable_circuit_breaker=disable_circuit_breaker,
|
|
||||||
health_failure_threshold=health_failure_threshold,
|
|
||||||
health_success_threshold=health_success_threshold,
|
|
||||||
health_check_timeout_secs=health_check_timeout_secs,
|
|
||||||
health_check_interval_secs=health_check_interval_secs,
|
|
||||||
health_check_endpoint=health_check_endpoint,
|
|
||||||
model_path=model_path,
|
|
||||||
tokenizer_path=tokenizer_path,
|
|
||||||
)
|
)
|
||||||
|
args_dict["policy"] = policy_from_str(args_dict["policy"])
|
||||||
|
args_dict["prefill_urls"] = (
|
||||||
|
args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None
|
||||||
|
)
|
||||||
|
args_dict["decode_urls"] = (
|
||||||
|
args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None
|
||||||
|
)
|
||||||
|
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
|
||||||
|
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
|
||||||
|
|
||||||
|
# remoge mini_lb parameter
|
||||||
|
args_dict.pop("mini_lb")
|
||||||
|
|
||||||
|
return Router(_Router(**args_dict))
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
"""Start the router server.
|
"""Start the router server.
|
||||||
|
|||||||
577
sgl-router/py_src/sglang_router/router_args.py
Normal file
577
sgl-router/py_src/sglang_router/router_args.py
Normal file
@@ -0,0 +1,577 @@
|
|||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class RouterArgs:
|
||||||
|
# Worker configuration
|
||||||
|
worker_urls: List[str] = dataclasses.field(default_factory=list)
|
||||||
|
host: str = "127.0.0.1"
|
||||||
|
port: int = 30000
|
||||||
|
|
||||||
|
# PD-specific configuration
|
||||||
|
mini_lb: bool = False
|
||||||
|
pd_disaggregation: bool = False # Enable PD disaggregated mode
|
||||||
|
prefill_urls: List[tuple] = dataclasses.field(
|
||||||
|
default_factory=list
|
||||||
|
) # List of (url, bootstrap_port)
|
||||||
|
decode_urls: List[str] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
|
# Routing policy
|
||||||
|
policy: str = "cache_aware"
|
||||||
|
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
|
||||||
|
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
|
||||||
|
worker_startup_timeout_secs: int = 600
|
||||||
|
worker_startup_check_interval: int = 30
|
||||||
|
cache_threshold: float = 0.3
|
||||||
|
balance_abs_threshold: int = 64
|
||||||
|
balance_rel_threshold: float = 1.5
|
||||||
|
eviction_interval_secs: int = 120
|
||||||
|
max_tree_size: int = 2**26
|
||||||
|
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
|
||||||
|
dp_aware: bool = False
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
log_dir: Optional[str] = None
|
||||||
|
log_level: Optional[str] = None
|
||||||
|
# Service discovery configuration
|
||||||
|
service_discovery: bool = False
|
||||||
|
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||||
|
service_discovery_port: int = 80
|
||||||
|
service_discovery_namespace: Optional[str] = None
|
||||||
|
# PD service discovery configuration
|
||||||
|
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||||
|
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||||
|
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
|
||||||
|
# Prometheus configuration
|
||||||
|
prometheus_port: Optional[int] = None
|
||||||
|
prometheus_host: Optional[str] = None
|
||||||
|
# Request ID headers configuration
|
||||||
|
request_id_headers: Optional[List[str]] = None
|
||||||
|
# Request timeout in seconds
|
||||||
|
request_timeout_secs: int = 1800
|
||||||
|
# Max concurrent requests for rate limiting
|
||||||
|
max_concurrent_requests: int = 256
|
||||||
|
# Queue size for pending requests when max concurrent limit reached
|
||||||
|
queue_size: int = 100
|
||||||
|
# Maximum time (in seconds) a request can wait in queue before timing out
|
||||||
|
queue_timeout_secs: int = 60
|
||||||
|
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
|
||||||
|
rate_limit_tokens_per_second: Optional[int] = None
|
||||||
|
# CORS allowed origins
|
||||||
|
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
|
||||||
|
# Retry configuration
|
||||||
|
retry_max_retries: int = 5
|
||||||
|
retry_initial_backoff_ms: int = 50
|
||||||
|
retry_max_backoff_ms: int = 30_000
|
||||||
|
retry_backoff_multiplier: float = 1.5
|
||||||
|
retry_jitter_factor: float = 0.2
|
||||||
|
disable_retries: bool = False
|
||||||
|
# Health check configuration
|
||||||
|
health_failure_threshold: int = 3
|
||||||
|
health_success_threshold: int = 2
|
||||||
|
health_check_timeout_secs: int = 5
|
||||||
|
health_check_interval_secs: int = 60
|
||||||
|
health_check_endpoint: str = "/health"
|
||||||
|
# Circuit breaker configuration
|
||||||
|
cb_failure_threshold: int = 10
|
||||||
|
cb_success_threshold: int = 3
|
||||||
|
cb_timeout_duration_secs: int = 60
|
||||||
|
cb_window_duration_secs: int = 120
|
||||||
|
disable_circuit_breaker: bool = False
|
||||||
|
# Tokenizer configuration
|
||||||
|
model_path: Optional[str] = None
|
||||||
|
tokenizer_path: Optional[str] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_cli_args(
|
||||||
|
parser: argparse.ArgumentParser,
|
||||||
|
use_router_prefix: bool = False,
|
||||||
|
exclude_host_port: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add router-specific arguments to an argument parser.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser: The argument parser to add arguments to
|
||||||
|
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
|
||||||
|
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
|
||||||
|
"""
|
||||||
|
prefix = "router-" if use_router_prefix else ""
|
||||||
|
|
||||||
|
# Worker configuration
|
||||||
|
if not exclude_host_port:
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default=RouterArgs.host,
|
||||||
|
help="Host address to bind the router server",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.port,
|
||||||
|
help="Port number to bind the router server",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--worker-urls",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
default=[],
|
||||||
|
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Routing policy configuration
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}policy",
|
||||||
|
type=str,
|
||||||
|
default=RouterArgs.policy,
|
||||||
|
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||||
|
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}prefill-policy",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||||
|
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}decode-policy",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||||
|
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
|
||||||
|
)
|
||||||
|
|
||||||
|
# PD-specific arguments
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}mini-lb",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable MiniLB",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}pd-disaggregation",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable PD (Prefill-Decode) disaggregated mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}prefill",
|
||||||
|
nargs="+",
|
||||||
|
action="append",
|
||||||
|
help="Prefill server URL and optional bootstrap port. Can be specified multiple times. "
|
||||||
|
"Format: --prefill URL [BOOTSTRAP_PORT]. "
|
||||||
|
"BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}decode",
|
||||||
|
nargs=1,
|
||||||
|
action="append",
|
||||||
|
metavar=("URL",),
|
||||||
|
help="Decode server URL. Can be specified multiple times.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}worker-startup-timeout-secs",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.worker_startup_timeout_secs,
|
||||||
|
help="Timeout in seconds for worker startup",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}worker-startup-check-interval",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.worker_startup_check_interval,
|
||||||
|
help="Interval in seconds between checks for worker startup",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}cache-threshold",
|
||||||
|
type=float,
|
||||||
|
default=RouterArgs.cache_threshold,
|
||||||
|
help="Cache threshold (0.0-1.0) for cache-aware routing",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}balance-abs-threshold",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.balance_abs_threshold,
|
||||||
|
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}balance-rel-threshold",
|
||||||
|
type=float,
|
||||||
|
default=RouterArgs.balance_rel_threshold,
|
||||||
|
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}eviction-interval-secs",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.eviction_interval_secs,
|
||||||
|
help="Interval in seconds between cache eviction operations",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}max-tree-size",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.max_tree_size,
|
||||||
|
help="Maximum size of the approximation tree for cache-aware routing",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}max-payload-size",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.max_payload_size,
|
||||||
|
help="Maximum payload size in bytes",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}dp-aware",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable data parallelism aware schedule",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}log-dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Directory to store log files. If not specified, logs are only output to console.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}log-level",
|
||||||
|
type=str,
|
||||||
|
default="info",
|
||||||
|
choices=["debug", "info", "warning", "error", "critical"],
|
||||||
|
help="Set the logging level. If not specified, defaults to INFO.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}service-discovery",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable Kubernetes service discovery",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}selector",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default={},
|
||||||
|
help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}service-discovery-port",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.service_discovery_port,
|
||||||
|
help="Port to use for discovered worker pods",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}service-discovery-namespace",
|
||||||
|
type=str,
|
||||||
|
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}prefill-selector",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default={},
|
||||||
|
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}decode-selector",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default={},
|
||||||
|
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
|
||||||
|
)
|
||||||
|
# Prometheus configuration
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}prometheus-port",
|
||||||
|
type=int,
|
||||||
|
default=29000,
|
||||||
|
help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}prometheus-host",
|
||||||
|
type=str,
|
||||||
|
default="127.0.0.1",
|
||||||
|
help="Host address to bind the Prometheus metrics server",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}request-id-headers",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}request-timeout-secs",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.request_timeout_secs,
|
||||||
|
help="Request timeout in seconds",
|
||||||
|
)
|
||||||
|
# Retry configuration
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}retry-max-retries",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.retry_max_retries,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}retry-initial-backoff-ms",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.retry_initial_backoff_ms,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}retry-max-backoff-ms",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.retry_max_backoff_ms,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}retry-backoff-multiplier",
|
||||||
|
type=float,
|
||||||
|
default=RouterArgs.retry_backoff_multiplier,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}retry-jitter-factor",
|
||||||
|
type=float,
|
||||||
|
default=RouterArgs.retry_jitter_factor,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}disable-retries",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable retries (equivalent to setting retry_max_retries=1)",
|
||||||
|
)
|
||||||
|
# Circuit breaker configuration
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}cb-failure-threshold",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.cb_failure_threshold,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}cb-success-threshold",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.cb_success_threshold,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}cb-timeout-duration-secs",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.cb_timeout_duration_secs,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}cb-window-duration-secs",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.cb_window_duration_secs,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}disable-circuit-breaker",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
|
||||||
|
)
|
||||||
|
# Health check configuration
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}health-failure-threshold",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.health_failure_threshold,
|
||||||
|
help="Number of consecutive health check failures before marking worker unhealthy",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}health-success-threshold",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.health_success_threshold,
|
||||||
|
help="Number of consecutive health check successes before marking worker healthy",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}health-check-timeout-secs",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.health_check_timeout_secs,
|
||||||
|
help="Timeout in seconds for health check requests",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}health-check-interval-secs",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.health_check_interval_secs,
|
||||||
|
help="Interval in seconds between runtime health checks",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}health-check-endpoint",
|
||||||
|
type=str,
|
||||||
|
default=RouterArgs.health_check_endpoint,
|
||||||
|
help="Health check endpoint path",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}max-concurrent-requests",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.max_concurrent_requests,
|
||||||
|
help="Maximum number of concurrent requests allowed (for rate limiting)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}queue-size",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.queue_size,
|
||||||
|
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}queue-timeout-secs",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.queue_timeout_secs,
|
||||||
|
help="Maximum time (in seconds) a request can wait in queue before timing out",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}rate-limit-tokens-per-second",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.rate_limit_tokens_per_second,
|
||||||
|
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}cors-allowed-origins",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
default=[],
|
||||||
|
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
|
||||||
|
)
|
||||||
|
# Tokenizer configuration
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}model-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Model path for loading tokenizer (HuggingFace model ID or local path)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}tokenizer-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli_args(
|
||||||
|
cls, args: argparse.Namespace, use_router_prefix: bool = False
|
||||||
|
) -> "RouterArgs":
|
||||||
|
"""
|
||||||
|
Create RouterArgs instance from parsed command line arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Parsed command line arguments
|
||||||
|
use_router_prefix: If True, look for arguments with 'router-' prefix
|
||||||
|
"""
|
||||||
|
prefix = "router_" if use_router_prefix else ""
|
||||||
|
cli_args_dict = vars(args)
|
||||||
|
args_dict = {}
|
||||||
|
|
||||||
|
for attr in dataclasses.fields(cls):
|
||||||
|
# Auto strip prefix from args
|
||||||
|
if f"{prefix}{attr.name}" in cli_args_dict:
|
||||||
|
args_dict[attr.name] = cli_args_dict[f"{prefix}{attr.name}"]
|
||||||
|
elif attr.name in cli_args_dict:
|
||||||
|
args_dict[attr.name] = cli_args_dict[attr.name]
|
||||||
|
|
||||||
|
# parse special arguments and remove "--prefill" and "--decode" from cli_args_dict
|
||||||
|
args_dict["prefill_urls"] = cls._parse_prefill_urls(
|
||||||
|
cli_args_dict.get(f"{prefix}prefill", None)
|
||||||
|
)
|
||||||
|
args_dict["decode_urls"] = cls._parse_decode_urls(
|
||||||
|
cli_args_dict.get(f"{prefix}decode", None)
|
||||||
|
)
|
||||||
|
args_dict["selector"] = cls._parse_selector(
|
||||||
|
cli_args_dict.get(f"{prefix}selector", None)
|
||||||
|
)
|
||||||
|
args_dict["prefill_selector"] = cls._parse_selector(
|
||||||
|
cli_args_dict.get(f"{prefix}prefill_selector", None)
|
||||||
|
)
|
||||||
|
args_dict["decode_selector"] = cls._parse_selector(
|
||||||
|
cli_args_dict.get(f"{prefix}decode_selector", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mooncake-specific annotation
|
||||||
|
args_dict["bootstrap_port_annotation"] = "sglang.ai/bootstrap-port"
|
||||||
|
|
||||||
|
return cls(**args_dict)
|
||||||
|
|
||||||
|
def _validate_router_args(self):
|
||||||
|
# Validate configuration based on mode
|
||||||
|
if self.pd_disaggregation:
|
||||||
|
# Validate PD configuration - skip URL requirements if using service discovery
|
||||||
|
if not self.service_discovery:
|
||||||
|
if not self.prefill_urls:
|
||||||
|
raise ValueError("PD disaggregation mode requires --prefill")
|
||||||
|
if not self.decode_urls:
|
||||||
|
raise ValueError("PD disaggregation mode requires --decode")
|
||||||
|
|
||||||
|
# Warn about policy usage in PD mode
|
||||||
|
if self.prefill_policy and self.decode_policy and self.policy:
|
||||||
|
logger.warning(
|
||||||
|
"Both --prefill-policy and --decode-policy are specified. "
|
||||||
|
"The main --policy flag will be ignored for PD mode."
|
||||||
|
)
|
||||||
|
elif self.prefill_policy and not self.decode_policy and self.policy:
|
||||||
|
logger.info(
|
||||||
|
f"Using --prefill-policy '{self.prefill_policy}' for prefill nodes "
|
||||||
|
f"and --policy '{self.policy}' for decode nodes."
|
||||||
|
)
|
||||||
|
elif self.decode_policy and not self.prefill_policy and self.policy:
|
||||||
|
logger.info(
|
||||||
|
f"Using --policy '{self.policy}' for prefill nodes "
|
||||||
|
f"and --decode-policy '{self.decode_policy}' for decode nodes."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_selector(selector_list):
|
||||||
|
if not selector_list:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
selector = {}
|
||||||
|
for item in selector_list:
|
||||||
|
if "=" in item:
|
||||||
|
key, value = item.split("=", 1)
|
||||||
|
selector[key] = value
|
||||||
|
return selector
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_prefill_urls(prefill_list):
|
||||||
|
"""Parse prefill URLs from --prefill arguments.
|
||||||
|
|
||||||
|
Format: --prefill URL [BOOTSTRAP_PORT]
|
||||||
|
Example:
|
||||||
|
--prefill http://prefill1:8080 9000 # With bootstrap port
|
||||||
|
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
|
||||||
|
--prefill http://prefill3:8080 # Defaults to no bootstrap port
|
||||||
|
"""
|
||||||
|
if not prefill_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
prefill_urls = []
|
||||||
|
for prefill_args in prefill_list:
|
||||||
|
|
||||||
|
url = prefill_args[0]
|
||||||
|
|
||||||
|
# Handle optional bootstrap port
|
||||||
|
if len(prefill_args) >= 2:
|
||||||
|
bootstrap_port_str = prefill_args[1]
|
||||||
|
# Handle 'none' as None
|
||||||
|
if bootstrap_port_str.lower() == "none":
|
||||||
|
bootstrap_port = None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
bootstrap_port = int(bootstrap_port_str)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No bootstrap port specified, default to None
|
||||||
|
bootstrap_port = None
|
||||||
|
|
||||||
|
prefill_urls.append((url, bootstrap_port))
|
||||||
|
|
||||||
|
return prefill_urls
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_decode_urls(decode_list):
|
||||||
|
"""Parse decode URLs from --decode arguments.
|
||||||
|
|
||||||
|
Format: --decode URL
|
||||||
|
Example: --decode http://decode1:8081 --decode http://decode2:8081
|
||||||
|
"""
|
||||||
|
if not decode_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# decode_list is a list of single-element lists due to nargs=1
|
||||||
|
return [url[0] for url in decode_list]
|
||||||
@@ -33,7 +33,7 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
cache_threshold=0.5,
|
cache_threshold=0.5,
|
||||||
balance_abs_threshold=32,
|
balance_abs_threshold=32,
|
||||||
balance_rel_threshold=1.0001,
|
balance_rel_threshold=1.0001,
|
||||||
eviction_interval=60,
|
eviction_interval_secs=60,
|
||||||
max_tree_size=2**24,
|
max_tree_size=2**24,
|
||||||
max_payload_size=256 * 1024 * 1024, # 256MB
|
max_payload_size=256 * 1024 * 1024, # 256MB
|
||||||
verbose=False,
|
verbose=False,
|
||||||
@@ -176,9 +176,8 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
"""Test basic PD router functionality without actually starting servers."""
|
"""Test basic PD router functionality without actually starting servers."""
|
||||||
# This test just verifies the PD router can be created and configured
|
# This test just verifies the PD router can be created and configured
|
||||||
# without actually starting it (which would require real prefill/decode servers)
|
# without actually starting it (which would require real prefill/decode servers)
|
||||||
from sglang_router import Router
|
|
||||||
from sglang_router.launch_router import RouterArgs
|
from sglang_router.launch_router import RouterArgs
|
||||||
from sglang_router_rs import PolicyType
|
from sglang_router.router import PolicyType, Router
|
||||||
|
|
||||||
# Test RouterArgs parsing for PD mode
|
# Test RouterArgs parsing for PD mode
|
||||||
# Simulate the parsed args structure from argparse with action="append"
|
# Simulate the parsed args structure from argparse with action="append"
|
||||||
@@ -209,18 +208,7 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
|
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
|
||||||
|
|
||||||
# Test Router creation in PD mode
|
# Test Router creation in PD mode
|
||||||
router = Router(
|
router = Router.from_args(router_args)
|
||||||
worker_urls=[], # Empty for PD mode
|
|
||||||
pd_disaggregation=True,
|
|
||||||
prefill_urls=[
|
|
||||||
("http://prefill1:8080", 9000),
|
|
||||||
("http://prefill2:8080", None),
|
|
||||||
],
|
|
||||||
decode_urls=["http://decode1:8081", "http://decode2:8081"],
|
|
||||||
policy=PolicyType.CacheAware,
|
|
||||||
host="127.0.0.1",
|
|
||||||
port=3001,
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(router)
|
self.assertIsNotNone(router)
|
||||||
|
|
||||||
def test_policy_validation(self):
|
def test_policy_validation(self):
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ def popen_launch_router(
|
|||||||
port,
|
port,
|
||||||
"--dp",
|
"--dp",
|
||||||
str(dp_size),
|
str(dp_size),
|
||||||
"--router-eviction-interval",
|
"--router-eviction-interval-secs",
|
||||||
"5",
|
"5",
|
||||||
"--router-policy",
|
"--router-policy",
|
||||||
policy,
|
policy,
|
||||||
|
|||||||
@@ -28,8 +28,3 @@ find = { where = ["py_src"] }
|
|||||||
# workaround for https://github.com/pypa/twine/issues/1216
|
# workaround for https://github.com/pypa/twine/issues/1216
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
license-files = []
|
license-files = []
|
||||||
|
|
||||||
[[tool.setuptools-rust.ext-modules]]
|
|
||||||
target = "sglang_router_rs"
|
|
||||||
path = "Cargo.toml"
|
|
||||||
binding = "PyO3"
|
|
||||||
|
|||||||
21
sgl-router/setup.py
Normal file
21
sgl-router/setup.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from setuptools import setup
|
||||||
|
from setuptools_rust import Binding, RustExtension
|
||||||
|
|
||||||
|
no_rust = os.environ.get("SGLANG_ROUTER_BUILD_NO_RUST") == "1"
|
||||||
|
|
||||||
|
rust_extensions = []
|
||||||
|
if not no_rust:
|
||||||
|
rust_extensions.append(
|
||||||
|
RustExtension(
|
||||||
|
target="sglang_router_rs",
|
||||||
|
path="Cargo.toml",
|
||||||
|
binding=Binding.PyO3,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
setup(
|
||||||
|
rust_extensions=rust_extensions,
|
||||||
|
zip_safe=False,
|
||||||
|
)
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
@@ -18,6 +17,7 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
CustomTestCase,
|
CustomTestCase,
|
||||||
popen_launch_pd_server,
|
popen_launch_pd_server,
|
||||||
|
popen_with_error_check,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -47,7 +47,9 @@ class TestDisaggregationAccuracy(CustomTestCase):
|
|||||||
lb_command = [
|
lb_command = [
|
||||||
"python3",
|
"python3",
|
||||||
"-m",
|
"-m",
|
||||||
"sglang.srt.disaggregation.mini_lb",
|
"sglang_router.launch_router",
|
||||||
|
"--pd-disaggregation",
|
||||||
|
"--mini-lb", # FIXME: remove this
|
||||||
"--prefill",
|
"--prefill",
|
||||||
cls.prefill_url,
|
cls.prefill_url,
|
||||||
"--decode",
|
"--decode",
|
||||||
@@ -59,9 +61,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
print("Starting load balancer:", " ".join(lb_command))
|
print("Starting load balancer:", " ".join(lb_command))
|
||||||
cls.process_lb = subprocess.Popen(
|
cls.process_lb = popen_with_error_check(lb_command)
|
||||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
|
||||||
)
|
|
||||||
cls.wait_server_ready(cls.lb_url + "/health")
|
cls.wait_server_ready(cls.lb_url + "/health")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -228,7 +228,9 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
|
|||||||
lb_command = [
|
lb_command = [
|
||||||
"python3",
|
"python3",
|
||||||
"-m",
|
"-m",
|
||||||
"sglang.srt.disaggregation.mini_lb",
|
"sglang_router.launch_router",
|
||||||
|
"--pd-disaggregation",
|
||||||
|
"--mini-lb", # FIXME: remove this
|
||||||
"--prefill",
|
"--prefill",
|
||||||
cls.prefill_url,
|
cls.prefill_url,
|
||||||
"--decode",
|
"--decode",
|
||||||
@@ -240,9 +242,7 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
print("Starting load balancer:", " ".join(lb_command))
|
print("Starting load balancer:", " ".join(lb_command))
|
||||||
cls.process_lb = subprocess.Popen(
|
cls.process_lb = popen_with_error_check(lb_command)
|
||||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
|
||||||
)
|
|
||||||
cls.wait_server_ready(cls.lb_url + "/health")
|
cls.wait_server_ready(cls.lb_url + "/health")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -383,7 +383,9 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
|
|||||||
lb_command = [
|
lb_command = [
|
||||||
"python3",
|
"python3",
|
||||||
"-m",
|
"-m",
|
||||||
"sglang.srt.disaggregation.mini_lb",
|
"sglang_router.launch_router",
|
||||||
|
"--pd-disaggregation",
|
||||||
|
"--mini-lb", # FIXME: remove this
|
||||||
"--prefill",
|
"--prefill",
|
||||||
cls.prefill_url,
|
cls.prefill_url,
|
||||||
"--decode",
|
"--decode",
|
||||||
@@ -395,9 +397,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
print("Starting load balancer:", " ".join(lb_command))
|
print("Starting load balancer:", " ".join(lb_command))
|
||||||
cls.process_lb = subprocess.Popen(
|
cls.process_lb = popen_with_error_check(lb_command)
|
||||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
|
||||||
)
|
|
||||||
cls.wait_server_ready(cls.lb_url + "/health")
|
cls.wait_server_ready(cls.lb_url + "/health")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -509,7 +509,9 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
|
|||||||
lb_command = [
|
lb_command = [
|
||||||
"python3",
|
"python3",
|
||||||
"-m",
|
"-m",
|
||||||
"sglang.srt.disaggregation.mini_lb",
|
"sglang_router.launch_router",
|
||||||
|
"--pd-disaggregation",
|
||||||
|
"--mini-lb", # FIXME: remove this
|
||||||
"--prefill",
|
"--prefill",
|
||||||
cls.prefill_url,
|
cls.prefill_url,
|
||||||
"--decode",
|
"--decode",
|
||||||
@@ -521,9 +523,7 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
print("Starting load balancer:", " ".join(lb_command))
|
print("Starting load balancer:", " ".join(lb_command))
|
||||||
cls.process_lb = subprocess.Popen(
|
cls.process_lb = popen_with_error_check(lb_command)
|
||||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
|
||||||
)
|
|
||||||
cls.wait_server_ready(cls.lb_url + "/health")
|
cls.wait_server_ready(cls.lb_url + "/health")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
CustomTestCase,
|
CustomTestCase,
|
||||||
popen_launch_pd_server,
|
popen_launch_pd_server,
|
||||||
run_with_timeout,
|
popen_with_error_check,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -49,7 +49,9 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
|
|||||||
lb_command = [
|
lb_command = [
|
||||||
"python3",
|
"python3",
|
||||||
"-m",
|
"-m",
|
||||||
"sglang.srt.disaggregation.mini_lb",
|
"sglang_router.launch_router",
|
||||||
|
"--pd-disaggregation",
|
||||||
|
"--mini-lb", # FIXME: remove this
|
||||||
"--prefill",
|
"--prefill",
|
||||||
cls.prefill_url,
|
cls.prefill_url,
|
||||||
"--decode",
|
"--decode",
|
||||||
@@ -61,9 +63,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
print("Starting load balancer:", " ".join(lb_command))
|
print("Starting load balancer:", " ".join(lb_command))
|
||||||
cls.process_lb = subprocess.Popen(
|
cls.process_lb = popen_with_error_check(lb_command)
|
||||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
|
||||||
)
|
|
||||||
cls.wait_server_ready(cls.lb_url + "/health")
|
cls.wait_server_ready(cls.lb_url + "/health")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -183,7 +183,9 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
|
|||||||
lb_command = [
|
lb_command = [
|
||||||
"python3",
|
"python3",
|
||||||
"-m",
|
"-m",
|
||||||
"sglang.srt.disaggregation.mini_lb",
|
"sglang_router.launch_router",
|
||||||
|
"--pd-disaggregation",
|
||||||
|
"--mini-lb", # FIXME: remove this
|
||||||
"--prefill",
|
"--prefill",
|
||||||
cls.prefill_url,
|
cls.prefill_url,
|
||||||
"--decode",
|
"--decode",
|
||||||
@@ -195,9 +197,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
print("Starting load balancer:", " ".join(lb_command))
|
print("Starting load balancer:", " ".join(lb_command))
|
||||||
cls.process_lb = subprocess.Popen(
|
cls.process_lb = popen_with_error_check(lb_command)
|
||||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
|
||||||
)
|
|
||||||
cls.wait_server_ready(cls.lb_url + "/health")
|
cls.wait_server_ready(cls.lb_url + "/health")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -49,7 +49,9 @@ class TestPDPPAccuracy(unittest.TestCase):
|
|||||||
lb_command = [
|
lb_command = [
|
||||||
"python3",
|
"python3",
|
||||||
"-m",
|
"-m",
|
||||||
"sglang.srt.disaggregation.mini_lb",
|
"sglang_router.launch_router",
|
||||||
|
"--pd-disaggregation",
|
||||||
|
"--mini-lb", # FIXME: remove this
|
||||||
"--prefill",
|
"--prefill",
|
||||||
cls.prefill_url,
|
cls.prefill_url,
|
||||||
"--decode",
|
"--decode",
|
||||||
|
|||||||
Reference in New Issue
Block a user