diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index d1d80ec60..e0522592f 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -99,6 +99,9 @@ class RouterArgs: cb_timeout_duration_secs: int = 60 cb_window_duration_secs: int = 120 disable_circuit_breaker: bool = False + # Tokenizer configuration + model_path: Optional[str] = None + tokenizer_path: Optional[str] = None @staticmethod def add_cli_args( @@ -433,6 +436,19 @@ class RouterArgs: default=[], help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)", ) + # Tokenizer configuration + parser.add_argument( + f"--{prefix}model-path", + type=str, + default=None, + help="Model path for loading tokenizer (HuggingFace model ID or local path)", + ) + parser.add_argument( + f"--{prefix}tokenizer-path", + type=str, + default=None, + help="Explicit tokenizer path (overrides model_path tokenizer if provided)", + ) @classmethod def from_cli_args( @@ -554,6 +570,8 @@ class RouterArgs: health_check_endpoint=getattr( args, f"{prefix}health_check_endpoint", RouterArgs.health_check_endpoint ), + model_path=getattr(args, f"{prefix}model_path", None), + tokenizer_path=getattr(args, f"{prefix}tokenizer_path", None), ) @staticmethod @@ -759,6 +777,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: health_check_timeout_secs=router_args.health_check_timeout_secs, health_check_interval_secs=router_args.health_check_interval_secs, health_check_endpoint=router_args.health_check_endpoint, + model_path=router_args.model_path, + tokenizer_path=router_args.tokenizer_path, ) router.start() diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index d6c53e032..de504bafc 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -74,6 +74,8 @@ class Router: health_check_timeout_secs: Timeout in seconds for health check requests. Default: 5 health_check_interval_secs: Interval in seconds between runtime health checks. Default: 60 health_check_endpoint: Health check endpoint path. Default: '/health' + model_path: Model path for loading tokenizer (HuggingFace model ID or local path). Default: None + tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None """ def __init__( @@ -131,6 +133,8 @@ class Router: health_check_timeout_secs: int = 5, health_check_interval_secs: int = 60, health_check_endpoint: str = "/health", + model_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, ): if selector is None: selector = {} @@ -195,6 +199,8 @@ class Router: health_check_timeout_secs=health_check_timeout_secs, health_check_interval_secs=health_check_interval_secs, health_check_endpoint=health_check_endpoint, + model_path=model_path, + tokenizer_path=tokenizer_path, ) def start(self) -> None: diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py index 8e0d9e852..cc234e756 100644 --- a/sgl-router/py_test/test_launch_router.py +++ b/sgl-router/py_test/test_launch_router.py @@ -64,6 +64,8 @@ class TestLaunchRouter(unittest.TestCase): cb_window_duration_secs=60, disable_retries=False, disable_circuit_breaker=False, + model_path=None, + tokenizer_path=None, ) def create_router_args(self, **kwargs): diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 6afc3348e..a45d52bd2 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -7,6 +7,9 @@ use std::collections::HashMap; pub struct RouterConfig { /// Routing mode configuration pub mode: RoutingMode, + /// Worker connection mode + #[serde(default)] + pub connection_mode: ConnectionMode, /// Policy configuration pub policy: PolicyConfig, /// Server host address @@ -60,6 +63,20 @@ pub struct RouterConfig { /// Enable Inference Gateway mode (false = proxy mode, true = IGW mode) #[serde(default)] pub enable_igw: bool, + /// Model path for loading tokenizer (can be a HuggingFace model ID or local path) + pub model_path: Option, + /// Explicit tokenizer path (overrides model_path tokenizer if provided) + pub tokenizer_path: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] +#[serde(tag = "type")] +pub enum ConnectionMode { + #[default] + #[serde(rename = "http")] + Http, + #[serde(rename = "grpc")] + Grpc, } /// Routing mode configuration @@ -336,6 +353,9 @@ impl Default for RouterConfig { disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), enable_igw: false, + connection_mode: ConnectionMode::Http, + model_path: None, + tokenizer_path: None, } } } @@ -478,6 +498,9 @@ mod tests { queue_size: 100, queue_timeout_secs: 60, rate_limit_tokens_per_second: None, + connection_mode: ConnectionMode::Http, + model_path: None, + tokenizer_path: None, }; let json = serde_json::to_string(&config).unwrap(); @@ -914,6 +937,9 @@ mod tests { queue_size: 100, queue_timeout_secs: 60, rate_limit_tokens_per_second: None, + connection_mode: ConnectionMode::Http, + model_path: None, + tokenizer_path: None, }; assert!(config.mode.is_pd_mode()); @@ -974,6 +1000,9 @@ mod tests { queue_size: 100, queue_timeout_secs: 60, rate_limit_tokens_per_second: None, + connection_mode: ConnectionMode::Http, + model_path: None, + tokenizer_path: None, }; assert!(!config.mode.is_pd_mode()); @@ -1030,6 +1059,9 @@ mod tests { queue_size: 100, queue_timeout_secs: 60, rate_limit_tokens_per_second: None, + connection_mode: ConnectionMode::Http, + model_path: None, + tokenizer_path: None, }; assert!(config.has_service_discovery()); diff --git a/sgl-router/src/config/validation.rs b/sgl-router/src/config/validation.rs index 542e2e467..a0a31fd23 100644 --- a/sgl-router/src/config/validation.rs +++ b/sgl-router/src/config/validation.rs @@ -349,6 +349,16 @@ impl ConfigValidator { return Ok(()); } + // Validate gRPC connection mode requires tokenizer configuration + if config.connection_mode == ConnectionMode::Grpc + && config.tokenizer_path.is_none() + && config.model_path.is_none() + { + return Err(ConfigError::ValidationFailed { + reason: "gRPC connection mode requires either --tokenizer-path or --model-path to be specified".to_string(), + }); + } + // All policies are now supported for both router types thanks to the unified trait design // No mode/policy restrictions needed anymore @@ -419,11 +429,14 @@ impl ConfigValidator { }); } - if !url.starts_with("http://") && !url.starts_with("https://") { + if !url.starts_with("http://") + && !url.starts_with("https://") + && !url.starts_with("grpc://") + { return Err(ConfigError::InvalidValue { field: "worker_url".to_string(), value: url.clone(), - reason: "URL must start with http:// or https://".to_string(), + reason: "URL must start with http://, https://, or grpc://".to_string(), }); } @@ -684,4 +697,60 @@ mod tests { assert!(e.to_string().contains("prefill requires at least 2")); } } + + #[test] + fn test_validate_grpc_requires_tokenizer() { + // Test that gRPC connection mode requires tokenizer configuration + let mut config = RouterConfig::new( + RoutingMode::Regular { + worker_urls: vec!["grpc://worker:50051".to_string()], + }, + PolicyConfig::Random, + ); + + // Set connection mode to gRPC without tokenizer config + config.connection_mode = ConnectionMode::Grpc; + config.tokenizer_path = None; + config.model_path = None; + + let result = ConfigValidator::validate(&config); + assert!(result.is_err()); + if let Err(e) = result { + assert!(e.to_string().contains("gRPC connection mode requires")); + } + } + + #[test] + fn test_validate_grpc_with_model_path() { + // Test that gRPC works with model_path + let mut config = RouterConfig::new( + RoutingMode::Regular { + worker_urls: vec!["grpc://worker:50051".to_string()], + }, + PolicyConfig::Random, + ); + + config.connection_mode = ConnectionMode::Grpc; + config.model_path = Some("meta-llama/Llama-3-8B".to_string()); + + let result = ConfigValidator::validate(&config); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_grpc_with_tokenizer_path() { + // Test that gRPC works with tokenizer_path + let mut config = RouterConfig::new( + RoutingMode::Regular { + worker_urls: vec!["grpc://worker:50051".to_string()], + }, + PolicyConfig::Random, + ); + + config.connection_mode = ConnectionMode::Grpc; + config.tokenizer_path = Some("/path/to/tokenizer.json".to_string()); + + let result = ConfigValidator::validate(&config); + assert!(result.is_ok()); + } } diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index c39e0d052..955185e07 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -2,6 +2,7 @@ use pyo3::prelude::*; pub mod config; pub mod logging; use std::collections::HashMap; + pub mod core; #[cfg(feature = "grpc-client")] pub mod grpc; @@ -89,9 +90,39 @@ struct Router { queue_size: usize, queue_timeout_secs: u64, rate_limit_tokens_per_second: Option, + // Connection mode (determined from worker URLs) + connection_mode: config::ConnectionMode, + // Model path for tokenizer + model_path: Option, + // Explicit tokenizer path + tokenizer_path: Option, } impl Router { + /// Determine connection mode from worker URLs + fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode { + // Check if any URL is a gRPC endpoint (starts with grpc:// or has port that commonly indicates gRPC) + for url in worker_urls { + if url.starts_with("grpc://") || url.starts_with("grpcs://") { + return config::ConnectionMode::Grpc; + } + // Also check for common gRPC ports if the scheme isn't specified + if let Ok(parsed_url) = url::Url::parse(url) { + if let Some(port) = parsed_url.port() { + // Common gRPC ports + if port == 50051 || port == 9090 || ((50000..=50100).contains(&port)) { + return config::ConnectionMode::Grpc; + } + } + } else if url.contains(":50051") || url.contains(":9090") || url.contains(":5000") { + // Fallback check for URLs that might not parse correctly + return config::ConnectionMode::Grpc; + } + } + // Default to HTTP + config::ConnectionMode::Http + } + /// Convert PyO3 Router to RouterConfig pub fn to_router_config(&self) -> config::ConfigResult { use config::{ @@ -168,6 +199,7 @@ impl Router { policy, host: self.host.clone(), port: self.port, + connection_mode: self.connection_mode.clone(), max_payload_size: self.max_payload_size, request_timeout_secs: self.request_timeout_secs, worker_startup_timeout_secs: self.worker_startup_timeout_secs, @@ -207,6 +239,8 @@ impl Router { endpoint: self.health_check_endpoint.clone(), }, enable_igw: self.enable_igw, + model_path: self.model_path.clone(), + tokenizer_path: self.tokenizer_path.clone(), }) } } @@ -273,6 +307,9 @@ impl Router { queue_size = 100, queue_timeout_secs = 60, rate_limit_tokens_per_second = None, + // Tokenizer defaults + model_path = None, + tokenizer_path = None, ))] #[allow(clippy::too_many_arguments)] fn new( @@ -330,7 +367,26 @@ impl Router { queue_size: usize, queue_timeout_secs: u64, rate_limit_tokens_per_second: Option, + model_path: Option, + tokenizer_path: Option, ) -> PyResult { + // Determine connection mode from worker URLs + let mut all_urls = worker_urls.clone(); + + // Add prefill URLs if in PD mode + if let Some(ref prefill_urls) = prefill_urls { + for (url, _) in prefill_urls { + all_urls.push(url.clone()); + } + } + + // Add decode URLs if in PD mode + if let Some(ref decode_urls) = decode_urls { + all_urls.extend(decode_urls.clone()); + } + + let connection_mode = Self::determine_connection_mode(&all_urls); + Ok(Router { host, port, @@ -386,6 +442,9 @@ impl Router { queue_size, queue_timeout_secs, rate_limit_tokens_per_second, + connection_mode, + model_path, + tokenizer_path, }) } diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index 1221d2b62..c745c0b3b 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -1,7 +1,7 @@ use clap::{ArgAction, Parser}; use sglang_router_rs::config::{ - CircuitBreakerConfig, ConfigError, ConfigResult, DiscoveryConfig, HealthCheckConfig, - MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, + CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig, + HealthCheckConfig, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, }; use sglang_router_rs::metrics::PrometheusConfig; use sglang_router_rs::server::{self, ServerConfig}; @@ -272,9 +272,42 @@ struct CliArgs { /// Enable Inference Gateway mode #[arg(long, default_value_t = false)] enable_igw: bool, + + // Tokenizer configuration + /// Model path for loading tokenizer (HuggingFace model ID or local path) + #[arg(long)] + model_path: Option, + + /// Explicit tokenizer path (overrides model_path tokenizer if provided) + #[arg(long)] + tokenizer_path: Option, } impl CliArgs { + /// Determine connection mode from worker URLs + fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode { + // Check if any URL is a gRPC endpoint (starts with grpc:// or has port that commonly indicates gRPC) + for url in worker_urls { + if url.starts_with("grpc://") || url.starts_with("grpcs://") { + return ConnectionMode::Grpc; + } + // Also check for common gRPC ports if the scheme isn't specified + if let Ok(parsed_url) = url::Url::parse(url) { + if let Some(port) = parsed_url.port() { + // Common gRPC ports + if port == 50051 || port == 9090 || ((50000..=50100).contains(&port)) { + return ConnectionMode::Grpc; + } + } + } else if url.contains(":50051") || url.contains(":9090") || url.contains(":5000") { + // Fallback check for URLs that might not parse correctly + return ConnectionMode::Grpc; + } + } + // Default to HTTP + ConnectionMode::Http + } + /// Parse selector strings into HashMap fn parse_selector(selector_list: &[String]) -> HashMap { let mut map = HashMap::new(); @@ -372,10 +405,30 @@ impl CliArgs { host: self.prometheus_host.clone(), }); + // Determine connection mode from all worker URLs + let mut all_urls = Vec::new(); + match &mode { + RoutingMode::Regular { worker_urls } => { + all_urls.extend(worker_urls.clone()); + } + RoutingMode::PrefillDecode { + prefill_urls, + decode_urls, + .. + } => { + for (url, _) in prefill_urls { + all_urls.push(url.clone()); + } + all_urls.extend(decode_urls.clone()); + } + } + let connection_mode = Self::determine_connection_mode(&all_urls); + // Build RouterConfig Ok(RouterConfig { mode, policy, + connection_mode, host: self.host.clone(), port: self.port, max_payload_size: self.max_payload_size, @@ -421,6 +474,8 @@ impl CliArgs { }, enable_igw: self.enable_igw, rate_limit_tokens_per_second: None, + model_path: self.model_path.clone(), + tokenizer_path: self.tokenizer_path.clone(), }) } diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index 94845fdfb..686ab4329 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -4,7 +4,7 @@ use super::{ http::{pd_router::PDRouter, router::Router}, RouterTrait, }; -use crate::config::{PolicyConfig, RoutingMode}; +use crate::config::{ConnectionMode, PolicyConfig, RoutingMode}; use crate::policies::PolicyFactory; use crate::server::AppContext; use std::sync::Arc; @@ -20,28 +20,56 @@ impl RouterFactory { return Self::create_igw_router(ctx).await; } - // TODO: Add gRPC mode check here when implementing gRPC support - - // Default to HTTP proxy mode - match &ctx.router_config.mode { - RoutingMode::Regular { worker_urls } => { - Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await + // Check connection mode and route to appropriate implementation + match ctx.router_config.connection_mode { + ConnectionMode::Grpc => { + // Route to gRPC implementation based on routing mode + match &ctx.router_config.mode { + RoutingMode::Regular { worker_urls } => { + Self::create_grpc_router(worker_urls, &ctx.router_config.policy, ctx).await + } + RoutingMode::PrefillDecode { + prefill_urls, + decode_urls, + prefill_policy, + decode_policy, + } => { + Self::create_grpc_pd_router( + prefill_urls, + decode_urls, + prefill_policy.as_ref(), + decode_policy.as_ref(), + &ctx.router_config.policy, + ctx, + ) + .await + } + } } - RoutingMode::PrefillDecode { - prefill_urls, - decode_urls, - prefill_policy, - decode_policy, - } => { - Self::create_pd_router( - prefill_urls, - decode_urls, - prefill_policy.as_ref(), - decode_policy.as_ref(), - &ctx.router_config.policy, - ctx, - ) - .await + ConnectionMode::Http => { + // Route to HTTP implementation based on routing mode + match &ctx.router_config.mode { + RoutingMode::Regular { worker_urls } => { + Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx) + .await + } + RoutingMode::PrefillDecode { + prefill_urls, + decode_urls, + prefill_policy, + decode_policy, + } => { + Self::create_pd_router( + prefill_urls, + decode_urls, + prefill_policy.as_ref(), + decode_policy.as_ref(), + &ctx.router_config.policy, + ctx, + ) + .await + } + } } } } @@ -109,25 +137,92 @@ impl RouterFactory { /// Create a gRPC router with injected policy pub async fn create_grpc_router( - _worker_urls: &[String], - _policy_config: &PolicyConfig, - _ctx: &Arc, + worker_urls: &[String], + policy_config: &PolicyConfig, + ctx: &Arc, ) -> Result, String> { - // For now, return an error as gRPC router is not yet implemented - Err("gRPC router is not yet implemented".to_string()) + use super::grpc::router::GrpcRouter; + + // Create policy + let policy = PolicyFactory::create_from_config(policy_config); + + // Determine which tokenizer path to use + // Priority: tokenizer_path > model_path + let tokenizer_path = ctx + .router_config + .tokenizer_path + .clone() + .or_else(|| ctx.router_config.model_path.clone()) + .ok_or_else(|| { + "gRPC router requires either --tokenizer-path or --model-path to be specified" + .to_string() + })?; + + // Create gRPC router + let router = GrpcRouter::new( + worker_urls.to_vec(), + policy, + ctx.router_config.worker_startup_timeout_secs, + ctx.router_config.worker_startup_check_interval_secs, + ctx.router_config.dp_aware, + ctx.router_config.api_key.clone(), + ctx.router_config.effective_retry_config(), + ctx.router_config.effective_circuit_breaker_config(), + ctx.router_config.health_check.clone(), + tokenizer_path, + ) + .await?; + + Ok(Box::new(router)) } - /// Create a gRPC PD router (placeholder for now) + /// Create a gRPC PD router with tokenizer and worker configuration pub async fn create_grpc_pd_router( - _prefill_urls: &[(String, Option)], - _decode_urls: &[String], - _prefill_policy_config: Option<&PolicyConfig>, - _decode_policy_config: Option<&PolicyConfig>, - _main_policy_config: &PolicyConfig, - _ctx: &Arc, + prefill_urls: &[(String, Option)], + decode_urls: &[String], + prefill_policy_config: Option<&PolicyConfig>, + decode_policy_config: Option<&PolicyConfig>, + main_policy_config: &PolicyConfig, + ctx: &Arc, ) -> Result, String> { - // For now, return an error as gRPC PD router is not yet implemented - Err("gRPC PD router is not yet implemented".to_string()) + use super::grpc::pd_router::GrpcPDRouter; + + // Create policies - use specific policies if provided, otherwise fall back to main policy + let prefill_policy = + PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config)); + let decode_policy = + PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config)); + + // Determine which tokenizer path to use + // Priority: tokenizer_path > model_path + let tokenizer_path = ctx + .router_config + .tokenizer_path + .clone() + .or_else(|| ctx.router_config.model_path.clone()) + .ok_or_else(|| { + "gRPC PD router requires either --tokenizer-path or --model-path to be specified" + .to_string() + })?; + + // Create gRPC PD router + let router = GrpcPDRouter::new( + prefill_urls.to_vec(), + decode_urls.to_vec(), + prefill_policy, + decode_policy, + ctx.router_config.worker_startup_timeout_secs, + ctx.router_config.worker_startup_check_interval_secs, + ctx.router_config.dp_aware, + ctx.router_config.api_key.clone(), + ctx.router_config.effective_retry_config(), + ctx.router_config.effective_circuit_breaker_config(), + ctx.router_config.health_check.clone(), + tokenizer_path, + ) + .await?; + + Ok(Box::new(router)) } /// Create an IGW router (placeholder for future implementation) diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index e3f453186..2f4c61649 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -1,7 +1,19 @@ // PD (Prefill-Decode) gRPC Router Implementation -// TODO: Implement gRPC-based PD router for disaggregated prefill-decode systems +use crate::config::types::{ + CircuitBreakerConfig as ConfigCircuitBreakerConfig, + HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig, +}; +use crate::core::{ + BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, +}; +use crate::grpc::SglangSchedulerClient; +use crate::metrics::RouterMetrics; +use crate::policies::LoadBalancingPolicy; +use crate::reasoning_parser::ParserFactory; use crate::routers::{RouterTrait, WorkerManagement}; +use crate::tokenizer::{factory, traits::Tokenizer}; +use crate::tool_parser::ParserRegistry; use async_trait::async_trait; use axum::{ body::Body, @@ -9,15 +21,222 @@ use axum::{ http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use std::time::Duration; +use tracing::{info, warn}; -/// Placeholder for gRPC PD router -#[derive(Debug)] -pub struct GrpcPDRouter; +/// gRPC PD (Prefill-Decode) router implementation for SGLang +#[allow(dead_code)] // Fields will be used once implementation is complete +pub struct GrpcPDRouter { + /// Prefill worker connections + prefill_workers: Arc>>>, + /// Decode worker connections + decode_workers: Arc>>>, + /// gRPC clients for prefill workers + prefill_grpc_clients: Arc>>, + /// gRPC clients for decode workers + decode_grpc_clients: Arc>>, + /// Load balancing policy for prefill + prefill_policy: Arc, + /// Load balancing policy for decode + decode_policy: Arc, + /// Tokenizer for handling text encoding/decoding + tokenizer: Arc, + /// Reasoning parser factory for structured reasoning outputs + reasoning_parser_factory: ParserFactory, + /// Tool parser registry for function/tool calls + tool_parser_registry: &'static ParserRegistry, + /// Worker health checkers + _prefill_health_checker: Option, + _decode_health_checker: Option, + /// Configuration + timeout_secs: u64, + interval_secs: u64, + dp_aware: bool, + api_key: Option, + retry_config: RetryConfig, + circuit_breaker_config: CircuitBreakerConfig, +} impl GrpcPDRouter { - pub async fn new() -> Result { - // TODO: Implement gRPC PD router initialization - Err("gRPC PD router not yet implemented".to_string()) + /// Create a new gRPC PD router + #[allow(clippy::too_many_arguments)] + pub async fn new( + prefill_urls: Vec<(String, Option)>, + decode_urls: Vec, + prefill_policy: Arc, + decode_policy: Arc, + timeout_secs: u64, + interval_secs: u64, + dp_aware: bool, + api_key: Option, + retry_config: RetryConfig, + circuit_breaker_config: ConfigCircuitBreakerConfig, + health_check_config: ConfigHealthCheckConfig, + tokenizer_path_or_model: String, + ) -> Result { + // Update metrics + RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len()); + + // Initialize tokenizer + let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model) + .map_err(|e| format!("Failed to create tokenizer: {}", e))?; + + // Initialize reasoning parser factory + let reasoning_parser_factory = ParserFactory::new(); + + // Get tool parser registry + let tool_parser_registry = ParserRegistry::new(); + + // Convert config CircuitBreakerConfig to core CircuitBreakerConfig + let core_cb_config = CircuitBreakerConfig { + failure_threshold: circuit_breaker_config.failure_threshold, + success_threshold: circuit_breaker_config.success_threshold, + timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs), + window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), + }; + + // Create gRPC clients for prefill workers + let mut prefill_grpc_clients = HashMap::new(); + for (url, _bootstrap_port) in &prefill_urls { + match SglangSchedulerClient::connect(url).await { + Ok(client) => { + prefill_grpc_clients.insert(url.clone(), client); + info!("Connected to gRPC prefill worker at {}", url); + } + Err(e) => { + warn!("Failed to connect to gRPC prefill worker at {}: {}", url, e); + // Continue with other workers + } + } + } + + // Create gRPC clients for decode workers + let mut decode_grpc_clients = HashMap::new(); + for url in &decode_urls { + match SglangSchedulerClient::connect(url).await { + Ok(client) => { + decode_grpc_clients.insert(url.clone(), client); + info!("Connected to gRPC decode worker at {}", url); + } + Err(e) => { + warn!("Failed to connect to gRPC decode worker at {}: {}", url, e); + // Continue with other workers + } + } + } + + if prefill_grpc_clients.is_empty() && decode_grpc_clients.is_empty() { + return Err("Failed to connect to any gRPC workers".to_string()); + } + + // Create Prefill Worker trait objects with gRPC connection mode + let prefill_workers: Vec> = prefill_urls + .iter() + .map(|(url, bootstrap_port)| { + let worker = BasicWorker::with_connection_mode( + url.clone(), + WorkerType::Prefill { + bootstrap_port: *bootstrap_port, + }, + crate::core::ConnectionMode::Grpc { + port: *bootstrap_port, + }, + ) + .with_circuit_breaker_config(core_cb_config.clone()) + .with_health_config(HealthConfig { + timeout_secs: health_check_config.timeout_secs, + check_interval_secs: health_check_config.check_interval_secs, + endpoint: health_check_config.endpoint.clone(), + failure_threshold: health_check_config.failure_threshold, + success_threshold: health_check_config.success_threshold, + }); + Box::new(worker) as Box + }) + .collect(); + + // Create Decode Worker trait objects with gRPC connection mode + let decode_workers: Vec> = decode_urls + .iter() + .map(|url| { + let worker = BasicWorker::with_connection_mode( + url.clone(), + WorkerType::Decode, + crate::core::ConnectionMode::Grpc { port: None }, + ) + .with_circuit_breaker_config(core_cb_config.clone()) + .with_health_config(HealthConfig { + timeout_secs: health_check_config.timeout_secs, + check_interval_secs: health_check_config.check_interval_secs, + endpoint: health_check_config.endpoint.clone(), + failure_threshold: health_check_config.failure_threshold, + success_threshold: health_check_config.success_threshold, + }); + Box::new(worker) as Box + }) + .collect(); + + // Initialize policies with workers if needed + if let Some(cache_aware) = prefill_policy + .as_any() + .downcast_ref::() + { + cache_aware.init_workers(&prefill_workers); + } + + if let Some(cache_aware) = decode_policy + .as_any() + .downcast_ref::() + { + cache_aware.init_workers(&decode_workers); + } + + let prefill_workers = Arc::new(RwLock::new(prefill_workers)); + let decode_workers = Arc::new(RwLock::new(decode_workers)); + + let prefill_health_checker = + crate::core::start_health_checker(Arc::clone(&prefill_workers), interval_secs); + let decode_health_checker = + crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs); + + Ok(GrpcPDRouter { + prefill_workers, + decode_workers, + prefill_grpc_clients: Arc::new(RwLock::new(prefill_grpc_clients)), + decode_grpc_clients: Arc::new(RwLock::new(decode_grpc_clients)), + prefill_policy, + decode_policy, + tokenizer, + reasoning_parser_factory, + tool_parser_registry, + _prefill_health_checker: Some(prefill_health_checker), + _decode_health_checker: Some(decode_health_checker), + timeout_secs, + interval_secs, + dp_aware, + api_key, + retry_config, + circuit_breaker_config: core_cb_config, + }) + } +} + +impl std::fmt::Debug for GrpcPDRouter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("GrpcPDRouter") + .field( + "prefill_workers_count", + &self.prefill_workers.read().unwrap().len(), + ) + .field( + "decode_workers_count", + &self.decode_workers.read().unwrap().len(), + ) + .field("timeout_secs", &self.timeout_secs) + .field("interval_secs", &self.interval_secs) + .field("dp_aware", &self.dp_aware) + .finish() } } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index f5fc407f7..e7a0bd162 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -1,7 +1,19 @@ // gRPC Router Implementation -// TODO: Implement gRPC-based router +use crate::config::types::{ + CircuitBreakerConfig as ConfigCircuitBreakerConfig, + HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig, +}; +use crate::core::{ + BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, +}; +use crate::grpc::SglangSchedulerClient; +use crate::metrics::RouterMetrics; +use crate::policies::LoadBalancingPolicy; +use crate::reasoning_parser::ParserFactory; use crate::routers::{RouterTrait, WorkerManagement}; +use crate::tokenizer::{factory, traits::Tokenizer}; +use crate::tool_parser::ParserRegistry; use async_trait::async_trait; use axum::{ body::Body, @@ -9,15 +21,150 @@ use axum::{ http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use std::time::Duration; +use tracing::{info, warn}; -/// Placeholder for gRPC router -#[derive(Debug)] -pub struct GrpcRouter; +/// gRPC router implementation for SGLang +#[allow(dead_code)] // Fields will be used once implementation is complete +pub struct GrpcRouter { + /// Worker connections + workers: Arc>>>, + /// gRPC clients for each worker + grpc_clients: Arc>>, + /// Load balancing policy + policy: Arc, + /// Tokenizer for handling text encoding/decoding + tokenizer: Arc, + /// Reasoning parser factory for structured reasoning outputs + reasoning_parser_factory: ParserFactory, + /// Tool parser registry for function/tool calls + tool_parser_registry: &'static ParserRegistry, + /// Worker health checker + _health_checker: Option, + /// Configuration + timeout_secs: u64, + interval_secs: u64, + dp_aware: bool, + api_key: Option, + retry_config: RetryConfig, + circuit_breaker_config: CircuitBreakerConfig, +} impl GrpcRouter { - pub async fn new() -> Result { - // TODO: Implement gRPC router initialization - Err("gRPC router not yet implemented".to_string()) + /// Create a new gRPC router + #[allow(clippy::too_many_arguments)] + pub async fn new( + worker_urls: Vec, + policy: Arc, + timeout_secs: u64, + interval_secs: u64, + dp_aware: bool, + api_key: Option, + retry_config: RetryConfig, + circuit_breaker_config: ConfigCircuitBreakerConfig, + health_check_config: ConfigHealthCheckConfig, + tokenizer_path_or_model: String, + ) -> Result { + // Update metrics + RouterMetrics::set_active_workers(worker_urls.len()); + + // Initialize tokenizer + let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model) + .map_err(|e| format!("Failed to create tokenizer: {}", e))?; + + // Initialize reasoning parser factory + let reasoning_parser_factory = ParserFactory::new(); + + // Get tool parser registry + let tool_parser_registry = ParserRegistry::new(); + + // Convert config CircuitBreakerConfig to core CircuitBreakerConfig + let core_cb_config = CircuitBreakerConfig { + failure_threshold: circuit_breaker_config.failure_threshold, + success_threshold: circuit_breaker_config.success_threshold, + timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs), + window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), + }; + + // Create gRPC clients for each worker + let mut grpc_clients = HashMap::new(); + for url in &worker_urls { + match SglangSchedulerClient::connect(url).await { + Ok(client) => { + grpc_clients.insert(url.clone(), client); + info!("Connected to gRPC worker at {}", url); + } + Err(e) => { + warn!("Failed to connect to gRPC worker at {}: {}", url, e); + // Continue with other workers + } + } + } + + if grpc_clients.is_empty() { + return Err("Failed to connect to any gRPC workers".to_string()); + } + + // Create Worker trait objects with gRPC connection mode + let workers: Vec> = worker_urls + .iter() + .map(|url| { + let worker = BasicWorker::with_connection_mode( + url.clone(), + WorkerType::Regular, + crate::core::ConnectionMode::Grpc { port: None }, + ) + .with_circuit_breaker_config(core_cb_config.clone()) + .with_health_config(HealthConfig { + timeout_secs: health_check_config.timeout_secs, + check_interval_secs: health_check_config.check_interval_secs, + endpoint: health_check_config.endpoint.clone(), + failure_threshold: health_check_config.failure_threshold, + success_threshold: health_check_config.success_threshold, + }); + Box::new(worker) as Box + }) + .collect(); + + // Initialize policy with workers if needed + if let Some(cache_aware) = policy + .as_any() + .downcast_ref::() + { + cache_aware.init_workers(&workers); + } + + let workers = Arc::new(RwLock::new(workers)); + let health_checker = crate::core::start_health_checker(Arc::clone(&workers), interval_secs); + + Ok(GrpcRouter { + workers, + grpc_clients: Arc::new(RwLock::new(grpc_clients)), + policy, + tokenizer, + reasoning_parser_factory, + tool_parser_registry, + _health_checker: Some(health_checker), + timeout_secs, + interval_secs, + dp_aware, + api_key, + retry_config, + circuit_breaker_config: core_cb_config, + }) + } +} + +impl std::fmt::Debug for GrpcRouter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("GrpcRouter") + .field("workers_count", &self.workers.read().unwrap().len()) + .field("timeout_secs", &self.timeout_secs) + .field("interval_secs", &self.interval_secs) + .field("dp_aware", &self.dp_aware) + .finish() } } diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index b3d5da6f0..8b2e29714 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -9,7 +9,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType use reqwest::Client; use serde_json::json; use sglang_router_rs::config::{ - CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, + CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, }; use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use std::sync::Arc; @@ -55,6 +55,9 @@ impl TestContext { disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), enable_igw: false, + connection_mode: ConnectionMode::Http, + model_path: None, + tokenizer_path: None, }; Self::new_with_config(config, worker_configs).await @@ -1101,6 +1104,9 @@ mod error_tests { disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), enable_igw: false, + connection_mode: ConnectionMode::Http, + model_path: None, + tokenizer_path: None, }; let ctx = TestContext::new_with_config( @@ -1456,6 +1462,9 @@ mod pd_mode_tests { disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), enable_igw: false, + connection_mode: ConnectionMode::Http, + model_path: None, + tokenizer_path: None, }; // Create app context @@ -1615,6 +1624,9 @@ mod request_id_tests { disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), enable_igw: false, + connection_mode: ConnectionMode::Http, + model_path: None, + tokenizer_path: None, }; let ctx = TestContext::new_with_config( diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index 2e91b82a6..606ca0a41 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -4,7 +4,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType use reqwest::Client; use serde_json::json; use sglang_router_rs::config::{ - CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, + CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, }; use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use std::sync::Arc; @@ -46,6 +46,9 @@ impl TestContext { disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), enable_igw: false, + connection_mode: ConnectionMode::Http, + model_path: None, + tokenizer_path: None, }; let mut workers = Vec::new(); diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index ce8f8cfdf..29190a312 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -5,7 +5,7 @@ use futures_util::StreamExt; use reqwest::Client; use serde_json::json; use sglang_router_rs::config::{ - CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, + CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, }; use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use std::sync::Arc; @@ -47,6 +47,9 @@ impl TestContext { disable_circuit_breaker: false, health_check: sglang_router_rs::config::HealthCheckConfig::default(), enable_igw: false, + connection_mode: ConnectionMode::Http, + model_path: None, + tokenizer_path: None, }; let mut workers = Vec::new(); diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index bcea75a6a..8b16fad2a 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -2,7 +2,7 @@ mod test_pd_routing { use serde_json::json; use sglang_router_rs::config::{ - CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, + CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, }; use sglang_router_rs::core::{WorkerFactory, WorkerType}; use sglang_router_rs::routers::http::pd_types::get_hostname; @@ -188,6 +188,9 @@ mod test_pd_routing { health_check: sglang_router_rs::config::HealthCheckConfig::default(), enable_igw: false, rate_limit_tokens_per_second: None, + connection_mode: ConnectionMode::Http, + model_path: None, + tokenizer_path: None, }; // Router creation will fail due to health checks, but config should be valid