diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index b84198a3e..d8ad5253f 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -7,8 +7,8 @@ use crate::{ data_connector::{ MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, NoOpConversationStorage, NoOpResponseStorage, OracleConversationItemStorage, - OracleConversationStorage, OracleResponseStorage, SharedConversationStorage, - SharedResponseStorage, + OracleConversationStorage, OracleResponseStorage, SharedConversationItemStorage, + SharedConversationStorage, SharedResponseStorage, }, logging::{self, LoggingConfig}, metrics::{self, PrometheusConfig}, @@ -62,7 +62,7 @@ pub struct AppContext { pub router_manager: Option>, pub response_storage: SharedResponseStorage, pub conversation_storage: SharedConversationStorage, - pub conversation_item_storage: crate::data_connector::SharedConversationItemStorage, + pub conversation_item_storage: SharedConversationItemStorage, pub load_monitor: Option>, pub configured_reasoning_parser: Option, pub configured_tool_parser: Option, @@ -70,135 +70,26 @@ pub struct AppContext { } impl AppContext { + #[allow(clippy::too_many_arguments)] pub fn new( router_config: RouterConfig, client: Client, - max_concurrent_requests: i32, - rate_limit_tokens_per_second: Option, - ) -> Result { - let rate_limiter = match max_concurrent_requests { - n if n <= 0 => None, - n => { - let rate_limit_tokens = - rate_limit_tokens_per_second.filter(|&t| t > 0).unwrap_or(n); - Some(Arc::new(TokenBucket::new( - n as usize, - rate_limit_tokens as usize, - ))) - } - }; - - let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if router_config - .connection_mode - == ConnectionMode::Grpc - { - let tokenizer_path = router_config - .tokenizer_path - .clone() - .or_else(|| router_config.model_path.clone()) - .ok_or_else(|| { - "gRPC mode requires either --tokenizer-path or --model-path to be specified" - .to_string() - })?; - - let tokenizer = Some( - tokenizer_factory::create_tokenizer_with_chat_template_blocking( - &tokenizer_path, - router_config.chat_template.as_deref(), - ) - .map_err(|e| { - format!( - "Failed to create tokenizer from '{}': {}. \ - Ensure the path is valid and points to a tokenizer file (tokenizer.json) \ - or a HuggingFace model ID. For directories, ensure they contain tokenizer files.", - tokenizer_path, e - ) - })?, - ); - let reasoning_parser_factory = Some(crate::reasoning_parser::ParserFactory::new()); - let tool_parser_factory = Some(crate::tool_parser::ParserFactory::new()); - - (tokenizer, reasoning_parser_factory, tool_parser_factory) - } else { - (None, None, None) - }; - - let worker_registry = Arc::new(WorkerRegistry::new()); - let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone())); - - let router_manager = None; - - let (response_storage, conversation_storage): ( - SharedResponseStorage, - SharedConversationStorage, - ) = match router_config.history_backend { - HistoryBackend::Memory => { - info!("Initializing data connector: Memory"); - ( - Arc::new(MemoryResponseStorage::new()), - Arc::new(MemoryConversationStorage::new()), - ) - } - HistoryBackend::None => { - info!("Initializing data connector: None (no persistence)"); - ( - Arc::new(NoOpResponseStorage::new()), - Arc::new(NoOpConversationStorage::new()), - ) - } - HistoryBackend::Oracle => { - let oracle_cfg = router_config.oracle.clone().ok_or_else(|| { - "oracle configuration is required when history_backend=oracle".to_string() - })?; - - info!( - "Initializing data connector: Oracle ATP (pool: {}-{})", - oracle_cfg.pool_min, oracle_cfg.pool_max - ); - - let response_storage = - OracleResponseStorage::new(oracle_cfg.clone()).map_err(|err| { - format!("failed to initialize Oracle response storage: {err}") - })?; - - let conversation_storage = OracleConversationStorage::new(oracle_cfg.clone()) - .map_err(|err| { - format!("failed to initialize Oracle conversation storage: {err}") - })?; - - info!("Data connector initialized successfully: Oracle ATP"); - (Arc::new(response_storage), Arc::new(conversation_storage)) - } - }; - - // Conversation items storage (memory-backed for now) - let conversation_item_storage: crate::data_connector::SharedConversationItemStorage = - match router_config.history_backend { - HistoryBackend::Oracle => { - let oracle_cfg = router_config.oracle.clone().ok_or_else(|| { - "oracle configuration is required when history_backend=oracle".to_string() - })?; - Arc::new(OracleConversationItemStorage::new(oracle_cfg).map_err(|e| { - format!("failed to initialize Oracle conversation item storage: {e}") - })?) - } - _ => Arc::new(MemoryConversationItemStorage::new()), - }; - - let load_monitor = Some(Arc::new(LoadMonitor::new( - worker_registry.clone(), - policy_registry.clone(), - client.clone(), - router_config.worker_startup_check_interval_secs, - ))); - + rate_limiter: Option>, + tokenizer: Option>, + reasoning_parser_factory: Option, + tool_parser_factory: Option, + worker_registry: Arc, + policy_registry: Arc, + response_storage: SharedResponseStorage, + conversation_storage: SharedConversationStorage, + conversation_item_storage: SharedConversationItemStorage, + load_monitor: Option>, + worker_job_queue: Arc>>, + ) -> Self { let configured_reasoning_parser = router_config.reasoning_parser.clone(); let configured_tool_parser = router_config.tool_call_parser.clone(); - // Create empty OnceLock for worker job queue (will be initialized in startup()) - let worker_job_queue = Arc::new(OnceLock::new()); - - Ok(Self { + Self { client, router_config, rate_limiter, @@ -207,7 +98,7 @@ impl AppContext { tool_parser_factory, worker_registry, policy_registry, - router_manager, + router_manager: None, response_storage, conversation_storage, conversation_item_storage, @@ -215,7 +106,7 @@ impl AppContext { configured_reasoning_parser, configured_tool_parser, worker_job_queue, - }) + } } } @@ -936,12 +827,146 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box None, + n => { + let rate_limit_tokens = config + .router_config + .rate_limit_tokens_per_second + .filter(|&t| t > 0) + .unwrap_or(n); + Some(Arc::new(TokenBucket::new( + n as usize, + rate_limit_tokens as usize, + ))) + } + }; + + // Initialize tokenizer and parser factories for gRPC mode + let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if config + .router_config + .connection_mode + == ConnectionMode::Grpc + { + let tokenizer_path = config + .router_config + .tokenizer_path + .clone() + .or_else(|| config.router_config.model_path.clone()) + .ok_or_else(|| { + "gRPC mode requires either --tokenizer-path or --model-path to be specified" + .to_string() + })?; + + let tokenizer = Some( + tokenizer_factory::create_tokenizer_with_chat_template_blocking( + &tokenizer_path, + config.router_config.chat_template.as_deref(), + ) + .map_err(|e| { + format!( + "Failed to create tokenizer from '{}': {}. \ + Ensure the path is valid and points to a tokenizer file (tokenizer.json) \ + or a HuggingFace model ID. For directories, ensure they contain tokenizer files.", + tokenizer_path, e + ) + })?, + ); + let reasoning_parser_factory = Some(ReasoningParserFactory::new()); + let tool_parser_factory = Some(ToolParserFactory::new()); + + (tokenizer, reasoning_parser_factory, tool_parser_factory) + } else { + (None, None, None) + }; + + // Initialize worker registry and policy registry + let worker_registry = Arc::new(WorkerRegistry::new()); + let policy_registry = Arc::new(PolicyRegistry::new(config.router_config.policy.clone())); + + // Initialize storage backends + let (response_storage, conversation_storage): ( + SharedResponseStorage, + SharedConversationStorage, + ) = match config.router_config.history_backend { + HistoryBackend::Memory => { + info!("Initializing data connector: Memory"); + ( + Arc::new(MemoryResponseStorage::new()), + Arc::new(MemoryConversationStorage::new()), + ) + } + HistoryBackend::None => { + info!("Initializing data connector: None (no persistence)"); + ( + Arc::new(NoOpResponseStorage::new()), + Arc::new(NoOpConversationStorage::new()), + ) + } + HistoryBackend::Oracle => { + let oracle_cfg = config.router_config.oracle.clone().ok_or_else(|| { + "oracle configuration is required when history_backend=oracle".to_string() + })?; + info!( + "Initializing data connector: Oracle ATP (pool: {}-{})", + oracle_cfg.pool_min, oracle_cfg.pool_max + ); + + let response_storage = OracleResponseStorage::new(oracle_cfg.clone()) + .map_err(|err| format!("failed to initialize Oracle response storage: {err}"))?; + + let conversation_storage = + OracleConversationStorage::new(oracle_cfg.clone()).map_err(|err| { + format!("failed to initialize Oracle conversation storage: {err}") + })?; + info!("Data connector initialized successfully: Oracle ATP"); + + (Arc::new(response_storage), Arc::new(conversation_storage)) + } + }; + + // Initialize conversation items storage + let conversation_item_storage: SharedConversationItemStorage = + match config.router_config.history_backend { + HistoryBackend::Oracle => { + let oracle_cfg = config.router_config.oracle.clone().ok_or_else(|| { + "oracle configuration is required when history_backend=oracle".to_string() + })?; + Arc::new(OracleConversationItemStorage::new(oracle_cfg).map_err(|e| { + format!("failed to initialize Oracle conversation item storage: {e}") + })?) + } + _ => Arc::new(MemoryConversationItemStorage::new()), + }; + + // Initialize load monitor + let load_monitor = Some(Arc::new(LoadMonitor::new( + worker_registry.clone(), + policy_registry.clone(), + client.clone(), + config.router_config.worker_startup_check_interval_secs, + ))); + + // Create empty OnceLock for worker job queue (will be initialized below) + let worker_job_queue = Arc::new(OnceLock::new()); + + // Create AppContext with all initialized components let app_context = AppContext::new( config.router_config.clone(), client.clone(), - config.router_config.max_concurrent_requests, - config.router_config.rate_limit_tokens_per_second, - )?; + rate_limiter, + tokenizer, + reasoning_parser_factory, + tool_parser_factory, + worker_registry, + policy_registry, + response_storage, + conversation_storage, + conversation_item_storage, + load_monitor, + worker_job_queue, + ); let app_context = Arc::new(app_context); diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs index 4a2ed2f90..dc7e7a8d5 100644 --- a/sgl-router/tests/common/mod.rs +++ b/sgl-router/tests/common/mod.rs @@ -9,6 +9,12 @@ pub mod test_app; use serde_json::json; use sglang_router_rs::config::RouterConfig; +use sglang_router_rs::core::{LoadMonitor, WorkerRegistry}; +use sglang_router_rs::data_connector::{ + MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, +}; +use sglang_router_rs::middleware::TokenBucket; +use sglang_router_rs::policies::PolicyRegistry; use sglang_router_rs::protocols::spec::{Function, Tool}; use sglang_router_rs::server::AppContext; use std::fs; @@ -17,15 +23,58 @@ use std::sync::{Arc, Mutex, OnceLock}; /// Helper function to create AppContext for tests pub fn create_test_context(config: RouterConfig) -> Arc { - Arc::new( - AppContext::new( - config.clone(), - reqwest::Client::new(), - config.max_concurrent_requests, - config.rate_limit_tokens_per_second, - ) - .expect("Failed to create AppContext in test"), - ) + let client = reqwest::Client::new(); + + // Initialize rate limiter + let rate_limiter = match config.max_concurrent_requests { + n if n <= 0 => None, + n => { + let rate_limit_tokens = config + .rate_limit_tokens_per_second + .filter(|&t| t > 0) + .unwrap_or(n); + Some(Arc::new(TokenBucket::new( + n as usize, + rate_limit_tokens as usize, + ))) + } + }; + + // Initialize registries + let worker_registry = Arc::new(WorkerRegistry::new()); + let policy_registry = Arc::new(PolicyRegistry::new(config.policy.clone())); + + // Initialize storage backends (Memory for tests) + let response_storage = Arc::new(MemoryResponseStorage::new()); + let conversation_storage = Arc::new(MemoryConversationStorage::new()); + let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new()); + + // Initialize load monitor + let load_monitor = Some(Arc::new(LoadMonitor::new( + worker_registry.clone(), + policy_registry.clone(), + client.clone(), + config.worker_startup_check_interval_secs, + ))); + + // Create empty OnceLock for worker job queue + let worker_job_queue = Arc::new(OnceLock::new()); + + Arc::new(AppContext::new( + config, + client, + rate_limiter, + None, // tokenizer + None, // reasoning_parser_factory + None, // tool_parser_factory + worker_registry, + policy_registry, + response_storage, + conversation_storage, + conversation_item_storage, + load_monitor, + worker_job_queue, + )) } // Tokenizer download configuration diff --git a/sgl-router/tests/common/test_app.rs b/sgl-router/tests/common/test_app.rs index 50959eec0..c293b7074 100644 --- a/sgl-router/tests/common/test_app.rs +++ b/sgl-router/tests/common/test_app.rs @@ -2,11 +2,16 @@ use axum::Router; use reqwest::Client; use sglang_router_rs::{ config::RouterConfig, - middleware::AuthConfig, + core::{LoadMonitor, WorkerRegistry}, + data_connector::{ + MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, + }, + middleware::{AuthConfig, TokenBucket}, + policies::PolicyRegistry, routers::RouterTrait, server::{build_app, AppContext, AppState}, }; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; /// Create a test Axum application using the actual server's build_app function #[allow(dead_code)] @@ -15,16 +20,57 @@ pub fn create_test_app( client: Client, router_config: &RouterConfig, ) -> Router { + // Initialize rate limiter + let rate_limiter = match router_config.max_concurrent_requests { + n if n <= 0 => None, + n => { + let rate_limit_tokens = router_config + .rate_limit_tokens_per_second + .filter(|&t| t > 0) + .unwrap_or(n); + Some(Arc::new(TokenBucket::new( + n as usize, + rate_limit_tokens as usize, + ))) + } + }; + + // Initialize registries + let worker_registry = Arc::new(WorkerRegistry::new()); + let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone())); + + // Initialize storage backends + let response_storage = Arc::new(MemoryResponseStorage::new()); + let conversation_storage = Arc::new(MemoryConversationStorage::new()); + let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new()); + + // Initialize load monitor + let load_monitor = Some(Arc::new(LoadMonitor::new( + worker_registry.clone(), + policy_registry.clone(), + client.clone(), + router_config.worker_startup_check_interval_secs, + ))); + + // Create empty OnceLock for worker job queue + let worker_job_queue = Arc::new(OnceLock::new()); + // Create AppContext - let app_context = Arc::new( - AppContext::new( - router_config.clone(), - client, - router_config.max_concurrent_requests, - router_config.rate_limit_tokens_per_second, - ) - .expect("Failed to create AppContext in test"), - ); + let app_context = Arc::new(AppContext::new( + router_config.clone(), + client, + rate_limiter, + None, // tokenizer + None, // reasoning_parser_factory + None, // tool_parser_factory + worker_registry, + policy_registry, + response_storage, + conversation_storage, + conversation_item_storage, + load_monitor, + worker_job_queue, + )); // Create AppState with the test router and context let app_state = Arc::new(AppState { diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index 9d641d795..cd155baf8 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -15,8 +15,6 @@ use sglang_router_rs::config::{ RouterConfig, RoutingMode, }; use sglang_router_rs::routers::RouterFactory; -use sglang_router_rs::server::AppContext; -use std::sync::Arc; #[tokio::test] async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { @@ -83,10 +81,8 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { }; // Create router and context - let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx"); - let router = RouterFactory::create_router(&Arc::new(ctx)) - .await - .expect("router"); + let ctx = common::create_test_context(router_cfg); + let router = RouterFactory::create_router(&ctx).await.expect("router"); // Build a simple ResponsesRequest that will trigger the tool call let req = ResponsesRequest { @@ -284,10 +280,8 @@ async fn test_conversations_crud_basic() { tool_call_parser: None, }; - let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx"); - let router = RouterFactory::create_router(&Arc::new(ctx)) - .await - .expect("router"); + let ctx = common::create_test_context(router_cfg); + let router = RouterFactory::create_router(&ctx).await.expect("router"); // Create let create_body = serde_json::json!({ "metadata": { "project": "alpha" } }); @@ -616,10 +610,8 @@ async fn test_multi_turn_loop_with_mcp() { tool_call_parser: None, }; - let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx"); - let router = RouterFactory::create_router(&Arc::new(ctx)) - .await - .expect("router"); + let ctx = common::create_test_context(router_cfg); + let router = RouterFactory::create_router(&ctx).await.expect("router"); // Build request with MCP tools let req = ResponsesRequest { @@ -794,10 +786,8 @@ async fn test_max_tool_calls_limit() { tool_call_parser: None, }; - let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx"); - let router = RouterFactory::create_router(&Arc::new(ctx)) - .await - .expect("router"); + let ctx = common::create_test_context(router_cfg); + let router = RouterFactory::create_router(&ctx).await.expect("router"); let req = ResponsesRequest { background: Some(false), @@ -938,10 +928,8 @@ async fn setup_streaming_mcp_test() -> ( tool_call_parser: None, }; - let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx"); - let router = RouterFactory::create_router(&Arc::new(ctx)) - .await - .expect("router"); + let ctx = common::create_test_context(router_cfg); + let router = RouterFactory::create_router(&ctx).await.expect("router"); (mcp, worker, router, dir) } @@ -1381,10 +1369,8 @@ async fn test_conversation_items_create_and_get() { tool_call_parser: None, }; - let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx"); - let router = RouterFactory::create_router(&Arc::new(ctx)) - .await - .expect("router"); + let ctx = common::create_test_context(router_cfg); + let router = RouterFactory::create_router(&ctx).await.expect("router"); // Create conversation let create_conv = serde_json::json!({}); @@ -1484,10 +1470,8 @@ async fn test_conversation_items_delete() { tool_call_parser: None, }; - let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx"); - let router = RouterFactory::create_router(&Arc::new(ctx)) - .await - .expect("router"); + let ctx = common::create_test_context(router_cfg); + let router = RouterFactory::create_router(&ctx).await.expect("router"); // Create conversation let create_conv = serde_json::json!({}); @@ -1593,10 +1577,8 @@ async fn test_conversation_items_max_limit() { tool_call_parser: None, }; - let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx"); - let router = RouterFactory::create_router(&Arc::new(ctx)) - .await - .expect("router"); + let ctx = common::create_test_context(router_cfg); + let router = RouterFactory::create_router(&ctx).await.expect("router"); // Create conversation let create_conv = serde_json::json!({}); @@ -1672,10 +1654,8 @@ async fn test_conversation_items_unsupported_type() { tool_call_parser: None, }; - let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx"); - let router = RouterFactory::create_router(&Arc::new(ctx)) - .await - .expect("router"); + let ctx = common::create_test_context(router_cfg); + let router = RouterFactory::create_router(&ctx).await.expect("router"); // Create conversation let create_conv = serde_json::json!({}); @@ -1750,10 +1730,8 @@ async fn test_conversation_items_multi_conversation_sharing() { tool_call_parser: None, }; - let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx"); - let router = RouterFactory::create_router(&Arc::new(ctx)) - .await - .expect("router"); + let ctx = common::create_test_context(router_cfg); + let router = RouterFactory::create_router(&ctx).await.expect("router"); // Create two conversations let conv_a_resp = router diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 5ca3a6c1e..ae50693b9 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -200,10 +200,56 @@ mod test_pd_routing { tool_call_parser: None, }; - let app_context = - sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None) - .expect("Failed to create AppContext"); - let app_context = std::sync::Arc::new(app_context); + let app_context = { + use sglang_router_rs::core::{LoadMonitor, WorkerRegistry}; + use sglang_router_rs::data_connector::{ + MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, + }; + use sglang_router_rs::middleware::TokenBucket; + use sglang_router_rs::policies::PolicyRegistry; + use std::sync::{Arc, OnceLock}; + + let client = reqwest::Client::new(); + + // Initialize rate limiter + let rate_limiter = Some(Arc::new(TokenBucket::new(64, 64))); + + // Initialize registries + let worker_registry = Arc::new(WorkerRegistry::new()); + let policy_registry = Arc::new(PolicyRegistry::new(config.policy.clone())); + + // Initialize storage backends + let response_storage = Arc::new(MemoryResponseStorage::new()); + let conversation_storage = Arc::new(MemoryConversationStorage::new()); + let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new()); + + // Initialize load monitor + let load_monitor = Some(Arc::new(LoadMonitor::new( + worker_registry.clone(), + policy_registry.clone(), + client.clone(), + config.worker_startup_check_interval_secs, + ))); + + // Create empty OnceLock for worker job queue + let worker_job_queue = Arc::new(OnceLock::new()); + + Arc::new(sglang_router_rs::server::AppContext::new( + config, + client, + rate_limiter, + None, // tokenizer + None, // reasoning_parser_factory + None, // tool_parser_factory + worker_registry, + policy_registry, + response_storage, + conversation_storage, + conversation_item_storage, + load_monitor, + worker_job_queue, + )) + }; let result = RouterFactory::create_router(&app_context).await; assert!( result.is_ok(),