[router] migrate router from actix to axum (#8479)
This commit is contained in:
@@ -68,6 +68,12 @@ class RouterArgs:
|
||||
prometheus_host: Optional[str] = None
|
||||
# Request ID headers configuration
|
||||
request_id_headers: Optional[List[str]] = None
|
||||
# Request timeout in seconds
|
||||
request_timeout_secs: int = 600
|
||||
# Max concurrent requests for rate limiting
|
||||
max_concurrent_requests: int = 64
|
||||
# CORS allowed origins
|
||||
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
@@ -276,6 +282,25 @@ class RouterArgs:
|
||||
nargs="*",
|
||||
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}request-timeout-secs",
|
||||
type=int,
|
||||
default=RouterArgs.request_timeout_secs,
|
||||
help="Request timeout in seconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}max-concurrent-requests",
|
||||
type=int,
|
||||
default=RouterArgs.max_concurrent_requests,
|
||||
help="Maximum number of concurrent requests allowed (for rate limiting)",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}cors-allowed-origins",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
@@ -337,6 +362,15 @@ class RouterArgs:
|
||||
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
|
||||
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
|
||||
request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
|
||||
request_timeout_secs=getattr(
|
||||
args, f"{prefix}request_timeout_secs", RouterArgs.request_timeout_secs
|
||||
),
|
||||
max_concurrent_requests=getattr(
|
||||
args,
|
||||
f"{prefix}max_concurrent_requests",
|
||||
RouterArgs.max_concurrent_requests,
|
||||
),
|
||||
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -490,6 +524,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
decode_selector=router_args.decode_selector,
|
||||
prometheus_port=router_args.prometheus_port,
|
||||
prometheus_host=router_args.prometheus_host,
|
||||
request_timeout_secs=router_args.request_timeout_secs,
|
||||
pd_disaggregation=router_args.pd_disaggregation,
|
||||
prefill_urls=(
|
||||
router_args.prefill_urls if router_args.pd_disaggregation else None
|
||||
@@ -508,6 +543,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
else None
|
||||
),
|
||||
request_id_headers=router_args.request_id_headers,
|
||||
max_concurrent_requests=router_args.max_concurrent_requests,
|
||||
cors_allowed_origins=router_args.cors_allowed_origins,
|
||||
)
|
||||
|
||||
router.start()
|
||||
|
||||
@@ -61,6 +61,11 @@ class Router:
|
||||
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
|
||||
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
|
||||
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
|
||||
bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
|
||||
Default: 'sglang.ai/bootstrap-port'
|
||||
request_timeout_secs: Request timeout in seconds. Default: 600
|
||||
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 64
|
||||
cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -87,14 +92,18 @@ class Router:
|
||||
service_discovery_namespace: Optional[str] = None,
|
||||
prefill_selector: Dict[str, str] = None,
|
||||
decode_selector: Dict[str, str] = None,
|
||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
|
||||
prometheus_port: Optional[int] = None,
|
||||
prometheus_host: Optional[str] = None,
|
||||
request_timeout_secs: int = 600,
|
||||
request_id_headers: Optional[List[str]] = None,
|
||||
pd_disaggregation: bool = False,
|
||||
prefill_urls: Optional[List[tuple]] = None,
|
||||
decode_urls: Optional[List[str]] = None,
|
||||
prefill_policy: Optional[PolicyType] = None,
|
||||
decode_policy: Optional[PolicyType] = None,
|
||||
request_id_headers: Optional[List[str]] = None,
|
||||
max_concurrent_requests: int = 64,
|
||||
cors_allowed_origins: List[str] = None,
|
||||
):
|
||||
if selector is None:
|
||||
selector = {}
|
||||
@@ -102,6 +111,8 @@ class Router:
|
||||
prefill_selector = {}
|
||||
if decode_selector is None:
|
||||
decode_selector = {}
|
||||
if cors_allowed_origins is None:
|
||||
cors_allowed_origins = []
|
||||
|
||||
self._router = _Router(
|
||||
worker_urls=worker_urls,
|
||||
@@ -126,14 +137,18 @@ class Router:
|
||||
service_discovery_namespace=service_discovery_namespace,
|
||||
prefill_selector=prefill_selector,
|
||||
decode_selector=decode_selector,
|
||||
bootstrap_port_annotation=bootstrap_port_annotation,
|
||||
prometheus_port=prometheus_port,
|
||||
prometheus_host=prometheus_host,
|
||||
request_timeout_secs=request_timeout_secs,
|
||||
request_id_headers=request_id_headers,
|
||||
pd_disaggregation=pd_disaggregation,
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
request_id_headers=request_id_headers,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
cors_allowed_origins=cors_allowed_origins,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user