[router] cleanup app context and move to startup (#11617)
This commit is contained in:
@@ -9,6 +9,12 @@ pub mod test_app;
|
||||
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::config::RouterConfig;
|
||||
use sglang_router_rs::core::{LoadMonitor, WorkerRegistry};
|
||||
use sglang_router_rs::data_connector::{
|
||||
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
||||
};
|
||||
use sglang_router_rs::middleware::TokenBucket;
|
||||
use sglang_router_rs::policies::PolicyRegistry;
|
||||
use sglang_router_rs::protocols::spec::{Function, Tool};
|
||||
use sglang_router_rs::server::AppContext;
|
||||
use std::fs;
|
||||
@@ -17,15 +23,58 @@ use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
/// Helper function to create AppContext for tests
|
||||
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
||||
Arc::new(
|
||||
AppContext::new(
|
||||
config.clone(),
|
||||
reqwest::Client::new(),
|
||||
config.max_concurrent_requests,
|
||||
config.rate_limit_tokens_per_second,
|
||||
)
|
||||
.expect("Failed to create AppContext in test"),
|
||||
)
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Initialize rate limiter
|
||||
let rate_limiter = match config.max_concurrent_requests {
|
||||
n if n <= 0 => None,
|
||||
n => {
|
||||
let rate_limit_tokens = config
|
||||
.rate_limit_tokens_per_second
|
||||
.filter(|&t| t > 0)
|
||||
.unwrap_or(n);
|
||||
Some(Arc::new(TokenBucket::new(
|
||||
n as usize,
|
||||
rate_limit_tokens as usize,
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize registries
|
||||
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||
let policy_registry = Arc::new(PolicyRegistry::new(config.policy.clone()));
|
||||
|
||||
// Initialize storage backends (Memory for tests)
|
||||
let response_storage = Arc::new(MemoryResponseStorage::new());
|
||||
let conversation_storage = Arc::new(MemoryConversationStorage::new());
|
||||
let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new());
|
||||
|
||||
// Initialize load monitor
|
||||
let load_monitor = Some(Arc::new(LoadMonitor::new(
|
||||
worker_registry.clone(),
|
||||
policy_registry.clone(),
|
||||
client.clone(),
|
||||
config.worker_startup_check_interval_secs,
|
||||
)));
|
||||
|
||||
// Create empty OnceLock for worker job queue
|
||||
let worker_job_queue = Arc::new(OnceLock::new());
|
||||
|
||||
Arc::new(AppContext::new(
|
||||
config,
|
||||
client,
|
||||
rate_limiter,
|
||||
None, // tokenizer
|
||||
None, // reasoning_parser_factory
|
||||
None, // tool_parser_factory
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
response_storage,
|
||||
conversation_storage,
|
||||
conversation_item_storage,
|
||||
load_monitor,
|
||||
worker_job_queue,
|
||||
))
|
||||
}
|
||||
|
||||
// Tokenizer download configuration
|
||||
|
||||
@@ -2,11 +2,16 @@ use axum::Router;
|
||||
use reqwest::Client;
|
||||
use sglang_router_rs::{
|
||||
config::RouterConfig,
|
||||
middleware::AuthConfig,
|
||||
core::{LoadMonitor, WorkerRegistry},
|
||||
data_connector::{
|
||||
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
||||
},
|
||||
middleware::{AuthConfig, TokenBucket},
|
||||
policies::PolicyRegistry,
|
||||
routers::RouterTrait,
|
||||
server::{build_app, AppContext, AppState},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
/// Create a test Axum application using the actual server's build_app function
|
||||
#[allow(dead_code)]
|
||||
@@ -15,16 +20,57 @@ pub fn create_test_app(
|
||||
client: Client,
|
||||
router_config: &RouterConfig,
|
||||
) -> Router {
|
||||
// Initialize rate limiter
|
||||
let rate_limiter = match router_config.max_concurrent_requests {
|
||||
n if n <= 0 => None,
|
||||
n => {
|
||||
let rate_limit_tokens = router_config
|
||||
.rate_limit_tokens_per_second
|
||||
.filter(|&t| t > 0)
|
||||
.unwrap_or(n);
|
||||
Some(Arc::new(TokenBucket::new(
|
||||
n as usize,
|
||||
rate_limit_tokens as usize,
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize registries
|
||||
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||
let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone()));
|
||||
|
||||
// Initialize storage backends
|
||||
let response_storage = Arc::new(MemoryResponseStorage::new());
|
||||
let conversation_storage = Arc::new(MemoryConversationStorage::new());
|
||||
let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new());
|
||||
|
||||
// Initialize load monitor
|
||||
let load_monitor = Some(Arc::new(LoadMonitor::new(
|
||||
worker_registry.clone(),
|
||||
policy_registry.clone(),
|
||||
client.clone(),
|
||||
router_config.worker_startup_check_interval_secs,
|
||||
)));
|
||||
|
||||
// Create empty OnceLock for worker job queue
|
||||
let worker_job_queue = Arc::new(OnceLock::new());
|
||||
|
||||
// Create AppContext
|
||||
let app_context = Arc::new(
|
||||
AppContext::new(
|
||||
router_config.clone(),
|
||||
client,
|
||||
router_config.max_concurrent_requests,
|
||||
router_config.rate_limit_tokens_per_second,
|
||||
)
|
||||
.expect("Failed to create AppContext in test"),
|
||||
);
|
||||
let app_context = Arc::new(AppContext::new(
|
||||
router_config.clone(),
|
||||
client,
|
||||
rate_limiter,
|
||||
None, // tokenizer
|
||||
None, // reasoning_parser_factory
|
||||
None, // tool_parser_factory
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
response_storage,
|
||||
conversation_storage,
|
||||
conversation_item_storage,
|
||||
load_monitor,
|
||||
worker_job_queue,
|
||||
));
|
||||
|
||||
// Create AppState with the test router and context
|
||||
let app_state = Arc::new(AppState {
|
||||
|
||||
@@ -15,8 +15,6 @@ use sglang_router_rs::config::{
|
||||
RouterConfig, RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::routers::RouterFactory;
|
||||
use sglang_router_rs::server::AppContext;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
||||
@@ -83,10 +81,8 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
||||
};
|
||||
|
||||
// Create router and context
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
let ctx = common::create_test_context(router_cfg);
|
||||
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||
|
||||
// Build a simple ResponsesRequest that will trigger the tool call
|
||||
let req = ResponsesRequest {
|
||||
@@ -284,10 +280,8 @@ async fn test_conversations_crud_basic() {
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
let ctx = common::create_test_context(router_cfg);
|
||||
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||
|
||||
// Create
|
||||
let create_body = serde_json::json!({ "metadata": { "project": "alpha" } });
|
||||
@@ -616,10 +610,8 @@ async fn test_multi_turn_loop_with_mcp() {
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
let ctx = common::create_test_context(router_cfg);
|
||||
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||
|
||||
// Build request with MCP tools
|
||||
let req = ResponsesRequest {
|
||||
@@ -794,10 +786,8 @@ async fn test_max_tool_calls_limit() {
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
let ctx = common::create_test_context(router_cfg);
|
||||
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||
|
||||
let req = ResponsesRequest {
|
||||
background: Some(false),
|
||||
@@ -938,10 +928,8 @@ async fn setup_streaming_mcp_test() -> (
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
let ctx = common::create_test_context(router_cfg);
|
||||
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||
|
||||
(mcp, worker, router, dir)
|
||||
}
|
||||
@@ -1381,10 +1369,8 @@ async fn test_conversation_items_create_and_get() {
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
let ctx = common::create_test_context(router_cfg);
|
||||
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||
|
||||
// Create conversation
|
||||
let create_conv = serde_json::json!({});
|
||||
@@ -1484,10 +1470,8 @@ async fn test_conversation_items_delete() {
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
let ctx = common::create_test_context(router_cfg);
|
||||
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||
|
||||
// Create conversation
|
||||
let create_conv = serde_json::json!({});
|
||||
@@ -1593,10 +1577,8 @@ async fn test_conversation_items_max_limit() {
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
let ctx = common::create_test_context(router_cfg);
|
||||
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||
|
||||
// Create conversation
|
||||
let create_conv = serde_json::json!({});
|
||||
@@ -1672,10 +1654,8 @@ async fn test_conversation_items_unsupported_type() {
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
let ctx = common::create_test_context(router_cfg);
|
||||
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||
|
||||
// Create conversation
|
||||
let create_conv = serde_json::json!({});
|
||||
@@ -1750,10 +1730,8 @@ async fn test_conversation_items_multi_conversation_sharing() {
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||
.await
|
||||
.expect("router");
|
||||
let ctx = common::create_test_context(router_cfg);
|
||||
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||
|
||||
// Create two conversations
|
||||
let conv_a_resp = router
|
||||
|
||||
@@ -200,10 +200,56 @@ mod test_pd_routing {
|
||||
tool_call_parser: None,
|
||||
};
|
||||
|
||||
let app_context =
|
||||
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None)
|
||||
.expect("Failed to create AppContext");
|
||||
let app_context = std::sync::Arc::new(app_context);
|
||||
let app_context = {
|
||||
use sglang_router_rs::core::{LoadMonitor, WorkerRegistry};
|
||||
use sglang_router_rs::data_connector::{
|
||||
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
||||
};
|
||||
use sglang_router_rs::middleware::TokenBucket;
|
||||
use sglang_router_rs::policies::PolicyRegistry;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Initialize rate limiter
|
||||
let rate_limiter = Some(Arc::new(TokenBucket::new(64, 64)));
|
||||
|
||||
// Initialize registries
|
||||
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||
let policy_registry = Arc::new(PolicyRegistry::new(config.policy.clone()));
|
||||
|
||||
// Initialize storage backends
|
||||
let response_storage = Arc::new(MemoryResponseStorage::new());
|
||||
let conversation_storage = Arc::new(MemoryConversationStorage::new());
|
||||
let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new());
|
||||
|
||||
// Initialize load monitor
|
||||
let load_monitor = Some(Arc::new(LoadMonitor::new(
|
||||
worker_registry.clone(),
|
||||
policy_registry.clone(),
|
||||
client.clone(),
|
||||
config.worker_startup_check_interval_secs,
|
||||
)));
|
||||
|
||||
// Create empty OnceLock for worker job queue
|
||||
let worker_job_queue = Arc::new(OnceLock::new());
|
||||
|
||||
Arc::new(sglang_router_rs::server::AppContext::new(
|
||||
config,
|
||||
client,
|
||||
rate_limiter,
|
||||
None, // tokenizer
|
||||
None, // reasoning_parser_factory
|
||||
None, // tool_parser_factory
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
response_storage,
|
||||
conversation_storage,
|
||||
conversation_item_storage,
|
||||
load_monitor,
|
||||
worker_job_queue,
|
||||
))
|
||||
};
|
||||
let result = RouterFactory::create_router(&app_context).await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
|
||||
Reference in New Issue
Block a user