[router] cleanup app context and move to startup (#11617)
This commit is contained in:
@@ -7,8 +7,8 @@ use crate::{
|
|||||||
data_connector::{
|
data_connector::{
|
||||||
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
||||||
NoOpConversationStorage, NoOpResponseStorage, OracleConversationItemStorage,
|
NoOpConversationStorage, NoOpResponseStorage, OracleConversationItemStorage,
|
||||||
OracleConversationStorage, OracleResponseStorage, SharedConversationStorage,
|
OracleConversationStorage, OracleResponseStorage, SharedConversationItemStorage,
|
||||||
SharedResponseStorage,
|
SharedConversationStorage, SharedResponseStorage,
|
||||||
},
|
},
|
||||||
logging::{self, LoggingConfig},
|
logging::{self, LoggingConfig},
|
||||||
metrics::{self, PrometheusConfig},
|
metrics::{self, PrometheusConfig},
|
||||||
@@ -62,7 +62,7 @@ pub struct AppContext {
|
|||||||
pub router_manager: Option<Arc<RouterManager>>,
|
pub router_manager: Option<Arc<RouterManager>>,
|
||||||
pub response_storage: SharedResponseStorage,
|
pub response_storage: SharedResponseStorage,
|
||||||
pub conversation_storage: SharedConversationStorage,
|
pub conversation_storage: SharedConversationStorage,
|
||||||
pub conversation_item_storage: crate::data_connector::SharedConversationItemStorage,
|
pub conversation_item_storage: SharedConversationItemStorage,
|
||||||
pub load_monitor: Option<Arc<LoadMonitor>>,
|
pub load_monitor: Option<Arc<LoadMonitor>>,
|
||||||
pub configured_reasoning_parser: Option<String>,
|
pub configured_reasoning_parser: Option<String>,
|
||||||
pub configured_tool_parser: Option<String>,
|
pub configured_tool_parser: Option<String>,
|
||||||
@@ -70,135 +70,26 @@ pub struct AppContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl AppContext {
|
impl AppContext {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
router_config: RouterConfig,
|
router_config: RouterConfig,
|
||||||
client: Client,
|
client: Client,
|
||||||
max_concurrent_requests: i32,
|
rate_limiter: Option<Arc<TokenBucket>>,
|
||||||
rate_limit_tokens_per_second: Option<i32>,
|
tokenizer: Option<Arc<dyn Tokenizer>>,
|
||||||
) -> Result<Self, String> {
|
reasoning_parser_factory: Option<ReasoningParserFactory>,
|
||||||
let rate_limiter = match max_concurrent_requests {
|
tool_parser_factory: Option<ToolParserFactory>,
|
||||||
n if n <= 0 => None,
|
worker_registry: Arc<WorkerRegistry>,
|
||||||
n => {
|
policy_registry: Arc<PolicyRegistry>,
|
||||||
let rate_limit_tokens =
|
response_storage: SharedResponseStorage,
|
||||||
rate_limit_tokens_per_second.filter(|&t| t > 0).unwrap_or(n);
|
conversation_storage: SharedConversationStorage,
|
||||||
Some(Arc::new(TokenBucket::new(
|
conversation_item_storage: SharedConversationItemStorage,
|
||||||
n as usize,
|
load_monitor: Option<Arc<LoadMonitor>>,
|
||||||
rate_limit_tokens as usize,
|
worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>,
|
||||||
)))
|
) -> Self {
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if router_config
|
|
||||||
.connection_mode
|
|
||||||
== ConnectionMode::Grpc
|
|
||||||
{
|
|
||||||
let tokenizer_path = router_config
|
|
||||||
.tokenizer_path
|
|
||||||
.clone()
|
|
||||||
.or_else(|| router_config.model_path.clone())
|
|
||||||
.ok_or_else(|| {
|
|
||||||
"gRPC mode requires either --tokenizer-path or --model-path to be specified"
|
|
||||||
.to_string()
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let tokenizer = Some(
|
|
||||||
tokenizer_factory::create_tokenizer_with_chat_template_blocking(
|
|
||||||
&tokenizer_path,
|
|
||||||
router_config.chat_template.as_deref(),
|
|
||||||
)
|
|
||||||
.map_err(|e| {
|
|
||||||
format!(
|
|
||||||
"Failed to create tokenizer from '{}': {}. \
|
|
||||||
Ensure the path is valid and points to a tokenizer file (tokenizer.json) \
|
|
||||||
or a HuggingFace model ID. For directories, ensure they contain tokenizer files.",
|
|
||||||
tokenizer_path, e
|
|
||||||
)
|
|
||||||
})?,
|
|
||||||
);
|
|
||||||
let reasoning_parser_factory = Some(crate::reasoning_parser::ParserFactory::new());
|
|
||||||
let tool_parser_factory = Some(crate::tool_parser::ParserFactory::new());
|
|
||||||
|
|
||||||
(tokenizer, reasoning_parser_factory, tool_parser_factory)
|
|
||||||
} else {
|
|
||||||
(None, None, None)
|
|
||||||
};
|
|
||||||
|
|
||||||
let worker_registry = Arc::new(WorkerRegistry::new());
|
|
||||||
let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone()));
|
|
||||||
|
|
||||||
let router_manager = None;
|
|
||||||
|
|
||||||
let (response_storage, conversation_storage): (
|
|
||||||
SharedResponseStorage,
|
|
||||||
SharedConversationStorage,
|
|
||||||
) = match router_config.history_backend {
|
|
||||||
HistoryBackend::Memory => {
|
|
||||||
info!("Initializing data connector: Memory");
|
|
||||||
(
|
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
|
||||||
Arc::new(MemoryConversationStorage::new()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
HistoryBackend::None => {
|
|
||||||
info!("Initializing data connector: None (no persistence)");
|
|
||||||
(
|
|
||||||
Arc::new(NoOpResponseStorage::new()),
|
|
||||||
Arc::new(NoOpConversationStorage::new()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
HistoryBackend::Oracle => {
|
|
||||||
let oracle_cfg = router_config.oracle.clone().ok_or_else(|| {
|
|
||||||
"oracle configuration is required when history_backend=oracle".to_string()
|
|
||||||
})?;
|
|
||||||
|
|
||||||
info!(
|
|
||||||
"Initializing data connector: Oracle ATP (pool: {}-{})",
|
|
||||||
oracle_cfg.pool_min, oracle_cfg.pool_max
|
|
||||||
);
|
|
||||||
|
|
||||||
let response_storage =
|
|
||||||
OracleResponseStorage::new(oracle_cfg.clone()).map_err(|err| {
|
|
||||||
format!("failed to initialize Oracle response storage: {err}")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let conversation_storage = OracleConversationStorage::new(oracle_cfg.clone())
|
|
||||||
.map_err(|err| {
|
|
||||||
format!("failed to initialize Oracle conversation storage: {err}")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
info!("Data connector initialized successfully: Oracle ATP");
|
|
||||||
(Arc::new(response_storage), Arc::new(conversation_storage))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Conversation items storage (memory-backed for now)
|
|
||||||
let conversation_item_storage: crate::data_connector::SharedConversationItemStorage =
|
|
||||||
match router_config.history_backend {
|
|
||||||
HistoryBackend::Oracle => {
|
|
||||||
let oracle_cfg = router_config.oracle.clone().ok_or_else(|| {
|
|
||||||
"oracle configuration is required when history_backend=oracle".to_string()
|
|
||||||
})?;
|
|
||||||
Arc::new(OracleConversationItemStorage::new(oracle_cfg).map_err(|e| {
|
|
||||||
format!("failed to initialize Oracle conversation item storage: {e}")
|
|
||||||
})?)
|
|
||||||
}
|
|
||||||
_ => Arc::new(MemoryConversationItemStorage::new()),
|
|
||||||
};
|
|
||||||
|
|
||||||
let load_monitor = Some(Arc::new(LoadMonitor::new(
|
|
||||||
worker_registry.clone(),
|
|
||||||
policy_registry.clone(),
|
|
||||||
client.clone(),
|
|
||||||
router_config.worker_startup_check_interval_secs,
|
|
||||||
)));
|
|
||||||
|
|
||||||
let configured_reasoning_parser = router_config.reasoning_parser.clone();
|
let configured_reasoning_parser = router_config.reasoning_parser.clone();
|
||||||
let configured_tool_parser = router_config.tool_call_parser.clone();
|
let configured_tool_parser = router_config.tool_call_parser.clone();
|
||||||
|
|
||||||
// Create empty OnceLock for worker job queue (will be initialized in startup())
|
Self {
|
||||||
let worker_job_queue = Arc::new(OnceLock::new());
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
client,
|
client,
|
||||||
router_config,
|
router_config,
|
||||||
rate_limiter,
|
rate_limiter,
|
||||||
@@ -207,7 +98,7 @@ impl AppContext {
|
|||||||
tool_parser_factory,
|
tool_parser_factory,
|
||||||
worker_registry,
|
worker_registry,
|
||||||
policy_registry,
|
policy_registry,
|
||||||
router_manager,
|
router_manager: None,
|
||||||
response_storage,
|
response_storage,
|
||||||
conversation_storage,
|
conversation_storage,
|
||||||
conversation_item_storage,
|
conversation_item_storage,
|
||||||
@@ -215,7 +106,7 @@ impl AppContext {
|
|||||||
configured_reasoning_parser,
|
configured_reasoning_parser,
|
||||||
configured_tool_parser,
|
configured_tool_parser,
|
||||||
worker_job_queue,
|
worker_job_queue,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -936,12 +827,146 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
.build()
|
.build()
|
||||||
.expect("Failed to create HTTP client");
|
.expect("Failed to create HTTP client");
|
||||||
|
|
||||||
|
// Initialize rate limiter
|
||||||
|
let rate_limiter = match config.router_config.max_concurrent_requests {
|
||||||
|
n if n <= 0 => None,
|
||||||
|
n => {
|
||||||
|
let rate_limit_tokens = config
|
||||||
|
.router_config
|
||||||
|
.rate_limit_tokens_per_second
|
||||||
|
.filter(|&t| t > 0)
|
||||||
|
.unwrap_or(n);
|
||||||
|
Some(Arc::new(TokenBucket::new(
|
||||||
|
n as usize,
|
||||||
|
rate_limit_tokens as usize,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initialize tokenizer and parser factories for gRPC mode
|
||||||
|
let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if config
|
||||||
|
.router_config
|
||||||
|
.connection_mode
|
||||||
|
== ConnectionMode::Grpc
|
||||||
|
{
|
||||||
|
let tokenizer_path = config
|
||||||
|
.router_config
|
||||||
|
.tokenizer_path
|
||||||
|
.clone()
|
||||||
|
.or_else(|| config.router_config.model_path.clone())
|
||||||
|
.ok_or_else(|| {
|
||||||
|
"gRPC mode requires either --tokenizer-path or --model-path to be specified"
|
||||||
|
.to_string()
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let tokenizer = Some(
|
||||||
|
tokenizer_factory::create_tokenizer_with_chat_template_blocking(
|
||||||
|
&tokenizer_path,
|
||||||
|
config.router_config.chat_template.as_deref(),
|
||||||
|
)
|
||||||
|
.map_err(|e| {
|
||||||
|
format!(
|
||||||
|
"Failed to create tokenizer from '{}': {}. \
|
||||||
|
Ensure the path is valid and points to a tokenizer file (tokenizer.json) \
|
||||||
|
or a HuggingFace model ID. For directories, ensure they contain tokenizer files.",
|
||||||
|
tokenizer_path, e
|
||||||
|
)
|
||||||
|
})?,
|
||||||
|
);
|
||||||
|
let reasoning_parser_factory = Some(ReasoningParserFactory::new());
|
||||||
|
let tool_parser_factory = Some(ToolParserFactory::new());
|
||||||
|
|
||||||
|
(tokenizer, reasoning_parser_factory, tool_parser_factory)
|
||||||
|
} else {
|
||||||
|
(None, None, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initialize worker registry and policy registry
|
||||||
|
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||||
|
let policy_registry = Arc::new(PolicyRegistry::new(config.router_config.policy.clone()));
|
||||||
|
|
||||||
|
// Initialize storage backends
|
||||||
|
let (response_storage, conversation_storage): (
|
||||||
|
SharedResponseStorage,
|
||||||
|
SharedConversationStorage,
|
||||||
|
) = match config.router_config.history_backend {
|
||||||
|
HistoryBackend::Memory => {
|
||||||
|
info!("Initializing data connector: Memory");
|
||||||
|
(
|
||||||
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
HistoryBackend::None => {
|
||||||
|
info!("Initializing data connector: None (no persistence)");
|
||||||
|
(
|
||||||
|
Arc::new(NoOpResponseStorage::new()),
|
||||||
|
Arc::new(NoOpConversationStorage::new()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
HistoryBackend::Oracle => {
|
||||||
|
let oracle_cfg = config.router_config.oracle.clone().ok_or_else(|| {
|
||||||
|
"oracle configuration is required when history_backend=oracle".to_string()
|
||||||
|
})?;
|
||||||
|
info!(
|
||||||
|
"Initializing data connector: Oracle ATP (pool: {}-{})",
|
||||||
|
oracle_cfg.pool_min, oracle_cfg.pool_max
|
||||||
|
);
|
||||||
|
|
||||||
|
let response_storage = OracleResponseStorage::new(oracle_cfg.clone())
|
||||||
|
.map_err(|err| format!("failed to initialize Oracle response storage: {err}"))?;
|
||||||
|
|
||||||
|
let conversation_storage =
|
||||||
|
OracleConversationStorage::new(oracle_cfg.clone()).map_err(|err| {
|
||||||
|
format!("failed to initialize Oracle conversation storage: {err}")
|
||||||
|
})?;
|
||||||
|
info!("Data connector initialized successfully: Oracle ATP");
|
||||||
|
|
||||||
|
(Arc::new(response_storage), Arc::new(conversation_storage))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initialize conversation items storage
|
||||||
|
let conversation_item_storage: SharedConversationItemStorage =
|
||||||
|
match config.router_config.history_backend {
|
||||||
|
HistoryBackend::Oracle => {
|
||||||
|
let oracle_cfg = config.router_config.oracle.clone().ok_or_else(|| {
|
||||||
|
"oracle configuration is required when history_backend=oracle".to_string()
|
||||||
|
})?;
|
||||||
|
Arc::new(OracleConversationItemStorage::new(oracle_cfg).map_err(|e| {
|
||||||
|
format!("failed to initialize Oracle conversation item storage: {e}")
|
||||||
|
})?)
|
||||||
|
}
|
||||||
|
_ => Arc::new(MemoryConversationItemStorage::new()),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initialize load monitor
|
||||||
|
let load_monitor = Some(Arc::new(LoadMonitor::new(
|
||||||
|
worker_registry.clone(),
|
||||||
|
policy_registry.clone(),
|
||||||
|
client.clone(),
|
||||||
|
config.router_config.worker_startup_check_interval_secs,
|
||||||
|
)));
|
||||||
|
|
||||||
|
// Create empty OnceLock for worker job queue (will be initialized below)
|
||||||
|
let worker_job_queue = Arc::new(OnceLock::new());
|
||||||
|
|
||||||
|
// Create AppContext with all initialized components
|
||||||
let app_context = AppContext::new(
|
let app_context = AppContext::new(
|
||||||
config.router_config.clone(),
|
config.router_config.clone(),
|
||||||
client.clone(),
|
client.clone(),
|
||||||
config.router_config.max_concurrent_requests,
|
rate_limiter,
|
||||||
config.router_config.rate_limit_tokens_per_second,
|
tokenizer,
|
||||||
)?;
|
reasoning_parser_factory,
|
||||||
|
tool_parser_factory,
|
||||||
|
worker_registry,
|
||||||
|
policy_registry,
|
||||||
|
response_storage,
|
||||||
|
conversation_storage,
|
||||||
|
conversation_item_storage,
|
||||||
|
load_monitor,
|
||||||
|
worker_job_queue,
|
||||||
|
);
|
||||||
|
|
||||||
let app_context = Arc::new(app_context);
|
let app_context = Arc::new(app_context);
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,12 @@ pub mod test_app;
|
|||||||
|
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::config::RouterConfig;
|
use sglang_router_rs::config::RouterConfig;
|
||||||
|
use sglang_router_rs::core::{LoadMonitor, WorkerRegistry};
|
||||||
|
use sglang_router_rs::data_connector::{
|
||||||
|
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
||||||
|
};
|
||||||
|
use sglang_router_rs::middleware::TokenBucket;
|
||||||
|
use sglang_router_rs::policies::PolicyRegistry;
|
||||||
use sglang_router_rs::protocols::spec::{Function, Tool};
|
use sglang_router_rs::protocols::spec::{Function, Tool};
|
||||||
use sglang_router_rs::server::AppContext;
|
use sglang_router_rs::server::AppContext;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
@@ -17,15 +23,58 @@ use std::sync::{Arc, Mutex, OnceLock};
|
|||||||
|
|
||||||
/// Helper function to create AppContext for tests
|
/// Helper function to create AppContext for tests
|
||||||
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
||||||
Arc::new(
|
let client = reqwest::Client::new();
|
||||||
AppContext::new(
|
|
||||||
config.clone(),
|
// Initialize rate limiter
|
||||||
reqwest::Client::new(),
|
let rate_limiter = match config.max_concurrent_requests {
|
||||||
config.max_concurrent_requests,
|
n if n <= 0 => None,
|
||||||
config.rate_limit_tokens_per_second,
|
n => {
|
||||||
)
|
let rate_limit_tokens = config
|
||||||
.expect("Failed to create AppContext in test"),
|
.rate_limit_tokens_per_second
|
||||||
)
|
.filter(|&t| t > 0)
|
||||||
|
.unwrap_or(n);
|
||||||
|
Some(Arc::new(TokenBucket::new(
|
||||||
|
n as usize,
|
||||||
|
rate_limit_tokens as usize,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initialize registries
|
||||||
|
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||||
|
let policy_registry = Arc::new(PolicyRegistry::new(config.policy.clone()));
|
||||||
|
|
||||||
|
// Initialize storage backends (Memory for tests)
|
||||||
|
let response_storage = Arc::new(MemoryResponseStorage::new());
|
||||||
|
let conversation_storage = Arc::new(MemoryConversationStorage::new());
|
||||||
|
let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new());
|
||||||
|
|
||||||
|
// Initialize load monitor
|
||||||
|
let load_monitor = Some(Arc::new(LoadMonitor::new(
|
||||||
|
worker_registry.clone(),
|
||||||
|
policy_registry.clone(),
|
||||||
|
client.clone(),
|
||||||
|
config.worker_startup_check_interval_secs,
|
||||||
|
)));
|
||||||
|
|
||||||
|
// Create empty OnceLock for worker job queue
|
||||||
|
let worker_job_queue = Arc::new(OnceLock::new());
|
||||||
|
|
||||||
|
Arc::new(AppContext::new(
|
||||||
|
config,
|
||||||
|
client,
|
||||||
|
rate_limiter,
|
||||||
|
None, // tokenizer
|
||||||
|
None, // reasoning_parser_factory
|
||||||
|
None, // tool_parser_factory
|
||||||
|
worker_registry,
|
||||||
|
policy_registry,
|
||||||
|
response_storage,
|
||||||
|
conversation_storage,
|
||||||
|
conversation_item_storage,
|
||||||
|
load_monitor,
|
||||||
|
worker_job_queue,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tokenizer download configuration
|
// Tokenizer download configuration
|
||||||
|
|||||||
@@ -2,11 +2,16 @@ use axum::Router;
|
|||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use sglang_router_rs::{
|
use sglang_router_rs::{
|
||||||
config::RouterConfig,
|
config::RouterConfig,
|
||||||
middleware::AuthConfig,
|
core::{LoadMonitor, WorkerRegistry},
|
||||||
|
data_connector::{
|
||||||
|
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
||||||
|
},
|
||||||
|
middleware::{AuthConfig, TokenBucket},
|
||||||
|
policies::PolicyRegistry,
|
||||||
routers::RouterTrait,
|
routers::RouterTrait,
|
||||||
server::{build_app, AppContext, AppState},
|
server::{build_app, AppContext, AppState},
|
||||||
};
|
};
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, OnceLock};
|
||||||
|
|
||||||
/// Create a test Axum application using the actual server's build_app function
|
/// Create a test Axum application using the actual server's build_app function
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
@@ -15,16 +20,57 @@ pub fn create_test_app(
|
|||||||
client: Client,
|
client: Client,
|
||||||
router_config: &RouterConfig,
|
router_config: &RouterConfig,
|
||||||
) -> Router {
|
) -> Router {
|
||||||
|
// Initialize rate limiter
|
||||||
|
let rate_limiter = match router_config.max_concurrent_requests {
|
||||||
|
n if n <= 0 => None,
|
||||||
|
n => {
|
||||||
|
let rate_limit_tokens = router_config
|
||||||
|
.rate_limit_tokens_per_second
|
||||||
|
.filter(|&t| t > 0)
|
||||||
|
.unwrap_or(n);
|
||||||
|
Some(Arc::new(TokenBucket::new(
|
||||||
|
n as usize,
|
||||||
|
rate_limit_tokens as usize,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initialize registries
|
||||||
|
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||||
|
let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone()));
|
||||||
|
|
||||||
|
// Initialize storage backends
|
||||||
|
let response_storage = Arc::new(MemoryResponseStorage::new());
|
||||||
|
let conversation_storage = Arc::new(MemoryConversationStorage::new());
|
||||||
|
let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new());
|
||||||
|
|
||||||
|
// Initialize load monitor
|
||||||
|
let load_monitor = Some(Arc::new(LoadMonitor::new(
|
||||||
|
worker_registry.clone(),
|
||||||
|
policy_registry.clone(),
|
||||||
|
client.clone(),
|
||||||
|
router_config.worker_startup_check_interval_secs,
|
||||||
|
)));
|
||||||
|
|
||||||
|
// Create empty OnceLock for worker job queue
|
||||||
|
let worker_job_queue = Arc::new(OnceLock::new());
|
||||||
|
|
||||||
// Create AppContext
|
// Create AppContext
|
||||||
let app_context = Arc::new(
|
let app_context = Arc::new(AppContext::new(
|
||||||
AppContext::new(
|
router_config.clone(),
|
||||||
router_config.clone(),
|
client,
|
||||||
client,
|
rate_limiter,
|
||||||
router_config.max_concurrent_requests,
|
None, // tokenizer
|
||||||
router_config.rate_limit_tokens_per_second,
|
None, // reasoning_parser_factory
|
||||||
)
|
None, // tool_parser_factory
|
||||||
.expect("Failed to create AppContext in test"),
|
worker_registry,
|
||||||
);
|
policy_registry,
|
||||||
|
response_storage,
|
||||||
|
conversation_storage,
|
||||||
|
conversation_item_storage,
|
||||||
|
load_monitor,
|
||||||
|
worker_job_queue,
|
||||||
|
));
|
||||||
|
|
||||||
// Create AppState with the test router and context
|
// Create AppState with the test router and context
|
||||||
let app_state = Arc::new(AppState {
|
let app_state = Arc::new(AppState {
|
||||||
|
|||||||
@@ -15,8 +15,6 @@ use sglang_router_rs::config::{
|
|||||||
RouterConfig, RoutingMode,
|
RouterConfig, RoutingMode,
|
||||||
};
|
};
|
||||||
use sglang_router_rs::routers::RouterFactory;
|
use sglang_router_rs::routers::RouterFactory;
|
||||||
use sglang_router_rs::server::AppContext;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
||||||
@@ -83,10 +81,8 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Create router and context
|
// Create router and context
|
||||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
let ctx = common::create_test_context(router_cfg);
|
||||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||||
.await
|
|
||||||
.expect("router");
|
|
||||||
|
|
||||||
// Build a simple ResponsesRequest that will trigger the tool call
|
// Build a simple ResponsesRequest that will trigger the tool call
|
||||||
let req = ResponsesRequest {
|
let req = ResponsesRequest {
|
||||||
@@ -284,10 +280,8 @@ async fn test_conversations_crud_basic() {
|
|||||||
tool_call_parser: None,
|
tool_call_parser: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
let ctx = common::create_test_context(router_cfg);
|
||||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||||
.await
|
|
||||||
.expect("router");
|
|
||||||
|
|
||||||
// Create
|
// Create
|
||||||
let create_body = serde_json::json!({ "metadata": { "project": "alpha" } });
|
let create_body = serde_json::json!({ "metadata": { "project": "alpha" } });
|
||||||
@@ -616,10 +610,8 @@ async fn test_multi_turn_loop_with_mcp() {
|
|||||||
tool_call_parser: None,
|
tool_call_parser: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
let ctx = common::create_test_context(router_cfg);
|
||||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||||
.await
|
|
||||||
.expect("router");
|
|
||||||
|
|
||||||
// Build request with MCP tools
|
// Build request with MCP tools
|
||||||
let req = ResponsesRequest {
|
let req = ResponsesRequest {
|
||||||
@@ -794,10 +786,8 @@ async fn test_max_tool_calls_limit() {
|
|||||||
tool_call_parser: None,
|
tool_call_parser: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
let ctx = common::create_test_context(router_cfg);
|
||||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||||
.await
|
|
||||||
.expect("router");
|
|
||||||
|
|
||||||
let req = ResponsesRequest {
|
let req = ResponsesRequest {
|
||||||
background: Some(false),
|
background: Some(false),
|
||||||
@@ -938,10 +928,8 @@ async fn setup_streaming_mcp_test() -> (
|
|||||||
tool_call_parser: None,
|
tool_call_parser: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 64, None).expect("ctx");
|
let ctx = common::create_test_context(router_cfg);
|
||||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||||
.await
|
|
||||||
.expect("router");
|
|
||||||
|
|
||||||
(mcp, worker, router, dir)
|
(mcp, worker, router, dir)
|
||||||
}
|
}
|
||||||
@@ -1381,10 +1369,8 @@ async fn test_conversation_items_create_and_get() {
|
|||||||
tool_call_parser: None,
|
tool_call_parser: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
let ctx = common::create_test_context(router_cfg);
|
||||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||||
.await
|
|
||||||
.expect("router");
|
|
||||||
|
|
||||||
// Create conversation
|
// Create conversation
|
||||||
let create_conv = serde_json::json!({});
|
let create_conv = serde_json::json!({});
|
||||||
@@ -1484,10 +1470,8 @@ async fn test_conversation_items_delete() {
|
|||||||
tool_call_parser: None,
|
tool_call_parser: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
let ctx = common::create_test_context(router_cfg);
|
||||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||||
.await
|
|
||||||
.expect("router");
|
|
||||||
|
|
||||||
// Create conversation
|
// Create conversation
|
||||||
let create_conv = serde_json::json!({});
|
let create_conv = serde_json::json!({});
|
||||||
@@ -1593,10 +1577,8 @@ async fn test_conversation_items_max_limit() {
|
|||||||
tool_call_parser: None,
|
tool_call_parser: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
let ctx = common::create_test_context(router_cfg);
|
||||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||||
.await
|
|
||||||
.expect("router");
|
|
||||||
|
|
||||||
// Create conversation
|
// Create conversation
|
||||||
let create_conv = serde_json::json!({});
|
let create_conv = serde_json::json!({});
|
||||||
@@ -1672,10 +1654,8 @@ async fn test_conversation_items_unsupported_type() {
|
|||||||
tool_call_parser: None,
|
tool_call_parser: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
let ctx = common::create_test_context(router_cfg);
|
||||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||||
.await
|
|
||||||
.expect("router");
|
|
||||||
|
|
||||||
// Create conversation
|
// Create conversation
|
||||||
let create_conv = serde_json::json!({});
|
let create_conv = serde_json::json!({});
|
||||||
@@ -1750,10 +1730,8 @@ async fn test_conversation_items_multi_conversation_sharing() {
|
|||||||
tool_call_parser: None,
|
tool_call_parser: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
let ctx = common::create_test_context(router_cfg);
|
||||||
let router = RouterFactory::create_router(&Arc::new(ctx))
|
let router = RouterFactory::create_router(&ctx).await.expect("router");
|
||||||
.await
|
|
||||||
.expect("router");
|
|
||||||
|
|
||||||
// Create two conversations
|
// Create two conversations
|
||||||
let conv_a_resp = router
|
let conv_a_resp = router
|
||||||
|
|||||||
@@ -200,10 +200,56 @@ mod test_pd_routing {
|
|||||||
tool_call_parser: None,
|
tool_call_parser: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let app_context =
|
let app_context = {
|
||||||
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None)
|
use sglang_router_rs::core::{LoadMonitor, WorkerRegistry};
|
||||||
.expect("Failed to create AppContext");
|
use sglang_router_rs::data_connector::{
|
||||||
let app_context = std::sync::Arc::new(app_context);
|
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
||||||
|
};
|
||||||
|
use sglang_router_rs::middleware::TokenBucket;
|
||||||
|
use sglang_router_rs::policies::PolicyRegistry;
|
||||||
|
use std::sync::{Arc, OnceLock};
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
// Initialize rate limiter
|
||||||
|
let rate_limiter = Some(Arc::new(TokenBucket::new(64, 64)));
|
||||||
|
|
||||||
|
// Initialize registries
|
||||||
|
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||||
|
let policy_registry = Arc::new(PolicyRegistry::new(config.policy.clone()));
|
||||||
|
|
||||||
|
// Initialize storage backends
|
||||||
|
let response_storage = Arc::new(MemoryResponseStorage::new());
|
||||||
|
let conversation_storage = Arc::new(MemoryConversationStorage::new());
|
||||||
|
let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new());
|
||||||
|
|
||||||
|
// Initialize load monitor
|
||||||
|
let load_monitor = Some(Arc::new(LoadMonitor::new(
|
||||||
|
worker_registry.clone(),
|
||||||
|
policy_registry.clone(),
|
||||||
|
client.clone(),
|
||||||
|
config.worker_startup_check_interval_secs,
|
||||||
|
)));
|
||||||
|
|
||||||
|
// Create empty OnceLock for worker job queue
|
||||||
|
let worker_job_queue = Arc::new(OnceLock::new());
|
||||||
|
|
||||||
|
Arc::new(sglang_router_rs::server::AppContext::new(
|
||||||
|
config,
|
||||||
|
client,
|
||||||
|
rate_limiter,
|
||||||
|
None, // tokenizer
|
||||||
|
None, // reasoning_parser_factory
|
||||||
|
None, // tool_parser_factory
|
||||||
|
worker_registry,
|
||||||
|
policy_registry,
|
||||||
|
response_storage,
|
||||||
|
conversation_storage,
|
||||||
|
conversation_item_storage,
|
||||||
|
load_monitor,
|
||||||
|
worker_job_queue,
|
||||||
|
))
|
||||||
|
};
|
||||||
let result = RouterFactory::create_router(&app_context).await;
|
let result = RouterFactory::create_router(&app_context).await;
|
||||||
assert!(
|
assert!(
|
||||||
result.is_ok(),
|
result.is_ok(),
|
||||||
|
|||||||
Reference in New Issue
Block a user