diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index f25fc6eea..51c3cdd65 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -986,7 +986,7 @@ pub fn start_health_checker( // Periodically reset load counters to prevent drift // Only do this when we believe all workers should be idle - if check_count.is_multiple_of(LOAD_RESET_INTERVAL) { + if check_count % LOAD_RESET_INTERVAL == 0 { let max_load = workers_to_check.iter().map(|w| w.load()).max().unwrap_or(0); // Only reset if load appears to be very low (likely drift) if max_load <= 2 { diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index 686ab4329..a297a7ede 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -146,17 +146,29 @@ impl RouterFactory { // Create policy let policy = PolicyFactory::create_from_config(policy_config); - // Determine which tokenizer path to use - // Priority: tokenizer_path > model_path - let tokenizer_path = ctx - .router_config - .tokenizer_path - .clone() - .or_else(|| ctx.router_config.model_path.clone()) + // Get tokenizer from context + let tokenizer = ctx + .tokenizer + .as_ref() .ok_or_else(|| { - "gRPC router requires either --tokenizer-path or --model-path to be specified" + "gRPC router requires tokenizer to be initialized in AppContext".to_string() + })? + .clone(); + + // Get reasoning parser factory from context + let reasoning_parser_factory = ctx + .reasoning_parser_factory + .as_ref() + .ok_or_else(|| { + "gRPC router requires reasoning parser factory to be initialized in AppContext" .to_string() - })?; + })? + .clone(); + + // Get tool parser registry from context + let tool_parser_registry = ctx.tool_parser_registry.ok_or_else(|| { + "gRPC router requires tool parser registry to be initialized in AppContext".to_string() + })?; // Create gRPC router let router = GrpcRouter::new( @@ -169,7 +181,9 @@ impl RouterFactory { ctx.router_config.effective_retry_config(), ctx.router_config.effective_circuit_breaker_config(), ctx.router_config.health_check.clone(), - tokenizer_path, + tokenizer, + reasoning_parser_factory, + tool_parser_registry, ) .await?; @@ -193,17 +207,30 @@ impl RouterFactory { let decode_policy = PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config)); - // Determine which tokenizer path to use - // Priority: tokenizer_path > model_path - let tokenizer_path = ctx - .router_config - .tokenizer_path - .clone() - .or_else(|| ctx.router_config.model_path.clone()) + // Get tokenizer from context + let tokenizer = ctx + .tokenizer + .as_ref() .ok_or_else(|| { - "gRPC PD router requires either --tokenizer-path or --model-path to be specified" + "gRPC PD router requires tokenizer to be initialized in AppContext".to_string() + })? + .clone(); + + // Get reasoning parser factory from context + let reasoning_parser_factory = ctx + .reasoning_parser_factory + .as_ref() + .ok_or_else(|| { + "gRPC PD router requires reasoning parser factory to be initialized in AppContext" .to_string() - })?; + })? + .clone(); + + // Get tool parser registry from context + let tool_parser_registry = ctx.tool_parser_registry.ok_or_else(|| { + "gRPC PD router requires tool parser registry to be initialized in AppContext" + .to_string() + })?; // Create gRPC PD router let router = GrpcPDRouter::new( @@ -218,7 +245,9 @@ impl RouterFactory { ctx.router_config.effective_retry_config(), ctx.router_config.effective_circuit_breaker_config(), ctx.router_config.health_check.clone(), - tokenizer_path, + tokenizer, + reasoning_parser_factory, + tool_parser_registry, ) .await?; diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index 2f4c61649..43d143d81 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -12,7 +12,7 @@ use crate::metrics::RouterMetrics; use crate::policies::LoadBalancingPolicy; use crate::reasoning_parser::ParserFactory; use crate::routers::{RouterTrait, WorkerManagement}; -use crate::tokenizer::{factory, traits::Tokenizer}; +use crate::tokenizer::traits::Tokenizer; use crate::tool_parser::ParserRegistry; use async_trait::async_trait; use axum::{ @@ -74,21 +74,13 @@ impl GrpcPDRouter { retry_config: RetryConfig, circuit_breaker_config: ConfigCircuitBreakerConfig, health_check_config: ConfigHealthCheckConfig, - tokenizer_path_or_model: String, + tokenizer: Arc, + reasoning_parser_factory: ParserFactory, + tool_parser_registry: &'static ParserRegistry, ) -> Result { // Update metrics RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len()); - // Initialize tokenizer - let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model) - .map_err(|e| format!("Failed to create tokenizer: {}", e))?; - - // Initialize reasoning parser factory - let reasoning_parser_factory = ParserFactory::new(); - - // Get tool parser registry - let tool_parser_registry = ParserRegistry::new(); - // Convert config CircuitBreakerConfig to core CircuitBreakerConfig let core_cb_config = CircuitBreakerConfig { failure_threshold: circuit_breaker_config.failure_threshold, diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index f81a25917..be2f5ae33 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -12,7 +12,7 @@ use crate::metrics::RouterMetrics; use crate::policies::LoadBalancingPolicy; use crate::reasoning_parser::ParserFactory; use crate::routers::{RouterTrait, WorkerManagement}; -use crate::tokenizer::{factory, traits::Tokenizer}; +use crate::tokenizer::traits::Tokenizer; use crate::tool_parser::ParserRegistry; use async_trait::async_trait; use axum::{ @@ -65,21 +65,13 @@ impl GrpcRouter { retry_config: RetryConfig, circuit_breaker_config: ConfigCircuitBreakerConfig, health_check_config: ConfigHealthCheckConfig, - tokenizer_path_or_model: String, + tokenizer: Arc, + reasoning_parser_factory: ParserFactory, + tool_parser_registry: &'static ParserRegistry, ) -> Result { // Update metrics RouterMetrics::set_active_workers(worker_urls.len()); - // Initialize tokenizer - let tokenizer = factory::create_tokenizer(&tokenizer_path_or_model) - .map_err(|e| format!("Failed to create tokenizer: {}", e))?; - - // Initialize reasoning parser factory - let reasoning_parser_factory = ParserFactory::new(); - - // Get tool parser registry - let tool_parser_registry = ParserRegistry::new(); - // Convert config CircuitBreakerConfig to core CircuitBreakerConfig let core_cb_config = CircuitBreakerConfig { failure_threshold: circuit_breaker_config.failure_threshold, diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index e4af619c9..2762f9765 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -3,8 +3,11 @@ use crate::logging::{self, LoggingConfig}; use crate::metrics::{self, PrometheusConfig}; use crate::middleware::TokenBucket; use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; +use crate::reasoning_parser::ParserFactory; use crate::routers::{RouterFactory, RouterTrait}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; +use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer}; +use crate::tool_parser::ParserRegistry; use axum::{ extract::{Query, Request, State}, http::StatusCode, @@ -27,7 +30,9 @@ pub struct AppContext { pub client: Client, pub router_config: RouterConfig, pub rate_limiter: Arc, - // Future dependencies can be added here + pub tokenizer: Option>, + pub reasoning_parser_factory: Option, + pub tool_parser_registry: Option<&'static ParserRegistry>, } impl AppContext { @@ -36,14 +41,45 @@ impl AppContext { client: Client, max_concurrent_requests: usize, rate_limit_tokens_per_second: Option, - ) -> Self { + ) -> Result { let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests); let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens)); - Self { + + // Initialize gRPC-specific components only when in gRPC mode + let (tokenizer, reasoning_parser_factory, tool_parser_registry) = + if router_config.connection_mode == crate::config::ConnectionMode::Grpc { + // Get tokenizer path (required for gRPC mode) + let tokenizer_path = router_config + .tokenizer_path + .clone() + .or_else(|| router_config.model_path.clone()) + .ok_or_else(|| { + "gRPC mode requires either --tokenizer-path or --model-path to be specified" + .to_string() + })?; + + // Initialize all gRPC components + let tokenizer = Some( + tokenizer_factory::create_tokenizer(&tokenizer_path) + .map_err(|e| format!("Failed to create tokenizer: {}", e))?, + ); + let reasoning_parser_factory = Some(ParserFactory::new()); + let tool_parser_registry = Some(ParserRegistry::new()); + + (tokenizer, reasoning_parser_factory, tool_parser_registry) + } else { + // HTTP mode doesn't need these components + (None, None, None) + }; + + Ok(Self { client, router_config, rate_limiter, - } + tokenizer, + reasoning_parser_factory, + tool_parser_registry, + }) } } @@ -291,7 +327,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box Arc { - Arc::new(AppContext::new( - config.clone(), - reqwest::Client::new(), - config.max_concurrent_requests, - config.rate_limit_tokens_per_second, - )) + Arc::new( + AppContext::new( + config.clone(), + reqwest::Client::new(), + config.max_concurrent_requests, + config.rate_limit_tokens_per_second, + ) + .expect("Failed to create AppContext in test"), + ) } // Tokenizer download configuration diff --git a/sgl-router/tests/common/test_app.rs b/sgl-router/tests/common/test_app.rs index 554845363..83d7d456a 100644 --- a/sgl-router/tests/common/test_app.rs +++ b/sgl-router/tests/common/test_app.rs @@ -15,12 +15,15 @@ pub fn create_test_app( router_config: &RouterConfig, ) -> Router { // Create AppContext - let app_context = Arc::new(AppContext::new( - router_config.clone(), - client, - router_config.max_concurrent_requests, - router_config.rate_limit_tokens_per_second, - )); + let app_context = Arc::new( + AppContext::new( + router_config.clone(), + client, + router_config.max_concurrent_requests, + router_config.rate_limit_tokens_per_second, + ) + .expect("Failed to create AppContext in test"), + ); // Create AppState with the test router and context let app_state = Arc::new(AppState { diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 8b16fad2a..7071106a4 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -195,7 +195,8 @@ mod test_pd_routing { // Router creation will fail due to health checks, but config should be valid let app_context = - sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None); + sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None) + .expect("Failed to create AppContext"); let app_context = std::sync::Arc::new(app_context); let result = RouterFactory::create_router(&app_context).await; assert!(result.is_err());