Simplify Router arguments passing and build it in docker image (#9964)
This commit is contained in:
@@ -1,9 +1,23 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from sglang_router.router_args import RouterArgs
|
||||
from sglang_router_rs import PolicyType
|
||||
from sglang_router_rs import Router as _Router
|
||||
|
||||
|
||||
def policy_from_str(policy_str: Optional[str]) -> PolicyType:
|
||||
"""Convert policy string to PolicyType enum."""
|
||||
if policy_str is None:
|
||||
return None
|
||||
policy_map = {
|
||||
"random": PolicyType.Random,
|
||||
"round_robin": PolicyType.RoundRobin,
|
||||
"cache_aware": PolicyType.CacheAware,
|
||||
"power_of_two": PolicyType.PowerOfTwo,
|
||||
}
|
||||
return policy_map[policy_str]
|
||||
|
||||
|
||||
class Router:
|
||||
"""
|
||||
A high-performance router for distributing requests across worker nodes.
|
||||
@@ -78,130 +92,34 @@ class Router:
|
||||
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_urls: List[str],
|
||||
policy: PolicyType = PolicyType.RoundRobin,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 3001,
|
||||
worker_startup_timeout_secs: int = 600,
|
||||
worker_startup_check_interval: int = 30,
|
||||
cache_threshold: float = 0.3,
|
||||
balance_abs_threshold: int = 64,
|
||||
balance_rel_threshold: float = 1.5,
|
||||
eviction_interval_secs: int = 120,
|
||||
max_tree_size: int = 2**26,
|
||||
max_payload_size: int = 512 * 1024 * 1024, # 512MB
|
||||
dp_aware: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
log_dir: Optional[str] = None,
|
||||
log_level: Optional[str] = None,
|
||||
service_discovery: bool = False,
|
||||
selector: Dict[str, str] = None,
|
||||
service_discovery_port: int = 80,
|
||||
service_discovery_namespace: Optional[str] = None,
|
||||
prefill_selector: Dict[str, str] = None,
|
||||
decode_selector: Dict[str, str] = None,
|
||||
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
|
||||
prometheus_port: Optional[int] = None,
|
||||
prometheus_host: Optional[str] = None,
|
||||
request_timeout_secs: int = 1800,
|
||||
request_id_headers: Optional[List[str]] = None,
|
||||
pd_disaggregation: bool = False,
|
||||
prefill_urls: Optional[List[tuple]] = None,
|
||||
decode_urls: Optional[List[str]] = None,
|
||||
prefill_policy: Optional[PolicyType] = None,
|
||||
decode_policy: Optional[PolicyType] = None,
|
||||
max_concurrent_requests: int = 256,
|
||||
queue_size: int = 100,
|
||||
queue_timeout_secs: int = 60,
|
||||
rate_limit_tokens_per_second: Optional[int] = None,
|
||||
cors_allowed_origins: List[str] = None,
|
||||
retry_max_retries: int = 5,
|
||||
retry_initial_backoff_ms: int = 50,
|
||||
retry_max_backoff_ms: int = 30_000,
|
||||
retry_backoff_multiplier: float = 1.5,
|
||||
retry_jitter_factor: float = 0.2,
|
||||
cb_failure_threshold: int = 10,
|
||||
cb_success_threshold: int = 3,
|
||||
cb_timeout_duration_secs: int = 60,
|
||||
cb_window_duration_secs: int = 120,
|
||||
disable_retries: bool = False,
|
||||
disable_circuit_breaker: bool = False,
|
||||
health_failure_threshold: int = 3,
|
||||
health_success_threshold: int = 2,
|
||||
health_check_timeout_secs: int = 5,
|
||||
health_check_interval_secs: int = 60,
|
||||
health_check_endpoint: str = "/health",
|
||||
model_path: Optional[str] = None,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
):
|
||||
if selector is None:
|
||||
selector = {}
|
||||
if prefill_selector is None:
|
||||
prefill_selector = {}
|
||||
if decode_selector is None:
|
||||
decode_selector = {}
|
||||
if cors_allowed_origins is None:
|
||||
cors_allowed_origins = []
|
||||
def __init__(self, router: _Router):
|
||||
self._router = router
|
||||
|
||||
self._router = _Router(
|
||||
worker_urls=worker_urls,
|
||||
policy=policy,
|
||||
host=host,
|
||||
port=port,
|
||||
worker_startup_timeout_secs=worker_startup_timeout_secs,
|
||||
worker_startup_check_interval=worker_startup_check_interval,
|
||||
cache_threshold=cache_threshold,
|
||||
balance_abs_threshold=balance_abs_threshold,
|
||||
balance_rel_threshold=balance_rel_threshold,
|
||||
eviction_interval_secs=eviction_interval_secs,
|
||||
max_tree_size=max_tree_size,
|
||||
max_payload_size=max_payload_size,
|
||||
dp_aware=dp_aware,
|
||||
api_key=api_key,
|
||||
log_dir=log_dir,
|
||||
log_level=log_level,
|
||||
service_discovery=service_discovery,
|
||||
selector=selector,
|
||||
service_discovery_port=service_discovery_port,
|
||||
service_discovery_namespace=service_discovery_namespace,
|
||||
prefill_selector=prefill_selector,
|
||||
decode_selector=decode_selector,
|
||||
bootstrap_port_annotation=bootstrap_port_annotation,
|
||||
prometheus_port=prometheus_port,
|
||||
prometheus_host=prometheus_host,
|
||||
request_timeout_secs=request_timeout_secs,
|
||||
request_id_headers=request_id_headers,
|
||||
pd_disaggregation=pd_disaggregation,
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
queue_size=queue_size,
|
||||
queue_timeout_secs=queue_timeout_secs,
|
||||
rate_limit_tokens_per_second=rate_limit_tokens_per_second,
|
||||
cors_allowed_origins=cors_allowed_origins,
|
||||
retry_max_retries=retry_max_retries,
|
||||
retry_initial_backoff_ms=retry_initial_backoff_ms,
|
||||
retry_max_backoff_ms=retry_max_backoff_ms,
|
||||
retry_backoff_multiplier=retry_backoff_multiplier,
|
||||
retry_jitter_factor=retry_jitter_factor,
|
||||
cb_failure_threshold=cb_failure_threshold,
|
||||
cb_success_threshold=cb_success_threshold,
|
||||
cb_timeout_duration_secs=cb_timeout_duration_secs,
|
||||
cb_window_duration_secs=cb_window_duration_secs,
|
||||
disable_retries=disable_retries,
|
||||
disable_circuit_breaker=disable_circuit_breaker,
|
||||
health_failure_threshold=health_failure_threshold,
|
||||
health_success_threshold=health_success_threshold,
|
||||
health_check_timeout_secs=health_check_timeout_secs,
|
||||
health_check_interval_secs=health_check_interval_secs,
|
||||
health_check_endpoint=health_check_endpoint,
|
||||
model_path=model_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
@staticmethod
|
||||
def from_args(args: RouterArgs) -> "Router":
|
||||
"""Create a router from a RouterArgs instance."""
|
||||
|
||||
args_dict = vars(args)
|
||||
# Convert RouterArgs to _Router parameters
|
||||
args_dict["worker_urls"] = (
|
||||
[]
|
||||
if args_dict["service_discovery"] or args_dict["pd_disaggregation"]
|
||||
else args_dict["worker_urls"]
|
||||
)
|
||||
args_dict["policy"] = policy_from_str(args_dict["policy"])
|
||||
args_dict["prefill_urls"] = (
|
||||
args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None
|
||||
)
|
||||
args_dict["decode_urls"] = (
|
||||
args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None
|
||||
)
|
||||
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
|
||||
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
|
||||
|
||||
# remoge mini_lb parameter
|
||||
args_dict.pop("mini_lb")
|
||||
|
||||
return Router(_Router(**args_dict))
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the router server.
|
||||
|
||||
Reference in New Issue
Block a user