diff --git a/docker/Dockerfile b/docker/Dockerfile index 4482297e9..4f63091bf 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -36,7 +36,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ ibverbs-providers infiniband-diags perftest \ libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev \ libboost-all-dev libssl-dev \ - libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler-grpc \ + libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler protobuf-compiler-grpc \ pybind11-dev \ libhiredis-dev libcurl4-openssl-dev \ libczmq4 libczmq-dev \ @@ -218,6 +218,19 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1 && cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \ && rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz +# Install Rust toolchain for sgl-router +ENV PATH="/root/.cargo/bin:${PATH}" +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \ + && rustc --version && cargo --version + +# Build and install sgl-router +RUN python3 -m pip install --no-cache-dir setuptools-rust \ + && cd /sgl-workspace/sglang/sgl-router \ + && cargo build --release \ + && python3 -m pip install --no-cache-dir . \ + && rm -rf /root/.cache + + # Add yank script COPY --chown=root:root <<-"EOF" /usr/local/bin/yank #!/bin/bash diff --git a/docs/advanced_features/pd_disaggregation.md b/docs/advanced_features/pd_disaggregation.md index f7cc0adaf..85a5db07e 100644 --- a/docs/advanced_features/pd_disaggregation.md +++ b/docs/advanced_features/pd_disaggregation.md @@ -36,7 +36,7 @@ uv pip install mooncake-transfer-engine ```bash $ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-ib-device mlx5_roce0 $ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-ib-device mlx5_roce0 -$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000 +$ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000 ``` ### DeepSeek Multi-Node @@ -100,7 +100,7 @@ pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx" ```bash $ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend nixl $ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend nixl -$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000 +$ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000 ``` ### DeepSeek Multi-Node @@ -137,7 +137,7 @@ export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true ```bash $ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend ascend $ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend ascend -$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000 +$ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000 ``` ### DeepSeek Multi-Node diff --git a/docs/advanced_features/router.md b/docs/advanced_features/router.md index 555a0bc4b..4aba99f37 100644 --- a/docs/advanced_features/router.md +++ b/docs/advanced_features/router.md @@ -278,7 +278,7 @@ The most sophisticated policy that combines cache optimization with load balanci 3. **Cache Management**: - Maintains approximate radix trees per worker - - Periodically evicts LRU entries based on `--eviction-interval` and `--max-tree-size` + - Periodically evicts LRU entries based on `--eviction-interval-secs` and `--max-tree-size` ### Data Parallelism Aware Routing @@ -296,7 +296,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu ### Core Settings | Parameter | Type | Default | Description | -|-----------------------------|------|-------------|-----------------------------------------------------------------| +| --------------------------- | ---- | ----------- | --------------------------------------------------------------- | | `--host` | str | 127.0.0.1 | Router server host address | | `--port` | int | 30000 | Router server port | | `--worker-urls` | list | [] | Worker URLs for separate launch mode | @@ -307,18 +307,18 @@ This mode coordinates with SGLang's DP controller for optimized request distribu ### Cache-Aware Routing Parameters -| Parameter | Type | Default | Description | -|---------------------------|-------|----------|--------------------------------------------------------| -| `--cache-threshold` | float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) | -| `--balance-abs-threshold` | int | 32 | Absolute load difference threshold | -| `--balance-rel-threshold` | float | 1.0001 | Relative load ratio threshold | -| `--eviction-interval` | int | 60 | Seconds between cache eviction cycles | -| `--max-tree-size` | int | 16777216 | Maximum nodes in routing tree | +| Parameter | Type | Default | Description | +| -------------------------- | ----- | -------- | ------------------------------------------------------ | +| `--cache-threshold` | float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) | +| `--balance-abs-threshold` | int | 32 | Absolute load difference threshold | +| `--balance-rel-threshold` | float | 1.0001 | Relative load ratio threshold | +| `--eviction-interval-secs` | int | 60 | Seconds between cache eviction cycles | +| `--max-tree-size` | int | 16777216 | Maximum nodes in routing tree | ### Fault Tolerance Parameters | Parameter | Type | Default | Description | -|------------------------------|-------|---------|---------------------------------------| +| ---------------------------- | ----- | ------- | ------------------------------------- | | `--retry-max-retries` | int | 3 | Maximum retry attempts per request | | `--retry-initial-backoff-ms` | int | 100 | Initial retry backoff in milliseconds | | `--retry-max-backoff-ms` | int | 10000 | Maximum retry backoff in milliseconds | @@ -334,7 +334,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu ### Prefill-Decode Disaggregation Parameters | Parameter | Type | Default | Description | -|-----------------------------------|------|---------|-------------------------------------------------------| +| --------------------------------- | ---- | ------- | ----------------------------------------------------- | | `--pd-disaggregation` | flag | False | Enable PD disaggregated mode | | `--prefill` | list | [] | Prefill server URLs with optional bootstrap ports | | `--decode` | list | [] | Decode server URLs | @@ -346,7 +346,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu ### Kubernetes Integration | Parameter | Type | Default | Description | -|---------------------------------|------|--------------------------|------------------------------------------------------| +| ------------------------------- | ---- | ------------------------ | ---------------------------------------------------- | | `--service-discovery` | flag | False | Enable Kubernetes service discovery | | `--selector` | list | [] | Label selector for workers (key1=value1 key2=value2) | | `--prefill-selector` | list | [] | Label selector for prefill servers in PD mode | @@ -358,7 +358,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu ### Observability | Parameter | Type | Default | Description | -|------------------------|------|-----------|-------------------------------------------------------| +| ---------------------- | ---- | --------- | ----------------------------------------------------- | | `--prometheus-port` | int | 29000 | Prometheus metrics port | | `--prometheus-host` | str | 127.0.0.1 | Prometheus metrics host | | `--log-dir` | str | None | Directory for log files | @@ -368,7 +368,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu ### CORS Configuration | Parameter | Type | Default | Description | -|--------------------------|------|---------|----------------------| +| ------------------------ | ---- | ------- | -------------------- | | `--cors-allowed-origins` | list | [] | Allowed CORS origins | ## Advanced Features @@ -429,7 +429,7 @@ python -m sglang_router.launch_router \ 2. **High latency**: Check if cache-aware routing is causing imbalance. Try adjusting `--balance-abs-threshold` and `--balance-rel-threshold`. -3. **Memory growth**: Reduce `--max-tree-size` or decrease `--eviction-interval` for more aggressive cache cleanup. +3. **Memory growth**: Reduce `--max-tree-size` or decrease `--eviction-interval-secs` for more aggressive cache cleanup. 4. **Circuit breaker triggering frequently**: Increase `--cb-failure-threshold` or extend `--cb-window-duration-secs`. diff --git a/docs/references/multi_node_deployment/lws_pd/lws-examples/lb.yaml b/docs/references/multi_node_deployment/lws_pd/lws-examples/lb.yaml index da7861584..4ca690969 100644 --- a/docs/references/multi_node_deployment/lws_pd/lws-examples/lb.yaml +++ b/docs/references/multi_node_deployment/lws_pd/lws-examples/lb.yaml @@ -27,7 +27,8 @@ spec: command: - python - -m - - sglang.srt.disaggregation.mini_lb + - sglang_router.launch_router + - --pd-disaggregation - --prefill - http://deepseekr10528-prefill-main:30000 - --decode diff --git a/docs/references/multi_node_deployment/lws_pd/lws_pd_deploy.md b/docs/references/multi_node_deployment/lws_pd/lws_pd_deploy.md index 617017077..eb8454997 100644 --- a/docs/references/multi_node_deployment/lws_pd/lws_pd_deploy.md +++ b/docs/references/multi_node_deployment/lws_pd/lws_pd_deploy.md @@ -714,7 +714,8 @@ spec: command: - python - -m - - sglang.srt.disaggregation.mini_lb + - sglang_router.launch_router + - --pd-disaggregation - --prefill - http://deepseekr10528-prefill-main:30000 - --decode diff --git a/python/sglang/srt/disaggregation/launch_lb.py b/python/sglang/srt/disaggregation/launch_lb.py deleted file mode 100644 index eb0be6573..000000000 --- a/python/sglang/srt/disaggregation/launch_lb.py +++ /dev/null @@ -1,118 +0,0 @@ -import argparse -import dataclasses - -from sglang.srt.disaggregation.mini_lb import PrefillConfig, run - - -@dataclasses.dataclass -class LBArgs: - host: str = "0.0.0.0" - port: int = 8000 - policy: str = "random" - prefill_infos: list = dataclasses.field(default_factory=list) - decode_infos: list = dataclasses.field(default_factory=list) - log_interval: int = 5 - timeout: int = 600 - - @staticmethod - def add_cli_args(parser: argparse.ArgumentParser): - parser.add_argument( - "--host", - type=str, - default=LBArgs.host, - help=f"Host to bind the server (default: {LBArgs.host})", - ) - parser.add_argument( - "--port", - type=int, - default=LBArgs.port, - help=f"Port to bind the server (default: {LBArgs.port})", - ) - parser.add_argument( - "--policy", - type=str, - default=LBArgs.policy, - choices=["random", "po2"], - help=f"Policy to use for load balancing (default: {LBArgs.policy})", - ) - parser.add_argument( - "--prefill", - type=str, - default=[], - nargs="+", - help="URLs for prefill servers", - ) - parser.add_argument( - "--decode", - type=str, - default=[], - nargs="+", - help="URLs for decode servers", - ) - parser.add_argument( - "--prefill-bootstrap-ports", - type=int, - nargs="+", - help="Bootstrap ports for prefill servers", - ) - parser.add_argument( - "--log-interval", - type=int, - default=LBArgs.log_interval, - help=f"Log interval in seconds (default: {LBArgs.log_interval})", - ) - parser.add_argument( - "--timeout", - type=int, - default=LBArgs.timeout, - help=f"Timeout in seconds (default: {LBArgs.timeout})", - ) - - @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs": - bootstrap_ports = args.prefill_bootstrap_ports - if bootstrap_ports is None: - bootstrap_ports = [None] * len(args.prefill) - elif len(bootstrap_ports) == 1: - bootstrap_ports = bootstrap_ports * len(args.prefill) - else: - if len(bootstrap_ports) != len(args.prefill): - raise ValueError( - "Number of prefill URLs must match number of bootstrap ports" - ) - - prefill_infos = [ - (url, port) for url, port in zip(args.prefill, bootstrap_ports) - ] - - return cls( - host=args.host, - port=args.port, - policy=args.policy, - prefill_infos=prefill_infos, - decode_infos=args.decode, - log_interval=args.log_interval, - timeout=args.timeout, - ) - - -def main(): - parser = argparse.ArgumentParser( - description="PD Disaggregation Load Balancer Server" - ) - LBArgs.add_cli_args(parser) - args = parser.parse_args() - lb_args = LBArgs.from_cli_args(args) - - prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos] - run( - prefill_configs, - lb_args.decode_infos, - lb_args.host, - lb_args.port, - lb_args.timeout, - ) - - -if __name__ == "__main__": - main() diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index d29e61853..5aaa2a70e 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -1,445 +1,6 @@ -""" -Minimal HTTP load balancer for prefill and decode servers for testing. -""" - -import asyncio -import dataclasses -import logging -import random -import urllib -from http import HTTPStatus -from itertools import chain -from typing import List, Optional - -import aiohttp -import orjson -import uvicorn -from fastapi import FastAPI, HTTPException -from fastapi.responses import ORJSONResponse, Response, StreamingResponse - -from sglang.srt.disaggregation.utils import PDRegistryRequest -from sglang.srt.utils import maybe_wrap_ipv6_address - -AIOHTTP_STREAM_READ_CHUNK_SIZE = ( - 1024 * 64 -) # 64KB, to prevent aiohttp's "Chunk too big" error - - -def setup_logger(): - logger = logging.getLogger("pdlb") - logger.setLevel(logging.INFO) - - formatter = logging.Formatter( - "[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logger.addHandler(handler) - - return logger - - -logger = setup_logger() - - -@dataclasses.dataclass -class PrefillConfig: - url: str - bootstrap_port: Optional[int] = None - - -class MiniLoadBalancer: - def __init__( - self, - prefill_configs: List[PrefillConfig], - decode_servers: List[str], - timeout: int, - ): - self.prefill_configs = prefill_configs - self.prefill_servers = [p.url for p in prefill_configs] - self.decode_servers = decode_servers - self.timeout = timeout - - def add_prefill_server(self, new_prefill_config: PrefillConfig): - self.prefill_configs.append(new_prefill_config) - self.prefill_servers.append(new_prefill_config.url) - - def add_decode_server(self, new_decode_server: str): - self.decode_servers.append(new_decode_server) - - def select_pair(self): - # TODO: return some message instead of panic - assert len(self.prefill_configs) > 0, "No prefill servers available" - assert len(self.decode_servers) > 0, "No decode servers available" - - prefill_config = random.choice(self.prefill_configs) - decode_server = random.choice(self.decode_servers) - return prefill_config.url, prefill_config.bootstrap_port, decode_server - - async def generate( - self, modified_request, prefill_server, decode_server, endpoint - ) -> ORJSONResponse: - assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" - - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout( - total=self.timeout - ) # Add timeout for request reliability - ) as session: - tasks = [ - session.post(f"{prefill_server}/{endpoint}", json=modified_request), - session.post(f"{decode_server}/{endpoint}", json=modified_request), - ] - - # Wait for both responses to complete. Prefill should end first. - prefill_response, decode_response = await asyncio.gather(*tasks) - - if "return_logprob" in modified_request: - - prefill_json = await prefill_response.json() - ret_json = await decode_response.json() - - # merge `meta_info.input_token_logprobs` from prefill to decode - if "meta_info" in ret_json: - if "input_token_logprobs" in ret_json["meta_info"]: - ret_json["meta_info"]["input_token_logprobs"] = ( - prefill_json["meta_info"]["input_token_logprobs"] - + ret_json["meta_info"]["input_token_logprobs"] - ) - else: - ret_json = await decode_response.json() - - return ORJSONResponse( - content=ret_json, - status_code=decode_response.status, - ) - - async def generate_stream( - self, modified_request, prefill_server, decode_server, endpoint="generate" - ): - assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" - - async def stream_results(): - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout( - total=self.timeout - ) # Add timeout for request reliability - ) as session: - # Create the tasks for both prefill and decode requests - tasks = [ - session.post(f"{prefill_server}/{endpoint}", json=modified_request), - session.post(f"{decode_server}/{endpoint}", json=modified_request), - ] - # Wait for both responses to complete. Since this is streaming, they return immediately. - prefill_response, decode_response = await asyncio.gather(*tasks) - - if modified_request.get("return_logprob", False): - prefill_chunks = [] - async for chunk in prefill_response.content: - prefill_chunks.append(chunk) - - first_prefill_chunk = ( - prefill_chunks[0].decode("utf-8")[5:].strip("\n") - ) - first_prefill_chunk_json = orjson.loads(first_prefill_chunk) - - async for chunk in decode_response.content: - # Note: This is inefficient - # merge prefill input_token_logprobs, output_token_logprobs to decode - decoded_chunk = chunk.decode("utf-8") - if ( - decoded_chunk - and decoded_chunk.startswith("data:") - and "[DONE]" not in decoded_chunk - ): - ret_json = orjson.loads(decoded_chunk[5:].strip("\n")) - ret_json["meta_info"]["input_token_logprobs"] = ( - first_prefill_chunk_json["meta_info"][ - "input_token_logprobs" - ] - + ret_json["meta_info"]["input_token_logprobs"] - ) - - yield b"data: " + orjson.dumps(ret_json) + b"\n\n" - else: - yield chunk - else: - async for chunk in decode_response.content.iter_chunked( - AIOHTTP_STREAM_READ_CHUNK_SIZE - ): - yield chunk - - return StreamingResponse( - stream_results(), - media_type="text/event-stream", - ) - - -app = FastAPI() -load_balancer: Optional[MiniLoadBalancer] = None - - -@app.get("/health") -async def health_check(): - return Response(status_code=200) - - -@app.get("/health_generate") -async def health_generate(): - prefill_servers, decode_servers = ( - load_balancer.prefill_servers, - load_balancer.decode_servers, - ) - async with aiohttp.ClientSession() as session: - # Create the tasks - tasks = [] - for server in chain(prefill_servers, decode_servers): - tasks.append(session.get(f"{server}/health_generate")) - for i, response in enumerate(asyncio.as_completed(tasks)): - await response - return Response(status_code=200) - - -@app.post("/flush_cache") -async def flush_cache(): - prefill_servers, decode_servers = ( - load_balancer.prefill_servers, - load_balancer.decode_servers, - ) - async with aiohttp.ClientSession() as session: - # Create the tasks - tasks = [] - for server in chain(prefill_servers, decode_servers): - tasks.append(session.post(f"{server}/flush_cache")) - for i, response in enumerate(asyncio.as_completed(tasks)): - await response - return Response(status_code=200) - - -@app.get("/get_server_info") -async def get_server_info(): - prefill_servers, decode_servers = ( - load_balancer.prefill_servers, - load_balancer.decode_servers, - ) - prefill_infos = [] - decode_infos = [] - all_internal_states = [] - - async with aiohttp.ClientSession() as session: - for server in chain(prefill_servers): - server_info = await session.get(f"{server}/get_server_info") - prefill_infos.append(await server_info.json()) - for server in chain(decode_servers): - server_info = await session.get(f"{server}/get_server_info") - info_json = await server_info.json() - decode_infos.append(info_json) - # Extract internal_states from decode servers - if "internal_states" in info_json: - all_internal_states.extend(info_json["internal_states"]) - - # Return format expected by bench_one_batch_server.py - if all_internal_states: - return { - "internal_states": all_internal_states, - "prefill": prefill_infos, - "decode": decode_infos, - } - else: - # Fallback with dummy data if no internal states found - return { - "internal_states": [ - { - "last_gen_throughput": 0.0, - "avg_spec_accept_length": None, - } - ], - "prefill": prefill_infos, - "decode": decode_infos, - } - - -@app.get("/get_model_info") -async def get_model_info(): - global load_balancer - - if not load_balancer or not load_balancer.prefill_servers: - raise HTTPException( - status_code=HTTPStatus.SERVICE_UNAVAILABLE, - detail="There is no server registered", - ) - - target_server_url = load_balancer.prefill_servers[0] - endpoint_url = f"{target_server_url}/get_model_info" - - async with aiohttp.ClientSession() as session: - try: - async with session.get(endpoint_url) as response: - if response.status != 200: - error_text = await response.text() - raise HTTPException( - status_code=HTTPStatus.BAD_GATEWAY, - detail=( - f"Failed to get model info from {target_server_url}" - f"Status: {response.status}, Response: {error_text}" - ), - ) - - model_info_json = await response.json() - return ORJSONResponse(content=model_info_json) - - except aiohttp.ClientError as e: - raise HTTPException( - status_code=HTTPStatus.SERVICE_UNAVAILABLE, - detail=f"Failed to get model info from backend", - ) - - -@app.post("/generate") -async def handle_generate_request(request_data: dict): - prefill_server, bootstrap_port, decode_server = load_balancer.select_pair() - - # Parse and transform prefill_server for bootstrap data - parsed_url = urllib.parse.urlparse(prefill_server) - hostname = maybe_wrap_ipv6_address(parsed_url.hostname) - modified_request = request_data.copy() - - batch_size = _get_request_batch_size(modified_request) - if batch_size is not None: - modified_request.update( - { - "bootstrap_host": [hostname] * batch_size, - "bootstrap_port": [bootstrap_port] * batch_size, - "bootstrap_room": [ - _generate_bootstrap_room() for _ in range(batch_size) - ], - } - ) - else: - modified_request.update( - { - "bootstrap_host": hostname, - "bootstrap_port": bootstrap_port, - "bootstrap_room": _generate_bootstrap_room(), - } - ) - - if request_data.get("stream", False): - return await load_balancer.generate_stream( - modified_request, prefill_server, decode_server, "generate" - ) - else: - return await load_balancer.generate( - modified_request, prefill_server, decode_server, "generate" - ) - - -async def _forward_to_backend(request_data: dict, endpoint_name: str): - prefill_server, bootstrap_port, decode_server = load_balancer.select_pair() - - # Parse and transform prefill_server for bootstrap data - parsed_url = urllib.parse.urlparse(prefill_server) - hostname = maybe_wrap_ipv6_address(parsed_url.hostname) - modified_request = request_data.copy() - modified_request.update( - { - "bootstrap_host": hostname, - "bootstrap_port": bootstrap_port, - "bootstrap_room": _generate_bootstrap_room(), - } - ) - - if request_data.get("stream", False): - return await load_balancer.generate_stream( - modified_request, - prefill_server, - decode_server, - endpoint=endpoint_name, - ) - else: - return await load_balancer.generate( - modified_request, - prefill_server, - decode_server, - endpoint=endpoint_name, - ) - - -@app.post("/v1/chat/completions") -async def handle_chat_completion_request(request_data: dict): - return await _forward_to_backend(request_data, "v1/chat/completions") - - -@app.post("/v1/completions") -async def handle_completion_request(request_data: dict): - return await _forward_to_backend(request_data, "v1/completions") - - -def _generate_bootstrap_room(): - return random.randint(0, 2**63 - 1) - - -# We may utilize `GenerateReqInput`'s logic later -def _get_request_batch_size(request): - if (text := request.get("text")) is not None: - return None if isinstance(text, str) else len(text) - if (input_ids := request.get("input_ids")) is not None: - return None if isinstance(input_ids[0], int) else len(input_ids) - return None - - -@app.get("/v1/models") -async def get_models(): - prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server - async with aiohttp.ClientSession() as session: - try: - response = await session.get(f"{prefill_server}/v1/models") - if response.status != 200: - raise HTTPException( - status_code=response.status, - detail=f"Prefill server error: Status {response.status}", - ) - return ORJSONResponse(content=await response.json()) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.post("/register") -async def register(obj: PDRegistryRequest): - if obj.mode == "prefill": - load_balancer.add_prefill_server( - PrefillConfig(obj.registry_url, obj.bootstrap_port) - ) - logger.info( - f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}" - ) - elif obj.mode == "decode": - load_balancer.add_decode_server(obj.registry_url) - logger.info(f"Registered decode server: {obj.registry_url}") - else: - raise HTTPException( - status_code=400, - detail="Invalid mode. Must be either PREFILL or DECODE.", - ) - - logger.info( - f"#Prefill servers: {len(load_balancer.prefill_configs)}, " - f"#Decode servers: {len(load_balancer.decode_servers)}" - ) - - return Response(status_code=200) - - -def run(prefill_configs, decode_addrs, host, port, timeout): - global load_balancer - load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs, timeout=timeout) - uvicorn.run(app, host=host, port=port) - - -if __name__ == "__main__": - # FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb - from sglang.srt.disaggregation.launch_lb import main - - main() +raise RuntimeError( + """The 'mini_lb' module has been relocated to the 'sglang_router' package. + We recommend installing 'sglang-router' with Rust support for optimal performance. + If you encounter issues building the router with Rust, set the environment variable + 'SGLANG_ROUTER_BUILD_NO_RUST=1' and add '--mini-lb' to the command line to use the Python version of 'mini_lb'.""" +) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 534528087..efe867e5a 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -1,21 +1,17 @@ from __future__ import annotations -import dataclasses import os import random -import threading -import warnings from collections import deque from contextlib import nullcontext from enum import Enum -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional import numpy as np -import requests import torch import torch.distributed as dist -from sglang.srt.utils import get_ip, is_npu +from sglang.srt.utils import is_npu if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -305,49 +301,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int): return (num_kv_indices + page_size - 1) // page_size -######################### -# PDLB Registry -######################### - - -@dataclasses.dataclass -class PDRegistryRequest: - """A request to register a machine itself to the LB.""" - - mode: str - registry_url: str - bootstrap_port: Optional[int] = None - - def __post_init__(self): - if self.mode == "prefill" and self.bootstrap_port is None: - raise ValueError("Bootstrap port must be set in PREFILL mode.") - elif self.mode == "decode" and self.bootstrap_port is not None: - raise ValueError("Bootstrap port must not be set in DECODE mode.") - elif self.mode not in ["prefill", "decode"]: - raise ValueError( - f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'." - ) - - -def register_disaggregation_server( - mode: str, server_port: int, bootstrap_port: int, pdlb_url: str -): - boostrap_port = bootstrap_port if mode == "prefill" else None - registry_request = PDRegistryRequest( - mode=mode, - registry_url=f"http://{get_ip()}:{server_port}", - bootstrap_port=boostrap_port, - ) - res = requests.post( - f"{pdlb_url}/register", - json=dataclasses.asdict(registry_request), - ) - if res.status_code != 200: - warnings.warn( - f"Failed to register disaggregation server: {res.status_code} {res.text}" - ) - - ######################### # Misc ######################### diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 1e7afe26b..dc91d7e84 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -47,11 +47,7 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, Response, StreamingResponse -from sglang.srt.disaggregation.utils import ( - FAKE_BOOTSTRAP_HOST, - DisaggregationMode, - register_disaggregation_server, -) +from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.entrypoints.openai.protocol import ( ChatCompletionRequest, @@ -1405,13 +1401,5 @@ def _wait_and_warmup( if server_args.debug_tensor_dump_input_file: kill_process_tree(os.getpid()) - if server_args.pdlb_url is not None: - register_disaggregation_server( - server_args.disaggregation_mode, - server_args.port, - server_args.disaggregation_bootstrap_port, - server_args.pdlb_url, - ) - if launch_callback is not None: launch_callback() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9466f02ce..aaf9a49f5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -367,7 +367,6 @@ class ServerArgs: disaggregation_prefill_pp: Optional[int] = 1 disaggregation_ib_device: Optional[str] = None num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD - pdlb_url: Optional[str] = None # For model weight update custom_weight_loader: Optional[List[str]] = None @@ -2071,12 +2070,6 @@ class ServerArgs: default=ServerArgs.num_reserved_decode_tokens, help="Number of decode tokens that will have memory reserved when adding new request to the running batch.", ) - parser.add_argument( - "--pdlb-url", - type=str, - default=None, - help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.", - ) # Custom weight loader parser.add_argument( diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 48830b1bc..953fb76df 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -466,6 +466,25 @@ def try_cached_model(model_repo: str): return model_dir if model_dir else model_repo +def popen_with_error_check(command: list[str], allow_exit: bool = False): + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + def _run_and_check(): + stdout, stderr = process.communicate() + + while process.poll() is None: + time.sleep(5) + + if not allow_exit or process.returncode != 0: + raise Exception( + f"{command} exited with code {process.returncode}\n{stdout=}\n{stderr=}" + ) + + t = threading.Thread(target=_run_and_check) + t.start() + return process + + def popen_launch_server( model: str, base_url: str, diff --git a/scripts/ci/ci_install_dependency.sh b/scripts/ci/ci_install_dependency.sh index 95fa01413..199fcbaf0 100755 --- a/scripts/ci/ci_install_dependency.sh +++ b/scripts/ci/ci_install_dependency.sh @@ -45,6 +45,10 @@ fi # Install the main package $PIP_CMD install -e "python[dev]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX +# Install router for pd-disagg test +SGLANG_ROUTER_BUILD_NO_RUST=1 $PIP_CMD install -e "sgl-router" $PIP_INSTALL_SUFFIX + + if [ "$IS_BLACKWELL" = "1" ]; then # TODO auto determine sgl-kernel version SGL_KERNEL_VERSION=0.3.8 diff --git a/sgl-router/py_src/sglang_router/__init__.py b/sgl-router/py_src/sglang_router/__init__.py index 081740479..9c7fa208e 100644 --- a/sgl-router/py_src/sglang_router/__init__.py +++ b/sgl-router/py_src/sglang_router/__init__.py @@ -1,7 +1,3 @@ -# a lightweihgt wrapper on router with argument type and comments -# no wrapper on policy type => direct export -from sglang_router.router import Router from sglang_router.version import __version__ -from sglang_router_rs import PolicyType -__all__ = ["Router", "PolicyType", "__version__"] +__all__ = ["__version__"] diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index e0522592f..506842f84 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -1,654 +1,22 @@ import argparse -import dataclasses import logging import sys -from typing import Dict, List, Optional +from typing import List, Optional -from sglang_router import Router -from sglang_router_rs import PolicyType +import setproctitle +from sglang_router.mini_lb import MiniLoadBalancer +from sglang_router.router_args import RouterArgs +logger = logging.getLogger("router") -def setup_logger(): - logger = logging.getLogger("router") - logger.setLevel(logging.INFO) - - formatter = logging.Formatter( - "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", +try: + from sglang_router.router import Router +except ImportError: + Router = None + logger.warning( + "Rust Router is not installed, only python MiniLB (debugging only) is available" ) - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logger.addHandler(handler) - - return logger - - -@dataclasses.dataclass -class RouterArgs: - # Worker configuration - worker_urls: List[str] = dataclasses.field(default_factory=list) - host: str = "127.0.0.1" - port: int = 30000 - - # PD-specific configuration - pd_disaggregation: bool = False # Enable PD disaggregated mode - prefill_urls: List[tuple] = dataclasses.field( - default_factory=list - ) # List of (url, bootstrap_port) - decode_urls: List[str] = dataclasses.field(default_factory=list) - - # Routing policy - policy: str = "cache_aware" - prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode - decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode - worker_startup_timeout_secs: int = 600 - worker_startup_check_interval: int = 30 - cache_threshold: float = 0.3 - balance_abs_threshold: int = 64 - balance_rel_threshold: float = 1.5 - eviction_interval: int = 120 - max_tree_size: int = 2**26 - max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches - dp_aware: bool = False - api_key: Optional[str] = None - log_dir: Optional[str] = None - log_level: Optional[str] = None - # Service discovery configuration - service_discovery: bool = False - selector: Dict[str, str] = dataclasses.field(default_factory=dict) - service_discovery_port: int = 80 - service_discovery_namespace: Optional[str] = None - # PD service discovery configuration - prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict) - decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict) - bootstrap_port_annotation: str = "sglang.ai/bootstrap-port" - # Prometheus configuration - prometheus_port: Optional[int] = None - prometheus_host: Optional[str] = None - # Request ID headers configuration - request_id_headers: Optional[List[str]] = None - # Request timeout in seconds - request_timeout_secs: int = 1800 - # Max concurrent requests for rate limiting - max_concurrent_requests: int = 256 - # Queue size for pending requests when max concurrent limit reached - queue_size: int = 100 - # Maximum time (in seconds) a request can wait in queue before timing out - queue_timeout_secs: int = 60 - # Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests - rate_limit_tokens_per_second: Optional[int] = None - # CORS allowed origins - cors_allowed_origins: List[str] = dataclasses.field(default_factory=list) - # Retry configuration - retry_max_retries: int = 5 - retry_initial_backoff_ms: int = 50 - retry_max_backoff_ms: int = 30_000 - retry_backoff_multiplier: float = 1.5 - retry_jitter_factor: float = 0.2 - disable_retries: bool = False - # Health check configuration - health_failure_threshold: int = 3 - health_success_threshold: int = 2 - health_check_timeout_secs: int = 5 - health_check_interval_secs: int = 60 - health_check_endpoint: str = "/health" - # Circuit breaker configuration - cb_failure_threshold: int = 10 - cb_success_threshold: int = 3 - cb_timeout_duration_secs: int = 60 - cb_window_duration_secs: int = 120 - disable_circuit_breaker: bool = False - # Tokenizer configuration - model_path: Optional[str] = None - tokenizer_path: Optional[str] = None - - @staticmethod - def add_cli_args( - parser: argparse.ArgumentParser, - use_router_prefix: bool = False, - exclude_host_port: bool = False, - ): - """ - Add router-specific arguments to an argument parser. - - Args: - parser: The argument parser to add arguments to - use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts - exclude_host_port: If True, don't add host and port arguments (used when inheriting from server) - """ - prefix = "router-" if use_router_prefix else "" - - # Worker configuration - if not exclude_host_port: - parser.add_argument( - "--host", - type=str, - default=RouterArgs.host, - help="Host address to bind the router server", - ) - parser.add_argument( - "--port", - type=int, - default=RouterArgs.port, - help="Port number to bind the router server", - ) - - parser.add_argument( - "--worker-urls", - type=str, - nargs="*", - default=[], - help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)", - ) - - # Routing policy configuration - parser.add_argument( - f"--{prefix}policy", - type=str, - default=RouterArgs.policy, - choices=["random", "round_robin", "cache_aware", "power_of_two"], - help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden", - ) - parser.add_argument( - f"--{prefix}prefill-policy", - type=str, - default=None, - choices=["random", "round_robin", "cache_aware", "power_of_two"], - help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy", - ) - parser.add_argument( - f"--{prefix}decode-policy", - type=str, - default=None, - choices=["random", "round_robin", "cache_aware", "power_of_two"], - help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy", - ) - - # PD-specific arguments - parser.add_argument( - f"--{prefix}pd-disaggregation", - action="store_true", - help="Enable PD (Prefill-Decode) disaggregated mode", - ) - parser.add_argument( - f"--{prefix}prefill", - nargs="+", - action="append", - help="Prefill server URL and optional bootstrap port. Can be specified multiple times. " - "Format: --prefill URL [BOOTSTRAP_PORT]. " - "BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).", - ) - parser.add_argument( - f"--{prefix}decode", - nargs=1, - action="append", - metavar=("URL",), - help="Decode server URL. Can be specified multiple times.", - ) - parser.add_argument( - f"--{prefix}worker-startup-timeout-secs", - type=int, - default=RouterArgs.worker_startup_timeout_secs, - help="Timeout in seconds for worker startup", - ) - parser.add_argument( - f"--{prefix}worker-startup-check-interval", - type=int, - default=RouterArgs.worker_startup_check_interval, - help="Interval in seconds between checks for worker startup", - ) - parser.add_argument( - f"--{prefix}cache-threshold", - type=float, - default=RouterArgs.cache_threshold, - help="Cache threshold (0.0-1.0) for cache-aware routing", - ) - parser.add_argument( - f"--{prefix}balance-abs-threshold", - type=int, - default=RouterArgs.balance_abs_threshold, - help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware", - ) - parser.add_argument( - f"--{prefix}balance-rel-threshold", - type=float, - default=RouterArgs.balance_rel_threshold, - help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware", - ) - parser.add_argument( - f"--{prefix}eviction-interval", - type=int, - default=RouterArgs.eviction_interval, - help="Interval in seconds between cache eviction operations", - ) - parser.add_argument( - f"--{prefix}max-tree-size", - type=int, - default=RouterArgs.max_tree_size, - help="Maximum size of the approximation tree for cache-aware routing", - ) - parser.add_argument( - f"--{prefix}max-payload-size", - type=int, - default=RouterArgs.max_payload_size, - help="Maximum payload size in bytes", - ) - parser.add_argument( - f"--{prefix}dp-aware", - action="store_true", - help="Enable data parallelism aware schedule", - ) - parser.add_argument( - f"--{prefix}api-key", - type=str, - default=None, - help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.", - ) - parser.add_argument( - f"--{prefix}log-dir", - type=str, - default=None, - help="Directory to store log files. If not specified, logs are only output to console.", - ) - parser.add_argument( - f"--{prefix}log-level", - type=str, - default="info", - choices=["debug", "info", "warning", "error", "critical"], - help="Set the logging level. If not specified, defaults to INFO.", - ) - parser.add_argument( - f"--{prefix}service-discovery", - action="store_true", - help="Enable Kubernetes service discovery", - ) - parser.add_argument( - f"--{prefix}selector", - type=str, - nargs="+", - help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)", - ) - parser.add_argument( - f"--{prefix}service-discovery-port", - type=int, - default=RouterArgs.service_discovery_port, - help="Port to use for discovered worker pods", - ) - parser.add_argument( - f"--{prefix}service-discovery-namespace", - type=str, - help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)", - ) - parser.add_argument( - f"--{prefix}prefill-selector", - type=str, - nargs="+", - help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)", - ) - parser.add_argument( - f"--{prefix}decode-selector", - type=str, - nargs="+", - help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)", - ) - # Prometheus configuration - parser.add_argument( - f"--{prefix}prometheus-port", - type=int, - default=29000, - help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled", - ) - parser.add_argument( - f"--{prefix}prometheus-host", - type=str, - default="127.0.0.1", - help="Host address to bind the Prometheus metrics server", - ) - parser.add_argument( - f"--{prefix}request-id-headers", - type=str, - nargs="*", - help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.", - ) - parser.add_argument( - f"--{prefix}request-timeout-secs", - type=int, - default=RouterArgs.request_timeout_secs, - help="Request timeout in seconds", - ) - # Retry configuration - parser.add_argument( - f"--{prefix}retry-max-retries", - type=int, - default=RouterArgs.retry_max_retries, - ) - parser.add_argument( - f"--{prefix}retry-initial-backoff-ms", - type=int, - default=RouterArgs.retry_initial_backoff_ms, - ) - parser.add_argument( - f"--{prefix}retry-max-backoff-ms", - type=int, - default=RouterArgs.retry_max_backoff_ms, - ) - parser.add_argument( - f"--{prefix}retry-backoff-multiplier", - type=float, - default=RouterArgs.retry_backoff_multiplier, - ) - parser.add_argument( - f"--{prefix}retry-jitter-factor", - type=float, - default=RouterArgs.retry_jitter_factor, - ) - parser.add_argument( - f"--{prefix}disable-retries", - action="store_true", - help="Disable retries (equivalent to setting retry_max_retries=1)", - ) - # Circuit breaker configuration - parser.add_argument( - f"--{prefix}cb-failure-threshold", - type=int, - default=RouterArgs.cb_failure_threshold, - ) - parser.add_argument( - f"--{prefix}cb-success-threshold", - type=int, - default=RouterArgs.cb_success_threshold, - ) - parser.add_argument( - f"--{prefix}cb-timeout-duration-secs", - type=int, - default=RouterArgs.cb_timeout_duration_secs, - ) - parser.add_argument( - f"--{prefix}cb-window-duration-secs", - type=int, - default=RouterArgs.cb_window_duration_secs, - ) - parser.add_argument( - f"--{prefix}disable-circuit-breaker", - action="store_true", - help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)", - ) - # Health check configuration - parser.add_argument( - f"--{prefix}health-failure-threshold", - type=int, - default=RouterArgs.health_failure_threshold, - help="Number of consecutive health check failures before marking worker unhealthy", - ) - parser.add_argument( - f"--{prefix}health-success-threshold", - type=int, - default=RouterArgs.health_success_threshold, - help="Number of consecutive health check successes before marking worker healthy", - ) - parser.add_argument( - f"--{prefix}health-check-timeout-secs", - type=int, - default=RouterArgs.health_check_timeout_secs, - help="Timeout in seconds for health check requests", - ) - parser.add_argument( - f"--{prefix}health-check-interval-secs", - type=int, - default=RouterArgs.health_check_interval_secs, - help="Interval in seconds between runtime health checks", - ) - parser.add_argument( - f"--{prefix}health-check-endpoint", - type=str, - default=RouterArgs.health_check_endpoint, - help="Health check endpoint path", - ) - parser.add_argument( - f"--{prefix}max-concurrent-requests", - type=int, - default=RouterArgs.max_concurrent_requests, - help="Maximum number of concurrent requests allowed (for rate limiting)", - ) - parser.add_argument( - f"--{prefix}queue-size", - type=int, - default=RouterArgs.queue_size, - help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)", - ) - parser.add_argument( - f"--{prefix}queue-timeout-secs", - type=int, - default=RouterArgs.queue_timeout_secs, - help="Maximum time (in seconds) a request can wait in queue before timing out", - ) - parser.add_argument( - f"--{prefix}rate-limit-tokens-per-second", - type=int, - default=RouterArgs.rate_limit_tokens_per_second, - help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests", - ) - parser.add_argument( - f"--{prefix}cors-allowed-origins", - type=str, - nargs="*", - default=[], - help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)", - ) - # Tokenizer configuration - parser.add_argument( - f"--{prefix}model-path", - type=str, - default=None, - help="Model path for loading tokenizer (HuggingFace model ID or local path)", - ) - parser.add_argument( - f"--{prefix}tokenizer-path", - type=str, - default=None, - help="Explicit tokenizer path (overrides model_path tokenizer if provided)", - ) - - @classmethod - def from_cli_args( - cls, args: argparse.Namespace, use_router_prefix: bool = False - ) -> "RouterArgs": - """ - Create RouterArgs instance from parsed command line arguments. - - Args: - args: Parsed command line arguments - use_router_prefix: If True, look for arguments with 'router-' prefix - """ - prefix = "router_" if use_router_prefix else "" - worker_urls = getattr(args, "worker_urls", []) - - # Parse PD URLs - prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None)) - decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None)) - - return cls( - worker_urls=worker_urls, - host=args.host, - port=args.port, - pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False), - prefill_urls=prefill_urls, - decode_urls=decode_urls, - policy=getattr(args, f"{prefix}policy"), - prefill_policy=getattr(args, f"{prefix}prefill_policy", None), - decode_policy=getattr(args, f"{prefix}decode_policy", None), - worker_startup_timeout_secs=getattr( - args, f"{prefix}worker_startup_timeout_secs" - ), - worker_startup_check_interval=getattr( - args, f"{prefix}worker_startup_check_interval" - ), - cache_threshold=getattr(args, f"{prefix}cache_threshold"), - balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"), - balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"), - eviction_interval=getattr(args, f"{prefix}eviction_interval"), - max_tree_size=getattr(args, f"{prefix}max_tree_size"), - max_payload_size=getattr(args, f"{prefix}max_payload_size"), - dp_aware=getattr(args, f"{prefix}dp_aware", False), - api_key=getattr(args, f"{prefix}api_key", None), - log_dir=getattr(args, f"{prefix}log_dir", None), - log_level=getattr(args, f"{prefix}log_level", None), - service_discovery=getattr(args, f"{prefix}service_discovery", False), - selector=cls._parse_selector(getattr(args, f"{prefix}selector", None)), - service_discovery_port=getattr(args, f"{prefix}service_discovery_port"), - service_discovery_namespace=getattr( - args, f"{prefix}service_discovery_namespace", None - ), - prefill_selector=cls._parse_selector( - getattr(args, f"{prefix}prefill_selector", None) - ), - decode_selector=cls._parse_selector( - getattr(args, f"{prefix}decode_selector", None) - ), - bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation - prometheus_port=getattr(args, f"{prefix}prometheus_port", None), - prometheus_host=getattr(args, f"{prefix}prometheus_host", None), - request_id_headers=getattr(args, f"{prefix}request_id_headers", None), - request_timeout_secs=getattr( - args, f"{prefix}request_timeout_secs", RouterArgs.request_timeout_secs - ), - max_concurrent_requests=getattr( - args, - f"{prefix}max_concurrent_requests", - RouterArgs.max_concurrent_requests, - ), - queue_size=getattr( - args, - f"{prefix}queue_size", - RouterArgs.queue_size, - ), - queue_timeout_secs=getattr( - args, - f"{prefix}queue_timeout_secs", - RouterArgs.queue_timeout_secs, - ), - rate_limit_tokens_per_second=getattr( - args, - f"{prefix}rate_limit_tokens_per_second", - RouterArgs.rate_limit_tokens_per_second, - ), - cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []), - retry_max_retries=getattr(args, f"{prefix}retry_max_retries"), - retry_initial_backoff_ms=getattr(args, f"{prefix}retry_initial_backoff_ms"), - retry_max_backoff_ms=getattr(args, f"{prefix}retry_max_backoff_ms"), - retry_backoff_multiplier=getattr(args, f"{prefix}retry_backoff_multiplier"), - retry_jitter_factor=getattr(args, f"{prefix}retry_jitter_factor"), - cb_failure_threshold=getattr(args, f"{prefix}cb_failure_threshold"), - cb_success_threshold=getattr(args, f"{prefix}cb_success_threshold"), - cb_timeout_duration_secs=getattr(args, f"{prefix}cb_timeout_duration_secs"), - cb_window_duration_secs=getattr(args, f"{prefix}cb_window_duration_secs"), - disable_retries=getattr(args, f"{prefix}disable_retries", False), - disable_circuit_breaker=getattr( - args, f"{prefix}disable_circuit_breaker", False - ), - health_failure_threshold=getattr( - args, - f"{prefix}health_failure_threshold", - RouterArgs.health_failure_threshold, - ), - health_success_threshold=getattr( - args, - f"{prefix}health_success_threshold", - RouterArgs.health_success_threshold, - ), - health_check_timeout_secs=getattr( - args, - f"{prefix}health_check_timeout_secs", - RouterArgs.health_check_timeout_secs, - ), - health_check_interval_secs=getattr( - args, - f"{prefix}health_check_interval_secs", - RouterArgs.health_check_interval_secs, - ), - health_check_endpoint=getattr( - args, f"{prefix}health_check_endpoint", RouterArgs.health_check_endpoint - ), - model_path=getattr(args, f"{prefix}model_path", None), - tokenizer_path=getattr(args, f"{prefix}tokenizer_path", None), - ) - - @staticmethod - def _parse_selector(selector_list): - if not selector_list: - return {} - - selector = {} - for item in selector_list: - if "=" in item: - key, value = item.split("=", 1) - selector[key] = value - return selector - - @staticmethod - def _parse_prefill_urls(prefill_list): - """Parse prefill URLs from --prefill arguments. - - Format: --prefill URL [BOOTSTRAP_PORT] - Example: - --prefill http://prefill1:8080 9000 # With bootstrap port - --prefill http://prefill2:8080 none # Explicitly no bootstrap port - --prefill http://prefill3:8080 # Defaults to no bootstrap port - """ - if not prefill_list: - return [] - - prefill_urls = [] - for prefill_args in prefill_list: - - url = prefill_args[0] - - # Handle optional bootstrap port - if len(prefill_args) >= 2: - bootstrap_port_str = prefill_args[1] - # Handle 'none' as None - if bootstrap_port_str.lower() == "none": - bootstrap_port = None - else: - try: - bootstrap_port = int(bootstrap_port_str) - except ValueError: - raise ValueError( - f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'" - ) - else: - # No bootstrap port specified, default to None - bootstrap_port = None - - prefill_urls.append((url, bootstrap_port)) - - return prefill_urls - - @staticmethod - def _parse_decode_urls(decode_list): - """Parse decode URLs from --decode arguments. - - Format: --decode URL - Example: --decode http://decode1:8081 --decode http://decode2:8081 - """ - if not decode_list: - return [] - - # decode_list is a list of single-element lists due to nargs=1 - return [url[0] for url in decode_list] - - -def policy_from_str(policy_str: str) -> PolicyType: - """Convert policy string to PolicyType enum.""" - policy_map = { - "random": PolicyType.Random, - "round_robin": PolicyType.RoundRobin, - "cache_aware": PolicyType.CacheAware, - "power_of_two": PolicyType.PowerOfTwo, - } - return policy_map[policy_str] - def launch_router(args: argparse.Namespace) -> Optional[Router]: """ @@ -661,7 +29,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: Returns: Router instance if successful, None if failed """ - logger = logging.getLogger("router") + setproctitle.setproctitle("sglang::router") try: # Convert to RouterArgs if needed if not isinstance(args, RouterArgs): @@ -669,120 +37,15 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: else: router_args = args - # Validate configuration based on mode - if router_args.pd_disaggregation: - # Validate PD configuration - skip URL requirements if using service discovery - if not router_args.service_discovery: - if not router_args.prefill_urls: - raise ValueError("PD disaggregation mode requires --prefill") - if not router_args.decode_urls: - raise ValueError("PD disaggregation mode requires --decode") - - # Warn about policy usage in PD mode - if ( - router_args.prefill_policy - and router_args.decode_policy - and router_args.policy - ): - logger.warning( - "Both --prefill-policy and --decode-policy are specified. " - "The main --policy flag will be ignored for PD mode." - ) - elif ( - router_args.prefill_policy - and not router_args.decode_policy - and router_args.policy - ): - logger.info( - f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes " - f"and --policy '{router_args.policy}' for decode nodes." - ) - elif ( - router_args.decode_policy - and not router_args.prefill_policy - and router_args.policy - ): - logger.info( - f"Using --policy '{router_args.policy}' for prefill nodes " - f"and --decode-policy '{router_args.decode_policy}' for decode nodes." - ) - - # Create router with unified constructor - router = Router( - worker_urls=( - [] - if router_args.service_discovery or router_args.pd_disaggregation - else router_args.worker_urls - ), - host=router_args.host, - port=router_args.port, - policy=policy_from_str(router_args.policy), - worker_startup_timeout_secs=router_args.worker_startup_timeout_secs, - worker_startup_check_interval=router_args.worker_startup_check_interval, - cache_threshold=router_args.cache_threshold, - balance_abs_threshold=router_args.balance_abs_threshold, - balance_rel_threshold=router_args.balance_rel_threshold, - eviction_interval_secs=router_args.eviction_interval, - max_tree_size=router_args.max_tree_size, - max_payload_size=router_args.max_payload_size, - dp_aware=router_args.dp_aware, - api_key=router_args.api_key, - log_dir=router_args.log_dir, - log_level=router_args.log_level, - service_discovery=router_args.service_discovery, - selector=router_args.selector, - service_discovery_port=router_args.service_discovery_port, - service_discovery_namespace=router_args.service_discovery_namespace, - prefill_selector=router_args.prefill_selector, - decode_selector=router_args.decode_selector, - prometheus_port=router_args.prometheus_port, - prometheus_host=router_args.prometheus_host, - request_timeout_secs=router_args.request_timeout_secs, - pd_disaggregation=router_args.pd_disaggregation, - prefill_urls=( - router_args.prefill_urls if router_args.pd_disaggregation else None - ), - decode_urls=( - router_args.decode_urls if router_args.pd_disaggregation else None - ), - prefill_policy=( - policy_from_str(router_args.prefill_policy) - if router_args.prefill_policy - else None - ), - decode_policy=( - policy_from_str(router_args.decode_policy) - if router_args.decode_policy - else None - ), - request_id_headers=router_args.request_id_headers, - max_concurrent_requests=router_args.max_concurrent_requests, - queue_size=router_args.queue_size, - queue_timeout_secs=router_args.queue_timeout_secs, - rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second, - cors_allowed_origins=router_args.cors_allowed_origins, - retry_max_retries=router_args.retry_max_retries, - retry_initial_backoff_ms=router_args.retry_initial_backoff_ms, - retry_max_backoff_ms=router_args.retry_max_backoff_ms, - retry_backoff_multiplier=router_args.retry_backoff_multiplier, - retry_jitter_factor=router_args.retry_jitter_factor, - cb_failure_threshold=router_args.cb_failure_threshold, - cb_success_threshold=router_args.cb_success_threshold, - cb_timeout_duration_secs=router_args.cb_timeout_duration_secs, - cb_window_duration_secs=router_args.cb_window_duration_secs, - disable_retries=router_args.disable_retries, - disable_circuit_breaker=router_args.disable_circuit_breaker, - health_failure_threshold=router_args.health_failure_threshold, - health_success_threshold=router_args.health_success_threshold, - health_check_timeout_secs=router_args.health_check_timeout_secs, - health_check_interval_secs=router_args.health_check_interval_secs, - health_check_endpoint=router_args.health_check_endpoint, - model_path=router_args.model_path, - tokenizer_path=router_args.tokenizer_path, - ) - - router.start() - return router + if router_args.mini_lb: + mini_lb = MiniLoadBalancer(router_args) + mini_lb.start() + else: + if Router is None: + raise RuntimeError("Rust Router is not installed") + router_args._validate_router_args() + router = Router.from_args(router_args) + router.start() except Exception as e: logger.error(f"Error starting router: {e}") diff --git a/sgl-router/py_src/sglang_router/mini_lb.py b/sgl-router/py_src/sglang_router/mini_lb.py new file mode 100644 index 000000000..920d5c38f --- /dev/null +++ b/sgl-router/py_src/sglang_router/mini_lb.py @@ -0,0 +1,395 @@ +""" +Minimal HTTP load balancer for prefill and decode servers for testing. +""" + +import asyncio +import ipaddress +import logging +import random +import urllib +from http import HTTPStatus +from itertools import chain +from typing import Optional + +import aiohttp +import orjson +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.responses import ORJSONResponse, Response, StreamingResponse +from sglang_router.router_args import RouterArgs + +logger = logging.getLogger(__name__) + +AIOHTTP_STREAM_READ_CHUNK_SIZE = ( + 1024 * 64 +) # 64KB, to prevent aiohttp's "Chunk too big" error + + +def maybe_wrap_ipv6_address(address: str) -> str: + try: + ipaddress.IPv6Address(address) + return f"[{address}]" + except ValueError: + return address + + +class MiniLoadBalancer: + def __init__( + self, + router_args: RouterArgs, + ): + self._validate_router_args(router_args) + + self.host = router_args.host + self.port = router_args.port + self.timeout = router_args.request_timeout_secs + self.prefill_urls = [url[0] for url in router_args.prefill_urls] + self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls] + self.decode_urls = router_args.decode_urls + + def _validate_router_args(self, router_args: RouterArgs): + logger.warning( + "\x1b[33mMiniLB is only for debugging purposes, it only supports random policy!\033[0m" + ) + + # NOTE: too many arguments unsupported, just validate some important ones + if router_args.policy != "random": + logger.warning("[MiniLB] Overriding policy to random") + router_args.policy = "random" + + if not router_args.pd_disaggregation: + raise ValueError("MiniLB only supports PD disaggregation mode") + + if len(router_args.prefill_urls) == 0 or len(router_args.decode_urls) == 0: + raise ValueError( + "MiniLB requires at least one prefill and one decode server" + ) + + def start(self): + global lb + lb = self + uvicorn.run(app, host=self.host, port=self.port) + + def select_pair(self): + assert len(self.prefill_urls) > 0, "No prefill servers available" + assert len(self.decode_urls) > 0, "No decode servers available" + pidx = random.randint(0, len(self.prefill_urls) - 1) + didx = random.randint(0, len(self.decode_urls) - 1) + return ( + self.prefill_urls[pidx], + self.prefill_bootstrap_ports[pidx], + self.decode_urls[didx], + ) + + async def generate( + self, modified_request, prefill_server, decode_server, endpoint + ) -> ORJSONResponse: + assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=self.timeout + ) # Add timeout for request reliability + ) as session: + tasks = [ + session.post(f"{prefill_server}/{endpoint}", json=modified_request), + session.post(f"{decode_server}/{endpoint}", json=modified_request), + ] + + # Wait for both responses to complete. Prefill should end first. + prefill_response, decode_response = await asyncio.gather(*tasks) + + if "return_logprob" in modified_request: + + prefill_json = await prefill_response.json() + ret_json = await decode_response.json() + + # merge `meta_info.input_token_logprobs` from prefill to decode + if "meta_info" in ret_json: + if "input_token_logprobs" in ret_json["meta_info"]: + ret_json["meta_info"]["input_token_logprobs"] = ( + prefill_json["meta_info"]["input_token_logprobs"] + + ret_json["meta_info"]["input_token_logprobs"] + ) + else: + ret_json = await decode_response.json() + + return ORJSONResponse( + content=ret_json, + status_code=decode_response.status, + ) + + async def generate_stream( + self, modified_request, prefill_server, decode_server, endpoint="generate" + ): + assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" + + async def stream_results(): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=self.timeout + ) # Add timeout for request reliability + ) as session: + # Create the tasks for both prefill and decode requests + tasks = [ + session.post(f"{prefill_server}/{endpoint}", json=modified_request), + session.post(f"{decode_server}/{endpoint}", json=modified_request), + ] + # Wait for both responses to complete. Since this is streaming, they return immediately. + prefill_response, decode_response = await asyncio.gather(*tasks) + + if modified_request.get("return_logprob", False): + prefill_chunks = [] + async for chunk in prefill_response.content: + prefill_chunks.append(chunk) + + first_prefill_chunk = ( + prefill_chunks[0].decode("utf-8")[5:].strip("\n") + ) + first_prefill_chunk_json = orjson.loads(first_prefill_chunk) + + async for chunk in decode_response.content: + # Note: This is inefficient + # merge prefill input_token_logprobs, output_token_logprobs to decode + decoded_chunk = chunk.decode("utf-8") + if ( + decoded_chunk + and decoded_chunk.startswith("data:") + and "[DONE]" not in decoded_chunk + ): + ret_json = orjson.loads(decoded_chunk[5:].strip("\n")) + ret_json["meta_info"]["input_token_logprobs"] = ( + first_prefill_chunk_json["meta_info"][ + "input_token_logprobs" + ] + + ret_json["meta_info"]["input_token_logprobs"] + ) + + yield b"data: " + orjson.dumps(ret_json) + b"\n\n" + else: + yield chunk + else: + async for chunk in decode_response.content.iter_chunked( + AIOHTTP_STREAM_READ_CHUNK_SIZE + ): + yield chunk + + return StreamingResponse( + stream_results(), + media_type="text/event-stream", + ) + + +app = FastAPI() +lb: Optional[MiniLoadBalancer] = None + + +@app.get("/health") +async def health_check(): + return Response(status_code=200) + + +@app.get("/health_generate") +async def health_generate(): + async with aiohttp.ClientSession() as session: + # Create the tasks + tasks = [] + for server in chain(lb.prefill_urls, lb.decode_urls): + tasks.append(session.get(f"{server}/health_generate")) + for i, response in enumerate(asyncio.as_completed(tasks)): + await response + return Response(status_code=200) + + +@app.post("/flush_cache") +async def flush_cache(): + async with aiohttp.ClientSession() as session: + # Create the tasks + tasks = [] + for server in chain(lb.prefill_urls, lb.decode_urls): + tasks.append(session.post(f"{server}/flush_cache")) + for i, response in enumerate(asyncio.as_completed(tasks)): + await response + return Response(status_code=200) + + +@app.get("/get_server_info") +async def get_server_info(): + prefill_infos = [] + decode_infos = [] + all_internal_states = [] + + async with aiohttp.ClientSession() as session: + for server in lb.prefill_urls: + server_info = await session.get(f"{server}/get_server_info") + prefill_infos.append(await server_info.json()) + for server in lb.decode_urls: + server_info = await session.get(f"{server}/get_server_info") + info_json = await server_info.json() + decode_infos.append(info_json) + # Extract internal_states from decode servers + if "internal_states" in info_json: + all_internal_states.extend(info_json["internal_states"]) + + # Return format expected by bench_one_batch_server.py + if all_internal_states: + return { + "internal_states": all_internal_states, + "prefill": prefill_infos, + "decode": decode_infos, + } + else: + # Fallback with dummy data if no internal states found + return { + "internal_states": [ + { + "last_gen_throughput": 0.0, + "avg_spec_accept_length": None, + } + ], + "prefill": prefill_infos, + "decode": decode_infos, + } + + +@app.get("/get_model_info") +async def get_model_info(): + if not lb or not lb.prefill_urls: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + detail="There is no server registered", + ) + + target_server_url = lb.prefill_urls[0] + endpoint_url = f"{target_server_url}/get_model_info" + + async with aiohttp.ClientSession() as session: + try: + async with session.get(endpoint_url) as response: + if response.status != 200: + error_text = await response.text() + raise HTTPException( + status_code=HTTPStatus.BAD_GATEWAY, + detail=( + f"Failed to get model info from {target_server_url}" + f"Status: {response.status}, Response: {error_text}" + ), + ) + + model_info_json = await response.json() + return ORJSONResponse(content=model_info_json) + + except aiohttp.ClientError as e: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + detail=f"Failed to get model info from backend", + ) + + +@app.post("/generate") +async def handle_generate_request(request_data: dict): + prefill_server, bootstrap_port, decode_server = lb.select_pair() + + # Parse and transform prefill_server for bootstrap data + parsed_url = urllib.parse.urlparse(prefill_server) + hostname = maybe_wrap_ipv6_address(parsed_url.hostname) + modified_request = request_data.copy() + + batch_size = _get_request_batch_size(modified_request) + if batch_size is not None: + modified_request.update( + { + "bootstrap_host": [hostname] * batch_size, + "bootstrap_port": [bootstrap_port] * batch_size, + "bootstrap_room": [ + _generate_bootstrap_room() for _ in range(batch_size) + ], + } + ) + else: + modified_request.update( + { + "bootstrap_host": hostname, + "bootstrap_port": bootstrap_port, + "bootstrap_room": _generate_bootstrap_room(), + } + ) + + if request_data.get("stream", False): + return await lb.generate_stream( + modified_request, prefill_server, decode_server, "generate" + ) + else: + return await lb.generate( + modified_request, prefill_server, decode_server, "generate" + ) + + +async def _forward_to_backend(request_data: dict, endpoint_name: str): + prefill_server, bootstrap_port, decode_server = lb.select_pair() + + # Parse and transform prefill_server for bootstrap data + parsed_url = urllib.parse.urlparse(prefill_server) + hostname = maybe_wrap_ipv6_address(parsed_url.hostname) + modified_request = request_data.copy() + modified_request.update( + { + "bootstrap_host": hostname, + "bootstrap_port": bootstrap_port, + "bootstrap_room": _generate_bootstrap_room(), + } + ) + + if request_data.get("stream", False): + return await lb.generate_stream( + modified_request, + prefill_server, + decode_server, + endpoint=endpoint_name, + ) + else: + return await lb.generate( + modified_request, + prefill_server, + decode_server, + endpoint=endpoint_name, + ) + + +@app.post("/v1/chat/completions") +async def handle_chat_completion_request(request_data: dict): + return await _forward_to_backend(request_data, "v1/chat/completions") + + +@app.post("/v1/completions") +async def handle_completion_request(request_data: dict): + return await _forward_to_backend(request_data, "v1/completions") + + +def _generate_bootstrap_room(): + return random.randint(0, 2**63 - 1) + + +# We may utilize `GenerateReqInput`'s logic later +def _get_request_batch_size(request): + if (text := request.get("text")) is not None: + return None if isinstance(text, str) else len(text) + if (input_ids := request.get("input_ids")) is not None: + return None if isinstance(input_ids[0], int) else len(input_ids) + return None + + +@app.get("/v1/models") +async def get_models(): + prefill_server = lb.prefill_urls[0] # Get the first prefill server + async with aiohttp.ClientSession() as session: + try: + response = await session.get(f"{prefill_server}/v1/models") + if response.status != 200: + raise HTTPException( + status_code=response.status, + detail=f"Prefill server error: Status {response.status}", + ) + return ORJSONResponse(content=await response.json()) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index de504bafc..72a99ffbb 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -1,9 +1,23 @@ -from typing import Dict, List, Optional +from typing import Optional +from sglang_router.router_args import RouterArgs from sglang_router_rs import PolicyType from sglang_router_rs import Router as _Router +def policy_from_str(policy_str: Optional[str]) -> PolicyType: + """Convert policy string to PolicyType enum.""" + if policy_str is None: + return None + policy_map = { + "random": PolicyType.Random, + "round_robin": PolicyType.RoundRobin, + "cache_aware": PolicyType.CacheAware, + "power_of_two": PolicyType.PowerOfTwo, + } + return policy_map[policy_str] + + class Router: """ A high-performance router for distributing requests across worker nodes. @@ -78,130 +92,34 @@ class Router: tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None """ - def __init__( - self, - worker_urls: List[str], - policy: PolicyType = PolicyType.RoundRobin, - host: str = "127.0.0.1", - port: int = 3001, - worker_startup_timeout_secs: int = 600, - worker_startup_check_interval: int = 30, - cache_threshold: float = 0.3, - balance_abs_threshold: int = 64, - balance_rel_threshold: float = 1.5, - eviction_interval_secs: int = 120, - max_tree_size: int = 2**26, - max_payload_size: int = 512 * 1024 * 1024, # 512MB - dp_aware: bool = False, - api_key: Optional[str] = None, - log_dir: Optional[str] = None, - log_level: Optional[str] = None, - service_discovery: bool = False, - selector: Dict[str, str] = None, - service_discovery_port: int = 80, - service_discovery_namespace: Optional[str] = None, - prefill_selector: Dict[str, str] = None, - decode_selector: Dict[str, str] = None, - bootstrap_port_annotation: str = "sglang.ai/bootstrap-port", - prometheus_port: Optional[int] = None, - prometheus_host: Optional[str] = None, - request_timeout_secs: int = 1800, - request_id_headers: Optional[List[str]] = None, - pd_disaggregation: bool = False, - prefill_urls: Optional[List[tuple]] = None, - decode_urls: Optional[List[str]] = None, - prefill_policy: Optional[PolicyType] = None, - decode_policy: Optional[PolicyType] = None, - max_concurrent_requests: int = 256, - queue_size: int = 100, - queue_timeout_secs: int = 60, - rate_limit_tokens_per_second: Optional[int] = None, - cors_allowed_origins: List[str] = None, - retry_max_retries: int = 5, - retry_initial_backoff_ms: int = 50, - retry_max_backoff_ms: int = 30_000, - retry_backoff_multiplier: float = 1.5, - retry_jitter_factor: float = 0.2, - cb_failure_threshold: int = 10, - cb_success_threshold: int = 3, - cb_timeout_duration_secs: int = 60, - cb_window_duration_secs: int = 120, - disable_retries: bool = False, - disable_circuit_breaker: bool = False, - health_failure_threshold: int = 3, - health_success_threshold: int = 2, - health_check_timeout_secs: int = 5, - health_check_interval_secs: int = 60, - health_check_endpoint: str = "/health", - model_path: Optional[str] = None, - tokenizer_path: Optional[str] = None, - ): - if selector is None: - selector = {} - if prefill_selector is None: - prefill_selector = {} - if decode_selector is None: - decode_selector = {} - if cors_allowed_origins is None: - cors_allowed_origins = [] + def __init__(self, router: _Router): + self._router = router - self._router = _Router( - worker_urls=worker_urls, - policy=policy, - host=host, - port=port, - worker_startup_timeout_secs=worker_startup_timeout_secs, - worker_startup_check_interval=worker_startup_check_interval, - cache_threshold=cache_threshold, - balance_abs_threshold=balance_abs_threshold, - balance_rel_threshold=balance_rel_threshold, - eviction_interval_secs=eviction_interval_secs, - max_tree_size=max_tree_size, - max_payload_size=max_payload_size, - dp_aware=dp_aware, - api_key=api_key, - log_dir=log_dir, - log_level=log_level, - service_discovery=service_discovery, - selector=selector, - service_discovery_port=service_discovery_port, - service_discovery_namespace=service_discovery_namespace, - prefill_selector=prefill_selector, - decode_selector=decode_selector, - bootstrap_port_annotation=bootstrap_port_annotation, - prometheus_port=prometheus_port, - prometheus_host=prometheus_host, - request_timeout_secs=request_timeout_secs, - request_id_headers=request_id_headers, - pd_disaggregation=pd_disaggregation, - prefill_urls=prefill_urls, - decode_urls=decode_urls, - prefill_policy=prefill_policy, - decode_policy=decode_policy, - max_concurrent_requests=max_concurrent_requests, - queue_size=queue_size, - queue_timeout_secs=queue_timeout_secs, - rate_limit_tokens_per_second=rate_limit_tokens_per_second, - cors_allowed_origins=cors_allowed_origins, - retry_max_retries=retry_max_retries, - retry_initial_backoff_ms=retry_initial_backoff_ms, - retry_max_backoff_ms=retry_max_backoff_ms, - retry_backoff_multiplier=retry_backoff_multiplier, - retry_jitter_factor=retry_jitter_factor, - cb_failure_threshold=cb_failure_threshold, - cb_success_threshold=cb_success_threshold, - cb_timeout_duration_secs=cb_timeout_duration_secs, - cb_window_duration_secs=cb_window_duration_secs, - disable_retries=disable_retries, - disable_circuit_breaker=disable_circuit_breaker, - health_failure_threshold=health_failure_threshold, - health_success_threshold=health_success_threshold, - health_check_timeout_secs=health_check_timeout_secs, - health_check_interval_secs=health_check_interval_secs, - health_check_endpoint=health_check_endpoint, - model_path=model_path, - tokenizer_path=tokenizer_path, + @staticmethod + def from_args(args: RouterArgs) -> "Router": + """Create a router from a RouterArgs instance.""" + + args_dict = vars(args) + # Convert RouterArgs to _Router parameters + args_dict["worker_urls"] = ( + [] + if args_dict["service_discovery"] or args_dict["pd_disaggregation"] + else args_dict["worker_urls"] ) + args_dict["policy"] = policy_from_str(args_dict["policy"]) + args_dict["prefill_urls"] = ( + args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None + ) + args_dict["decode_urls"] = ( + args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None + ) + args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"]) + args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"]) + + # remoge mini_lb parameter + args_dict.pop("mini_lb") + + return Router(_Router(**args_dict)) def start(self) -> None: """Start the router server. diff --git a/sgl-router/py_src/sglang_router/router_args.py b/sgl-router/py_src/sglang_router/router_args.py new file mode 100644 index 000000000..ad0a2ac9f --- /dev/null +++ b/sgl-router/py_src/sglang_router/router_args.py @@ -0,0 +1,577 @@ +import argparse +import dataclasses +import logging +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class RouterArgs: + # Worker configuration + worker_urls: List[str] = dataclasses.field(default_factory=list) + host: str = "127.0.0.1" + port: int = 30000 + + # PD-specific configuration + mini_lb: bool = False + pd_disaggregation: bool = False # Enable PD disaggregated mode + prefill_urls: List[tuple] = dataclasses.field( + default_factory=list + ) # List of (url, bootstrap_port) + decode_urls: List[str] = dataclasses.field(default_factory=list) + + # Routing policy + policy: str = "cache_aware" + prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode + decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode + worker_startup_timeout_secs: int = 600 + worker_startup_check_interval: int = 30 + cache_threshold: float = 0.3 + balance_abs_threshold: int = 64 + balance_rel_threshold: float = 1.5 + eviction_interval_secs: int = 120 + max_tree_size: int = 2**26 + max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches + dp_aware: bool = False + api_key: Optional[str] = None + log_dir: Optional[str] = None + log_level: Optional[str] = None + # Service discovery configuration + service_discovery: bool = False + selector: Dict[str, str] = dataclasses.field(default_factory=dict) + service_discovery_port: int = 80 + service_discovery_namespace: Optional[str] = None + # PD service discovery configuration + prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict) + decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict) + bootstrap_port_annotation: str = "sglang.ai/bootstrap-port" + # Prometheus configuration + prometheus_port: Optional[int] = None + prometheus_host: Optional[str] = None + # Request ID headers configuration + request_id_headers: Optional[List[str]] = None + # Request timeout in seconds + request_timeout_secs: int = 1800 + # Max concurrent requests for rate limiting + max_concurrent_requests: int = 256 + # Queue size for pending requests when max concurrent limit reached + queue_size: int = 100 + # Maximum time (in seconds) a request can wait in queue before timing out + queue_timeout_secs: int = 60 + # Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests + rate_limit_tokens_per_second: Optional[int] = None + # CORS allowed origins + cors_allowed_origins: List[str] = dataclasses.field(default_factory=list) + # Retry configuration + retry_max_retries: int = 5 + retry_initial_backoff_ms: int = 50 + retry_max_backoff_ms: int = 30_000 + retry_backoff_multiplier: float = 1.5 + retry_jitter_factor: float = 0.2 + disable_retries: bool = False + # Health check configuration + health_failure_threshold: int = 3 + health_success_threshold: int = 2 + health_check_timeout_secs: int = 5 + health_check_interval_secs: int = 60 + health_check_endpoint: str = "/health" + # Circuit breaker configuration + cb_failure_threshold: int = 10 + cb_success_threshold: int = 3 + cb_timeout_duration_secs: int = 60 + cb_window_duration_secs: int = 120 + disable_circuit_breaker: bool = False + # Tokenizer configuration + model_path: Optional[str] = None + tokenizer_path: Optional[str] = None + + @staticmethod + def add_cli_args( + parser: argparse.ArgumentParser, + use_router_prefix: bool = False, + exclude_host_port: bool = False, + ): + """ + Add router-specific arguments to an argument parser. + + Args: + parser: The argument parser to add arguments to + use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts + exclude_host_port: If True, don't add host and port arguments (used when inheriting from server) + """ + prefix = "router-" if use_router_prefix else "" + + # Worker configuration + if not exclude_host_port: + parser.add_argument( + "--host", + type=str, + default=RouterArgs.host, + help="Host address to bind the router server", + ) + parser.add_argument( + "--port", + type=int, + default=RouterArgs.port, + help="Port number to bind the router server", + ) + + parser.add_argument( + "--worker-urls", + type=str, + nargs="*", + default=[], + help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)", + ) + + # Routing policy configuration + parser.add_argument( + f"--{prefix}policy", + type=str, + default=RouterArgs.policy, + choices=["random", "round_robin", "cache_aware", "power_of_two"], + help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden", + ) + parser.add_argument( + f"--{prefix}prefill-policy", + type=str, + default=None, + choices=["random", "round_robin", "cache_aware", "power_of_two"], + help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy", + ) + parser.add_argument( + f"--{prefix}decode-policy", + type=str, + default=None, + choices=["random", "round_robin", "cache_aware", "power_of_two"], + help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy", + ) + + # PD-specific arguments + parser.add_argument( + f"--{prefix}mini-lb", + action="store_true", + help="Enable MiniLB", + ) + parser.add_argument( + f"--{prefix}pd-disaggregation", + action="store_true", + help="Enable PD (Prefill-Decode) disaggregated mode", + ) + parser.add_argument( + f"--{prefix}prefill", + nargs="+", + action="append", + help="Prefill server URL and optional bootstrap port. Can be specified multiple times. " + "Format: --prefill URL [BOOTSTRAP_PORT]. " + "BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).", + ) + parser.add_argument( + f"--{prefix}decode", + nargs=1, + action="append", + metavar=("URL",), + help="Decode server URL. Can be specified multiple times.", + ) + parser.add_argument( + f"--{prefix}worker-startup-timeout-secs", + type=int, + default=RouterArgs.worker_startup_timeout_secs, + help="Timeout in seconds for worker startup", + ) + parser.add_argument( + f"--{prefix}worker-startup-check-interval", + type=int, + default=RouterArgs.worker_startup_check_interval, + help="Interval in seconds between checks for worker startup", + ) + parser.add_argument( + f"--{prefix}cache-threshold", + type=float, + default=RouterArgs.cache_threshold, + help="Cache threshold (0.0-1.0) for cache-aware routing", + ) + parser.add_argument( + f"--{prefix}balance-abs-threshold", + type=int, + default=RouterArgs.balance_abs_threshold, + help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware", + ) + parser.add_argument( + f"--{prefix}balance-rel-threshold", + type=float, + default=RouterArgs.balance_rel_threshold, + help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware", + ) + parser.add_argument( + f"--{prefix}eviction-interval-secs", + type=int, + default=RouterArgs.eviction_interval_secs, + help="Interval in seconds between cache eviction operations", + ) + parser.add_argument( + f"--{prefix}max-tree-size", + type=int, + default=RouterArgs.max_tree_size, + help="Maximum size of the approximation tree for cache-aware routing", + ) + parser.add_argument( + f"--{prefix}max-payload-size", + type=int, + default=RouterArgs.max_payload_size, + help="Maximum payload size in bytes", + ) + parser.add_argument( + f"--{prefix}dp-aware", + action="store_true", + help="Enable data parallelism aware schedule", + ) + parser.add_argument( + f"--{prefix}api-key", + type=str, + default=None, + help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.", + ) + parser.add_argument( + f"--{prefix}log-dir", + type=str, + default=None, + help="Directory to store log files. If not specified, logs are only output to console.", + ) + parser.add_argument( + f"--{prefix}log-level", + type=str, + default="info", + choices=["debug", "info", "warning", "error", "critical"], + help="Set the logging level. If not specified, defaults to INFO.", + ) + parser.add_argument( + f"--{prefix}service-discovery", + action="store_true", + help="Enable Kubernetes service discovery", + ) + parser.add_argument( + f"--{prefix}selector", + type=str, + nargs="+", + default={}, + help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)", + ) + parser.add_argument( + f"--{prefix}service-discovery-port", + type=int, + default=RouterArgs.service_discovery_port, + help="Port to use for discovered worker pods", + ) + parser.add_argument( + f"--{prefix}service-discovery-namespace", + type=str, + help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)", + ) + parser.add_argument( + f"--{prefix}prefill-selector", + type=str, + nargs="+", + default={}, + help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)", + ) + parser.add_argument( + f"--{prefix}decode-selector", + type=str, + nargs="+", + default={}, + help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)", + ) + # Prometheus configuration + parser.add_argument( + f"--{prefix}prometheus-port", + type=int, + default=29000, + help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled", + ) + parser.add_argument( + f"--{prefix}prometheus-host", + type=str, + default="127.0.0.1", + help="Host address to bind the Prometheus metrics server", + ) + parser.add_argument( + f"--{prefix}request-id-headers", + type=str, + nargs="*", + help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.", + ) + parser.add_argument( + f"--{prefix}request-timeout-secs", + type=int, + default=RouterArgs.request_timeout_secs, + help="Request timeout in seconds", + ) + # Retry configuration + parser.add_argument( + f"--{prefix}retry-max-retries", + type=int, + default=RouterArgs.retry_max_retries, + ) + parser.add_argument( + f"--{prefix}retry-initial-backoff-ms", + type=int, + default=RouterArgs.retry_initial_backoff_ms, + ) + parser.add_argument( + f"--{prefix}retry-max-backoff-ms", + type=int, + default=RouterArgs.retry_max_backoff_ms, + ) + parser.add_argument( + f"--{prefix}retry-backoff-multiplier", + type=float, + default=RouterArgs.retry_backoff_multiplier, + ) + parser.add_argument( + f"--{prefix}retry-jitter-factor", + type=float, + default=RouterArgs.retry_jitter_factor, + ) + parser.add_argument( + f"--{prefix}disable-retries", + action="store_true", + help="Disable retries (equivalent to setting retry_max_retries=1)", + ) + # Circuit breaker configuration + parser.add_argument( + f"--{prefix}cb-failure-threshold", + type=int, + default=RouterArgs.cb_failure_threshold, + ) + parser.add_argument( + f"--{prefix}cb-success-threshold", + type=int, + default=RouterArgs.cb_success_threshold, + ) + parser.add_argument( + f"--{prefix}cb-timeout-duration-secs", + type=int, + default=RouterArgs.cb_timeout_duration_secs, + ) + parser.add_argument( + f"--{prefix}cb-window-duration-secs", + type=int, + default=RouterArgs.cb_window_duration_secs, + ) + parser.add_argument( + f"--{prefix}disable-circuit-breaker", + action="store_true", + help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)", + ) + # Health check configuration + parser.add_argument( + f"--{prefix}health-failure-threshold", + type=int, + default=RouterArgs.health_failure_threshold, + help="Number of consecutive health check failures before marking worker unhealthy", + ) + parser.add_argument( + f"--{prefix}health-success-threshold", + type=int, + default=RouterArgs.health_success_threshold, + help="Number of consecutive health check successes before marking worker healthy", + ) + parser.add_argument( + f"--{prefix}health-check-timeout-secs", + type=int, + default=RouterArgs.health_check_timeout_secs, + help="Timeout in seconds for health check requests", + ) + parser.add_argument( + f"--{prefix}health-check-interval-secs", + type=int, + default=RouterArgs.health_check_interval_secs, + help="Interval in seconds between runtime health checks", + ) + parser.add_argument( + f"--{prefix}health-check-endpoint", + type=str, + default=RouterArgs.health_check_endpoint, + help="Health check endpoint path", + ) + parser.add_argument( + f"--{prefix}max-concurrent-requests", + type=int, + default=RouterArgs.max_concurrent_requests, + help="Maximum number of concurrent requests allowed (for rate limiting)", + ) + parser.add_argument( + f"--{prefix}queue-size", + type=int, + default=RouterArgs.queue_size, + help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)", + ) + parser.add_argument( + f"--{prefix}queue-timeout-secs", + type=int, + default=RouterArgs.queue_timeout_secs, + help="Maximum time (in seconds) a request can wait in queue before timing out", + ) + parser.add_argument( + f"--{prefix}rate-limit-tokens-per-second", + type=int, + default=RouterArgs.rate_limit_tokens_per_second, + help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests", + ) + parser.add_argument( + f"--{prefix}cors-allowed-origins", + type=str, + nargs="*", + default=[], + help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)", + ) + # Tokenizer configuration + parser.add_argument( + f"--{prefix}model-path", + type=str, + default=None, + help="Model path for loading tokenizer (HuggingFace model ID or local path)", + ) + parser.add_argument( + f"--{prefix}tokenizer-path", + type=str, + default=None, + help="Explicit tokenizer path (overrides model_path tokenizer if provided)", + ) + + @classmethod + def from_cli_args( + cls, args: argparse.Namespace, use_router_prefix: bool = False + ) -> "RouterArgs": + """ + Create RouterArgs instance from parsed command line arguments. + + Args: + args: Parsed command line arguments + use_router_prefix: If True, look for arguments with 'router-' prefix + """ + prefix = "router_" if use_router_prefix else "" + cli_args_dict = vars(args) + args_dict = {} + + for attr in dataclasses.fields(cls): + # Auto strip prefix from args + if f"{prefix}{attr.name}" in cli_args_dict: + args_dict[attr.name] = cli_args_dict[f"{prefix}{attr.name}"] + elif attr.name in cli_args_dict: + args_dict[attr.name] = cli_args_dict[attr.name] + + # parse special arguments and remove "--prefill" and "--decode" from cli_args_dict + args_dict["prefill_urls"] = cls._parse_prefill_urls( + cli_args_dict.get(f"{prefix}prefill", None) + ) + args_dict["decode_urls"] = cls._parse_decode_urls( + cli_args_dict.get(f"{prefix}decode", None) + ) + args_dict["selector"] = cls._parse_selector( + cli_args_dict.get(f"{prefix}selector", None) + ) + args_dict["prefill_selector"] = cls._parse_selector( + cli_args_dict.get(f"{prefix}prefill_selector", None) + ) + args_dict["decode_selector"] = cls._parse_selector( + cli_args_dict.get(f"{prefix}decode_selector", None) + ) + + # Mooncake-specific annotation + args_dict["bootstrap_port_annotation"] = "sglang.ai/bootstrap-port" + + return cls(**args_dict) + + def _validate_router_args(self): + # Validate configuration based on mode + if self.pd_disaggregation: + # Validate PD configuration - skip URL requirements if using service discovery + if not self.service_discovery: + if not self.prefill_urls: + raise ValueError("PD disaggregation mode requires --prefill") + if not self.decode_urls: + raise ValueError("PD disaggregation mode requires --decode") + + # Warn about policy usage in PD mode + if self.prefill_policy and self.decode_policy and self.policy: + logger.warning( + "Both --prefill-policy and --decode-policy are specified. " + "The main --policy flag will be ignored for PD mode." + ) + elif self.prefill_policy and not self.decode_policy and self.policy: + logger.info( + f"Using --prefill-policy '{self.prefill_policy}' for prefill nodes " + f"and --policy '{self.policy}' for decode nodes." + ) + elif self.decode_policy and not self.prefill_policy and self.policy: + logger.info( + f"Using --policy '{self.policy}' for prefill nodes " + f"and --decode-policy '{self.decode_policy}' for decode nodes." + ) + + @staticmethod + def _parse_selector(selector_list): + if not selector_list: + return {} + + selector = {} + for item in selector_list: + if "=" in item: + key, value = item.split("=", 1) + selector[key] = value + return selector + + @staticmethod + def _parse_prefill_urls(prefill_list): + """Parse prefill URLs from --prefill arguments. + + Format: --prefill URL [BOOTSTRAP_PORT] + Example: + --prefill http://prefill1:8080 9000 # With bootstrap port + --prefill http://prefill2:8080 none # Explicitly no bootstrap port + --prefill http://prefill3:8080 # Defaults to no bootstrap port + """ + if not prefill_list: + return [] + + prefill_urls = [] + for prefill_args in prefill_list: + + url = prefill_args[0] + + # Handle optional bootstrap port + if len(prefill_args) >= 2: + bootstrap_port_str = prefill_args[1] + # Handle 'none' as None + if bootstrap_port_str.lower() == "none": + bootstrap_port = None + else: + try: + bootstrap_port = int(bootstrap_port_str) + except ValueError: + raise ValueError( + f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'" + ) + else: + # No bootstrap port specified, default to None + bootstrap_port = None + + prefill_urls.append((url, bootstrap_port)) + + return prefill_urls + + @staticmethod + def _parse_decode_urls(decode_list): + """Parse decode URLs from --decode arguments. + + Format: --decode URL + Example: --decode http://decode1:8081 --decode http://decode2:8081 + """ + if not decode_list: + return [] + + # decode_list is a list of single-element lists due to nargs=1 + return [url[0] for url in decode_list] diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py index cc234e756..031ad5d08 100644 --- a/sgl-router/py_test/test_launch_router.py +++ b/sgl-router/py_test/test_launch_router.py @@ -33,7 +33,7 @@ class TestLaunchRouter(unittest.TestCase): cache_threshold=0.5, balance_abs_threshold=32, balance_rel_threshold=1.0001, - eviction_interval=60, + eviction_interval_secs=60, max_tree_size=2**24, max_payload_size=256 * 1024 * 1024, # 256MB verbose=False, @@ -176,9 +176,8 @@ class TestLaunchRouter(unittest.TestCase): """Test basic PD router functionality without actually starting servers.""" # This test just verifies the PD router can be created and configured # without actually starting it (which would require real prefill/decode servers) - from sglang_router import Router from sglang_router.launch_router import RouterArgs - from sglang_router_rs import PolicyType + from sglang_router.router import PolicyType, Router # Test RouterArgs parsing for PD mode # Simulate the parsed args structure from argparse with action="append" @@ -209,18 +208,7 @@ class TestLaunchRouter(unittest.TestCase): self.assertEqual(router_args.decode_urls[1], "http://decode2:8081") # Test Router creation in PD mode - router = Router( - worker_urls=[], # Empty for PD mode - pd_disaggregation=True, - prefill_urls=[ - ("http://prefill1:8080", 9000), - ("http://prefill2:8080", None), - ], - decode_urls=["http://decode1:8081", "http://decode2:8081"], - policy=PolicyType.CacheAware, - host="127.0.0.1", - port=3001, - ) + router = Router.from_args(router_args) self.assertIsNotNone(router) def test_policy_validation(self): diff --git a/sgl-router/py_test/test_launch_server.py b/sgl-router/py_test/test_launch_server.py index f805ff117..cdad0b9a1 100644 --- a/sgl-router/py_test/test_launch_server.py +++ b/sgl-router/py_test/test_launch_server.py @@ -77,7 +77,7 @@ def popen_launch_router( port, "--dp", str(dp_size), - "--router-eviction-interval", + "--router-eviction-interval-secs", "5", "--router-policy", policy, diff --git a/sgl-router/pyproject.toml b/sgl-router/pyproject.toml index 40f7cd15a..bd0314aec 100644 --- a/sgl-router/pyproject.toml +++ b/sgl-router/pyproject.toml @@ -28,8 +28,3 @@ find = { where = ["py_src"] } # workaround for https://github.com/pypa/twine/issues/1216 [tool.setuptools] license-files = [] - -[[tool.setuptools-rust.ext-modules]] -target = "sglang_router_rs" -path = "Cargo.toml" -binding = "PyO3" diff --git a/sgl-router/setup.py b/sgl-router/setup.py new file mode 100644 index 000000000..730a91ceb --- /dev/null +++ b/sgl-router/setup.py @@ -0,0 +1,21 @@ +import os + +from setuptools import setup +from setuptools_rust import Binding, RustExtension + +no_rust = os.environ.get("SGLANG_ROUTER_BUILD_NO_RUST") == "1" + +rust_extensions = [] +if not no_rust: + rust_extensions.append( + RustExtension( + target="sglang_router_rs", + path="Cargo.toml", + binding=Binding.PyO3, + ) + ) + +setup( + rust_extensions=rust_extensions, + zip_safe=False, +) diff --git a/test/srt/test_disaggregation.py b/test/srt/test_disaggregation.py index 68848aade..1a7cb99ed 100644 --- a/test/srt/test_disaggregation.py +++ b/test/srt/test_disaggregation.py @@ -1,6 +1,5 @@ import json import os -import subprocess import time import unittest from types import SimpleNamespace @@ -18,6 +17,7 @@ from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, CustomTestCase, popen_launch_pd_server, + popen_with_error_check, ) @@ -47,7 +47,9 @@ class TestDisaggregationAccuracy(CustomTestCase): lb_command = [ "python3", "-m", - "sglang.srt.disaggregation.mini_lb", + "sglang_router.launch_router", + "--pd-disaggregation", + "--mini-lb", # FIXME: remove this "--prefill", cls.prefill_url, "--decode", @@ -59,9 +61,7 @@ class TestDisaggregationAccuracy(CustomTestCase): ] print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = subprocess.Popen( - lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) + cls.process_lb = popen_with_error_check(lb_command) cls.wait_server_ready(cls.lb_url + "/health") @classmethod @@ -228,7 +228,9 @@ class TestDisaggregationMooncakeFailure(CustomTestCase): lb_command = [ "python3", "-m", - "sglang.srt.disaggregation.mini_lb", + "sglang_router.launch_router", + "--pd-disaggregation", + "--mini-lb", # FIXME: remove this "--prefill", cls.prefill_url, "--decode", @@ -240,9 +242,7 @@ class TestDisaggregationMooncakeFailure(CustomTestCase): ] print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = subprocess.Popen( - lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) + cls.process_lb = popen_with_error_check(lb_command) cls.wait_server_ready(cls.lb_url + "/health") @classmethod @@ -383,7 +383,9 @@ class TestDisaggregationMooncakeSpec(CustomTestCase): lb_command = [ "python3", "-m", - "sglang.srt.disaggregation.mini_lb", + "sglang_router.launch_router", + "--pd-disaggregation", + "--mini-lb", # FIXME: remove this "--prefill", cls.prefill_url, "--decode", @@ -395,9 +397,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase): ] print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = subprocess.Popen( - lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) + cls.process_lb = popen_with_error_check(lb_command) cls.wait_server_ready(cls.lb_url + "/health") @classmethod @@ -509,7 +509,9 @@ class TestDisaggregationSimulatedRetract(CustomTestCase): lb_command = [ "python3", "-m", - "sglang.srt.disaggregation.mini_lb", + "sglang_router.launch_router", + "--pd-disaggregation", + "--mini-lb", # FIXME: remove this "--prefill", cls.prefill_url, "--decode", @@ -521,9 +523,7 @@ class TestDisaggregationSimulatedRetract(CustomTestCase): ] print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = subprocess.Popen( - lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) + cls.process_lb = popen_with_error_check(lb_command) cls.wait_server_ready(cls.lb_url + "/health") @classmethod diff --git a/test/srt/test_disaggregation_different_tp.py b/test/srt/test_disaggregation_different_tp.py index fdc332040..911afbe9b 100644 --- a/test/srt/test_disaggregation_different_tp.py +++ b/test/srt/test_disaggregation_different_tp.py @@ -15,7 +15,7 @@ from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, CustomTestCase, popen_launch_pd_server, - run_with_timeout, + popen_with_error_check, ) @@ -49,7 +49,9 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase): lb_command = [ "python3", "-m", - "sglang.srt.disaggregation.mini_lb", + "sglang_router.launch_router", + "--pd-disaggregation", + "--mini-lb", # FIXME: remove this "--prefill", cls.prefill_url, "--decode", @@ -61,9 +63,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase): ] print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = subprocess.Popen( - lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) + cls.process_lb = popen_with_error_check(lb_command) cls.wait_server_ready(cls.lb_url + "/health") @classmethod @@ -183,7 +183,9 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase): lb_command = [ "python3", "-m", - "sglang.srt.disaggregation.mini_lb", + "sglang_router.launch_router", + "--pd-disaggregation", + "--mini-lb", # FIXME: remove this "--prefill", cls.prefill_url, "--decode", @@ -195,9 +197,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase): ] print("Starting load balancer:", " ".join(lb_command)) - cls.process_lb = subprocess.Popen( - lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) + cls.process_lb = popen_with_error_check(lb_command) cls.wait_server_ready(cls.lb_url + "/health") @classmethod diff --git a/test/srt/test_disaggregation_pp.py b/test/srt/test_disaggregation_pp.py index 6c04d0cce..ece959a7d 100644 --- a/test/srt/test_disaggregation_pp.py +++ b/test/srt/test_disaggregation_pp.py @@ -49,7 +49,9 @@ class TestPDPPAccuracy(unittest.TestCase): lb_command = [ "python3", "-m", - "sglang.srt.disaggregation.mini_lb", + "sglang_router.launch_router", + "--pd-disaggregation", + "--mini-lb", # FIXME: remove this "--prefill", cls.prefill_url, "--decode",