diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 4ecfae55d..a58eff9f6 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -45,6 +45,8 @@ k8s-openapi = { version = "0.25.0", features = ["v1_33"] } metrics = "0.24.2" metrics-exporter-prometheus = "0.17.0" uuid = { version = "1.10", features = ["v4", "serde"] } +ulid = "1.2.1" +parking_lot = "0.12.4" thiserror = "2.0.12" regex = "1.10" url = "2.5.4" diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 5f5b227af..a46c23dbd 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -67,6 +67,23 @@ pub struct RouterConfig { pub model_path: Option, /// Explicit tokenizer path (overrides model_path tokenizer if provided) pub tokenizer_path: Option, + /// History backend configuration (memory or none, default: memory) + #[serde(default = "default_history_backend")] + pub history_backend: HistoryBackend, +} + +fn default_history_backend() -> HistoryBackend { + HistoryBackend::Memory +} + +/// History backend configuration +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum HistoryBackend { + /// In-memory storage (default) + Memory, + /// No history storage + None, } #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] @@ -363,6 +380,7 @@ impl Default for RouterConfig { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + history_backend: default_history_backend(), } } } @@ -484,31 +502,9 @@ mod tests { policy: PolicyConfig::Random, host: "0.0.0.0".to_string(), port: 8080, - max_payload_size: 1024, - request_timeout_secs: 30, - worker_startup_timeout_secs: 60, - worker_startup_check_interval_secs: 5, - dp_aware: false, - api_key: None, - discovery: Some(DiscoveryConfig::default()), - metrics: Some(MetricsConfig::default()), log_dir: Some("/var/log".to_string()), log_level: Some("debug".to_string()), - request_id_headers: None, - max_concurrent_requests: 64, - cors_allowed_origins: vec![], - retry: RetryConfig::default(), - circuit_breaker: CircuitBreakerConfig::default(), - disable_retries: false, - disable_circuit_breaker: false, - health_check: HealthCheckConfig::default(), - enable_igw: false, - queue_size: 100, - queue_timeout_secs: 60, - rate_limit_tokens_per_second: None, - connection_mode: ConnectionMode::Http, - model_path: None, - tokenizer_path: None, + ..Default::default() }; let json = serde_json::to_string(&config).unwrap(); @@ -517,8 +513,11 @@ mod tests { assert_eq!(config.host, deserialized.host); assert_eq!(config.port, deserialized.port); assert_eq!(config.max_payload_size, deserialized.max_payload_size); - assert!(deserialized.discovery.is_some()); - assert!(deserialized.metrics.is_some()); + assert_eq!(config.log_dir, deserialized.log_dir); + assert_eq!(config.log_level, deserialized.log_level); + // discovery and metrics are None in Default implementation + assert!(deserialized.discovery.is_none()); + assert!(deserialized.metrics.is_none()); } // ============= RoutingMode Tests ============= @@ -948,6 +947,7 @@ mod tests { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + history_backend: default_history_backend(), }; assert!(config.mode.is_pd_mode()); @@ -1011,6 +1011,7 @@ mod tests { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + history_backend: default_history_backend(), }; assert!(!config.mode.is_pd_mode()); @@ -1070,6 +1071,7 @@ mod tests { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + history_backend: default_history_backend(), }; assert!(config.has_service_discovery()); diff --git a/sgl-router/src/data_connector/mod.rs b/sgl-router/src/data_connector/mod.rs new file mode 100644 index 000000000..1b4bf2073 --- /dev/null +++ b/sgl-router/src/data_connector/mod.rs @@ -0,0 +1,11 @@ +// Data connector module for response storage +pub mod response_memory_store; +pub mod response_noop_store; +pub mod responses; + +pub use response_memory_store::MemoryResponseStorage; +pub use response_noop_store::NoOpResponseStorage; +pub use responses::{ + ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, SharedResponseStorage, + StoredResponse, +}; diff --git a/sgl-router/src/data_connector/response_memory_store.rs b/sgl-router/src/data_connector/response_memory_store.rs new file mode 100644 index 000000000..003d07ac7 --- /dev/null +++ b/sgl-router/src/data_connector/response_memory_store.rs @@ -0,0 +1,325 @@ +use async_trait::async_trait; +use parking_lot::RwLock; +use std::collections::HashMap; +use std::sync::Arc; + +use super::responses::{ResponseChain, ResponseId, ResponseStorage, Result, StoredResponse}; + +/// Internal store structure holding both maps together +#[derive(Default)] +struct InnerStore { + /// All stored responses indexed by ID + responses: HashMap, + /// Index of response IDs by user + user_index: HashMap>, +} + +/// In-memory implementation of response storage +pub struct MemoryResponseStorage { + /// Single lock wrapping both maps to prevent deadlocks and ensure atomic updates + store: Arc>, +} + +impl MemoryResponseStorage { + pub fn new() -> Self { + Self { + store: Arc::new(RwLock::new(InnerStore::default())), + } + } + + /// Get statistics about the store + pub fn stats(&self) -> MemoryStoreStats { + let store = self.store.read(); + MemoryStoreStats { + response_count: store.responses.len(), + user_count: store.user_index.len(), + } + } + + /// Clear all data (useful for testing) + pub fn clear(&self) { + let mut store = self.store.write(); + store.responses.clear(); + store.user_index.clear(); + } +} + +impl Default for MemoryResponseStorage { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ResponseStorage for MemoryResponseStorage { + async fn store_response(&self, mut response: StoredResponse) -> Result { + // Generate ID if not set + if response.id.0.is_empty() { + response.id = ResponseId::new(); + } + + let response_id = response.id.clone(); + + // Single lock acquisition for atomic update + let mut store = self.store.write(); + + // Update user index if user is specified + if let Some(ref user) = response.user { + store + .user_index + .entry(user.clone()) + .or_default() + .push(response_id.clone()); + } + + // Store the response + store.responses.insert(response_id.clone(), response); + + Ok(response_id) + } + + async fn get_response(&self, response_id: &ResponseId) -> Result> { + let store = self.store.read(); + Ok(store.responses.get(response_id).cloned()) + } + + async fn delete_response(&self, response_id: &ResponseId) -> Result<()> { + let mut store = self.store.write(); + + // Remove the response and update user index if needed + if let Some(response) = store.responses.remove(response_id) { + if let Some(ref user) = response.user { + if let Some(user_responses) = store.user_index.get_mut(user) { + user_responses.retain(|id| id != response_id); + } + } + } + + Ok(()) + } + + async fn get_response_chain( + &self, + response_id: &ResponseId, + max_depth: Option, + ) -> Result { + let mut chain = ResponseChain::new(); + let max_depth = max_depth.unwrap_or(100); // Default max depth to prevent infinite loops + + // Collect all response IDs first + let mut response_ids = Vec::new(); + let mut current_id = Some(response_id.clone()); + let mut depth = 0; + + // Single lock acquisition to collect the chain + { + let store = self.store.read(); + while let Some(id) = current_id { + if depth >= max_depth { + break; + } + + if let Some(response) = store.responses.get(&id) { + response_ids.push(id); + current_id = response.previous_response_id.clone(); + depth += 1; + } else { + break; + } + } + } + + // Reverse to get chronological order (oldest first) + response_ids.reverse(); + + // Now collect the actual responses + let store = self.store.read(); + for id in response_ids { + if let Some(response) = store.responses.get(&id) { + chain.add_response(response.clone()); + } + } + + Ok(chain) + } + + async fn list_user_responses( + &self, + user: &str, + limit: Option, + ) -> Result> { + let store = self.store.read(); + + if let Some(user_response_ids) = store.user_index.get(user) { + // Collect responses with their timestamps for sorting + let mut responses_with_time: Vec<_> = user_response_ids + .iter() + .filter_map(|id| store.responses.get(id).map(|r| (r.created_at, id))) + .collect(); + + // Sort by creation time (newest first) + responses_with_time.sort_by(|a, b| b.0.cmp(&a.0)); + + // Apply limit and collect the actual responses + let limit = limit.unwrap_or(responses_with_time.len()); + let user_responses: Vec = responses_with_time + .into_iter() + .take(limit) + .filter_map(|(_, id)| store.responses.get(id).cloned()) + .collect(); + + Ok(user_responses) + } else { + Ok(Vec::new()) + } + } + + async fn delete_user_responses(&self, user: &str) -> Result { + let mut store = self.store.write(); + + if let Some(user_response_ids) = store.user_index.remove(user) { + let count = user_response_ids.len(); + for id in user_response_ids { + store.responses.remove(&id); + } + Ok(count) + } else { + Ok(0) + } + } +} + +/// Statistics for the memory store +#[derive(Debug, Clone)] +pub struct MemoryStoreStats { + pub response_count: usize, + pub user_count: usize, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_memory_store_basic() { + let store = MemoryResponseStorage::new(); + + // Store a response + let response = StoredResponse::new("Hello".to_string(), "Hi there!".to_string(), None); + let response_id = store.store_response(response).await.unwrap(); + + // Retrieve it + let retrieved = store.get_response(&response_id).await.unwrap(); + assert!(retrieved.is_some()); + assert_eq!(retrieved.unwrap().input, "Hello"); + + // Delete it + store.delete_response(&response_id).await.unwrap(); + let deleted = store.get_response(&response_id).await.unwrap(); + assert!(deleted.is_none()); + } + + #[tokio::test] + async fn test_response_chain() { + let store = MemoryResponseStorage::new(); + + // Create a chain of responses + let response1 = + StoredResponse::new("First".to_string(), "First response".to_string(), None); + let id1 = store.store_response(response1).await.unwrap(); + + let response2 = StoredResponse::new( + "Second".to_string(), + "Second response".to_string(), + Some(id1.clone()), + ); + let id2 = store.store_response(response2).await.unwrap(); + + let response3 = StoredResponse::new( + "Third".to_string(), + "Third response".to_string(), + Some(id2.clone()), + ); + let id3 = store.store_response(response3).await.unwrap(); + + // Get the chain + let chain = store.get_response_chain(&id3, None).await.unwrap(); + assert_eq!(chain.responses.len(), 3); + assert_eq!(chain.responses[0].input, "First"); + assert_eq!(chain.responses[1].input, "Second"); + assert_eq!(chain.responses[2].input, "Third"); + + // Test with max_depth + let limited_chain = store.get_response_chain(&id3, Some(2)).await.unwrap(); + assert_eq!(limited_chain.responses.len(), 2); + assert_eq!(limited_chain.responses[0].input, "Second"); + assert_eq!(limited_chain.responses[1].input, "Third"); + } + + #[tokio::test] + async fn test_user_responses() { + let store = MemoryResponseStorage::new(); + + // Store responses for different users + let mut response1 = StoredResponse::new( + "User1 message".to_string(), + "Response to user1".to_string(), + None, + ); + response1.user = Some("user1".to_string()); + store.store_response(response1).await.unwrap(); + + let mut response2 = StoredResponse::new( + "Another user1 message".to_string(), + "Another response to user1".to_string(), + None, + ); + response2.user = Some("user1".to_string()); + store.store_response(response2).await.unwrap(); + + let mut response3 = StoredResponse::new( + "User2 message".to_string(), + "Response to user2".to_string(), + None, + ); + response3.user = Some("user2".to_string()); + store.store_response(response3).await.unwrap(); + + // List user1's responses + let user1_responses = store.list_user_responses("user1", None).await.unwrap(); + assert_eq!(user1_responses.len(), 2); + + // List user2's responses + let user2_responses = store.list_user_responses("user2", None).await.unwrap(); + assert_eq!(user2_responses.len(), 1); + + // Delete user1's responses + let deleted_count = store.delete_user_responses("user1").await.unwrap(); + assert_eq!(deleted_count, 2); + + // Verify they're gone + let user1_responses_after = store.list_user_responses("user1", None).await.unwrap(); + assert_eq!(user1_responses_after.len(), 0); + + // User2's responses should still be there + let user2_responses_after = store.list_user_responses("user2", None).await.unwrap(); + assert_eq!(user2_responses_after.len(), 1); + } + + #[tokio::test] + async fn test_memory_store_stats() { + let store = MemoryResponseStorage::new(); + + let mut response1 = StoredResponse::new("Test1".to_string(), "Reply1".to_string(), None); + response1.user = Some("user1".to_string()); + store.store_response(response1).await.unwrap(); + + let mut response2 = StoredResponse::new("Test2".to_string(), "Reply2".to_string(), None); + response2.user = Some("user2".to_string()); + store.store_response(response2).await.unwrap(); + + let stats = store.stats(); + assert_eq!(stats.response_count, 2); + assert_eq!(stats.user_count, 2); + } +} diff --git a/sgl-router/src/data_connector/response_noop_store.rs b/sgl-router/src/data_connector/response_noop_store.rs new file mode 100644 index 000000000..968a329d7 --- /dev/null +++ b/sgl-router/src/data_connector/response_noop_store.rs @@ -0,0 +1,53 @@ +use async_trait::async_trait; + +use super::responses::{ResponseChain, ResponseId, ResponseStorage, Result, StoredResponse}; + +/// No-op implementation of response storage (does nothing) +pub struct NoOpResponseStorage; + +impl NoOpResponseStorage { + pub fn new() -> Self { + Self + } +} + +impl Default for NoOpResponseStorage { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ResponseStorage for NoOpResponseStorage { + async fn store_response(&self, response: StoredResponse) -> Result { + Ok(response.id) + } + + async fn get_response(&self, _response_id: &ResponseId) -> Result> { + Ok(None) + } + + async fn delete_response(&self, _response_id: &ResponseId) -> Result<()> { + Ok(()) + } + + async fn get_response_chain( + &self, + _response_id: &ResponseId, + _max_depth: Option, + ) -> Result { + Ok(ResponseChain::new()) + } + + async fn list_user_responses( + &self, + _user: &str, + _limit: Option, + ) -> Result> { + Ok(Vec::new()) + } + + async fn delete_user_responses(&self, _user: &str) -> Result { + Ok(0) + } +} diff --git a/sgl-router/src/data_connector/responses.rs b/sgl-router/src/data_connector/responses.rs new file mode 100644 index 000000000..49693e984 --- /dev/null +++ b/sgl-router/src/data_connector/responses.rs @@ -0,0 +1,177 @@ +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; + +/// Response identifier +#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)] +pub struct ResponseId(pub String); + +impl ResponseId { + pub fn new() -> Self { + Self(ulid::Ulid::new().to_string()) + } + + pub fn from_string(s: String) -> Self { + Self(s) + } +} + +impl Default for ResponseId { + fn default() -> Self { + Self::new() + } +} + +/// Stored response data +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StoredResponse { + /// Unique response ID + pub id: ResponseId, + + /// ID of the previous response in the chain (if any) + pub previous_response_id: Option, + + /// The user input for this response + pub input: String, + + /// System instructions used + pub instructions: Option, + + /// The model's output + pub output: String, + + /// Tool calls made by the model (if any) + pub tool_calls: Vec, + + /// Custom metadata + pub metadata: HashMap, + + /// When this response was created + pub created_at: chrono::DateTime, + + /// User identifier (optional) + pub user: Option, + + /// Model used for generation + pub model: Option, +} + +impl StoredResponse { + pub fn new(input: String, output: String, previous_response_id: Option) -> Self { + Self { + id: ResponseId::new(), + previous_response_id, + input, + instructions: None, + output, + tool_calls: Vec::new(), + metadata: HashMap::new(), + created_at: chrono::Utc::now(), + user: None, + model: None, + } + } +} + +/// Response chain - a sequence of related responses +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseChain { + /// The responses in chronological order + pub responses: Vec, + + /// Metadata about the chain + pub metadata: HashMap, +} + +impl Default for ResponseChain { + fn default() -> Self { + Self::new() + } +} + +impl ResponseChain { + pub fn new() -> Self { + Self { + responses: Vec::new(), + metadata: HashMap::new(), + } + } + + /// Get the ID of the most recent response in the chain + pub fn latest_response_id(&self) -> Option<&ResponseId> { + self.responses.last().map(|r| &r.id) + } + + /// Add a response to the chain + pub fn add_response(&mut self, response: StoredResponse) { + self.responses.push(response); + } + + /// Build context from the chain for the next request + pub fn build_context(&self, max_responses: Option) -> Vec<(String, String)> { + let responses = if let Some(max) = max_responses { + let start = self.responses.len().saturating_sub(max); + &self.responses[start..] + } else { + &self.responses[..] + }; + + responses + .iter() + .map(|r| (r.input.clone(), r.output.clone())) + .collect() + } +} + +/// Error type for response storage operations +#[derive(Debug, thiserror::Error)] +pub enum ResponseStorageError { + #[error("Response not found: {0}")] + ResponseNotFound(String), + + #[error("Invalid chain: {0}")] + InvalidChain(String), + + #[error("Storage error: {0}")] + StorageError(String), + + #[error("Serialization error: {0}")] + SerializationError(#[from] serde_json::Error), +} + +pub type Result = std::result::Result; + +/// Trait for response storage +#[async_trait] +pub trait ResponseStorage: Send + Sync { + /// Store a new response + async fn store_response(&self, response: StoredResponse) -> Result; + + /// Get a response by ID + async fn get_response(&self, response_id: &ResponseId) -> Result>; + + /// Delete a response + async fn delete_response(&self, response_id: &ResponseId) -> Result<()>; + + /// Get the chain of responses leading to a given response + /// Returns responses in chronological order (oldest first) + async fn get_response_chain( + &self, + response_id: &ResponseId, + max_depth: Option, + ) -> Result; + + /// List recent responses for a user + async fn list_user_responses( + &self, + user: &str, + limit: Option, + ) -> Result>; + + /// Delete all responses for a user + async fn delete_user_responses(&self, user: &str) -> Result; +} + +/// Type alias for shared storage +pub type SharedResponseStorage = Arc; diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 938a1ba0c..1907ed2e5 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -4,6 +4,7 @@ pub mod logging; use std::collections::HashMap; pub mod core; +pub mod data_connector; #[cfg(feature = "grpc-client")] pub mod grpc; pub mod mcp; @@ -229,6 +230,7 @@ impl Router { enable_igw: self.enable_igw, model_path: self.model_path.clone(), tokenizer_path: self.tokenizer_path.clone(), + history_backend: config::HistoryBackend::Memory, }) } } diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index 8ec24a722..243c32b33 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -1,7 +1,8 @@ use clap::{ArgAction, Parser, ValueEnum}; use sglang_router_rs::config::{ CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig, - HealthCheckConfig, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, + HealthCheckConfig, HistoryBackend, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, + RoutingMode, }; use sglang_router_rs::metrics::PrometheusConfig; use sglang_router_rs::server::{self, ServerConfig}; @@ -312,6 +313,10 @@ struct CliArgs { /// Explicit tokenizer path (overrides model_path tokenizer if provided) #[arg(long)] tokenizer_path: Option, + + /// History backend configuration (memory or none) + #[arg(long, default_value = "memory", value_parser = ["memory", "none"])] + history_backend: String, } impl CliArgs { @@ -506,6 +511,10 @@ impl CliArgs { rate_limit_tokens_per_second: None, model_path: self.model_path.clone(), tokenizer_path: self.tokenizer_path.clone(), + history_backend: match self.history_backend.as_str() { + "none" => HistoryBackend::None, + _ => HistoryBackend::Memory, + }, }) } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 2bb9bbb05..348d70aea 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,6 +1,7 @@ use crate::{ - config::{ConnectionMode, RouterConfig}, + config::{ConnectionMode, HistoryBackend, RouterConfig}, core::{WorkerRegistry, WorkerType}, + data_connector::{MemoryResponseStorage, NoOpResponseStorage, SharedResponseStorage}, logging::{self, LoggingConfig}, metrics::{self, PrometheusConfig}, middleware::{self, QueuedRequest, TokenBucket}, @@ -50,6 +51,7 @@ pub struct AppContext { pub worker_registry: Arc, pub policy_registry: Arc, pub router_manager: Option>, + pub response_storage: SharedResponseStorage, } impl AppContext { @@ -94,6 +96,12 @@ impl AppContext { let router_manager = None; + // Initialize response storage based on configuration + let response_storage: SharedResponseStorage = match router_config.history_backend { + HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()), + HistoryBackend::None => Arc::new(NoOpResponseStorage::new()), + }; + Ok(Self { client, router_config, @@ -104,6 +112,7 @@ impl AppContext { worker_registry, policy_registry, router_manager, + response_storage, }) } } diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index a099933dc..8272e5f35 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -603,6 +603,7 @@ mod tests { reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser tool_parser_registry: None, // HTTP mode doesn't need tool parser router_manager: None, // Test doesn't need router manager + response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()), }); let router = Router::new(vec![], &app_context).await.unwrap(); diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index 14b9f2f99..b715dad55 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -58,6 +58,7 @@ impl TestContext { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + history_backend: sglang_router_rs::config::HistoryBackend::Memory, }; Self::new_with_config(config, worker_configs).await @@ -1392,6 +1393,7 @@ mod error_tests { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + history_backend: sglang_router_rs::config::HistoryBackend::Memory, }; let ctx = TestContext::new_with_config( @@ -1750,6 +1752,7 @@ mod pd_mode_tests { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + history_backend: sglang_router_rs::config::HistoryBackend::Memory, }; // Create app context @@ -1912,6 +1915,7 @@ mod request_id_tests { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + history_backend: sglang_router_rs::config::HistoryBackend::Memory, }; let ctx = TestContext::new_with_config( diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index 606ca0a41..2ec7f0039 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -3,9 +3,7 @@ mod common; use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use reqwest::Client; use serde_json::json; -use sglang_router_rs::config::{ - CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, -}; +use sglang_router_rs::config::{RouterConfig, RoutingMode}; use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use std::sync::Arc; @@ -21,34 +19,10 @@ impl TestContext { mode: RoutingMode::Regular { worker_urls: vec![], }, - policy: PolicyConfig::Random, - host: "127.0.0.1".to_string(), port: 3003, - max_payload_size: 256 * 1024 * 1024, - request_timeout_secs: 600, worker_startup_timeout_secs: 1, worker_startup_check_interval_secs: 1, - dp_aware: false, - api_key: None, - discovery: None, - metrics: None, - log_dir: None, - log_level: None, - request_id_headers: None, - max_concurrent_requests: 64, - queue_size: 0, - queue_timeout_secs: 60, - rate_limit_tokens_per_second: None, - cors_allowed_origins: vec![], - retry: RetryConfig::default(), - circuit_breaker: CircuitBreakerConfig::default(), - disable_retries: false, - disable_circuit_breaker: false, - health_check: sglang_router_rs::config::HealthCheckConfig::default(), - enable_igw: false, - connection_mode: ConnectionMode::Http, - model_path: None, - tokenizer_path: None, + ..Default::default() }; let mut workers = Vec::new(); diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index 29190a312..b998625c1 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -4,9 +4,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType use futures_util::StreamExt; use reqwest::Client; use serde_json::json; -use sglang_router_rs::config::{ - CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, -}; +use sglang_router_rs::config::{RouterConfig, RoutingMode}; use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use std::sync::Arc; @@ -22,34 +20,10 @@ impl TestContext { mode: RoutingMode::Regular { worker_urls: vec![], }, - policy: PolicyConfig::Random, - host: "127.0.0.1".to_string(), port: 3004, - max_payload_size: 256 * 1024 * 1024, - request_timeout_secs: 600, worker_startup_timeout_secs: 1, worker_startup_check_interval_secs: 1, - dp_aware: false, - api_key: None, - discovery: None, - metrics: None, - log_dir: None, - log_level: None, - request_id_headers: None, - max_concurrent_requests: 64, - queue_size: 0, - queue_timeout_secs: 60, - rate_limit_tokens_per_second: None, - cors_allowed_origins: vec![], - retry: RetryConfig::default(), - circuit_breaker: CircuitBreakerConfig::default(), - disable_retries: false, - disable_circuit_breaker: false, - health_check: sglang_router_rs::config::HealthCheckConfig::default(), - enable_igw: false, - connection_mode: ConnectionMode::Http, - model_path: None, - tokenizer_path: None, + ..Default::default() }; let mut workers = Vec::new(); diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 7071106a4..45651478a 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -191,6 +191,7 @@ mod test_pd_routing { connection_mode: ConnectionMode::Http, model_path: None, tokenizer_path: None, + history_backend: sglang_router_rs::config::HistoryBackend::Memory, }; // Router creation will fail due to health checks, but config should be valid