[router] add router db connector for responses api (#10487)
This commit is contained in:
@@ -45,6 +45,8 @@ k8s-openapi = { version = "0.25.0", features = ["v1_33"] }
|
||||
metrics = "0.24.2"
|
||||
metrics-exporter-prometheus = "0.17.0"
|
||||
uuid = { version = "1.10", features = ["v4", "serde"] }
|
||||
ulid = "1.2.1"
|
||||
parking_lot = "0.12.4"
|
||||
thiserror = "2.0.12"
|
||||
regex = "1.10"
|
||||
url = "2.5.4"
|
||||
|
||||
@@ -67,6 +67,23 @@ pub struct RouterConfig {
|
||||
pub model_path: Option<String>,
|
||||
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
|
||||
pub tokenizer_path: Option<String>,
|
||||
/// History backend configuration (memory or none, default: memory)
|
||||
#[serde(default = "default_history_backend")]
|
||||
pub history_backend: HistoryBackend,
|
||||
}
|
||||
|
||||
fn default_history_backend() -> HistoryBackend {
|
||||
HistoryBackend::Memory
|
||||
}
|
||||
|
||||
/// History backend configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum HistoryBackend {
|
||||
/// In-memory storage (default)
|
||||
Memory,
|
||||
/// No history storage
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
|
||||
@@ -363,6 +380,7 @@ impl Default for RouterConfig {
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: default_history_backend(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -484,31 +502,9 @@ mod tests {
|
||||
policy: PolicyConfig::Random,
|
||||
host: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
max_payload_size: 1024,
|
||||
request_timeout_secs: 30,
|
||||
worker_startup_timeout_secs: 60,
|
||||
worker_startup_check_interval_secs: 5,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: Some(DiscoveryConfig::default()),
|
||||
metrics: Some(MetricsConfig::default()),
|
||||
log_dir: Some("/var/log".to_string()),
|
||||
log_level: Some("debug".to_string()),
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
queue_size: 100,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
@@ -517,8 +513,11 @@ mod tests {
|
||||
assert_eq!(config.host, deserialized.host);
|
||||
assert_eq!(config.port, deserialized.port);
|
||||
assert_eq!(config.max_payload_size, deserialized.max_payload_size);
|
||||
assert!(deserialized.discovery.is_some());
|
||||
assert!(deserialized.metrics.is_some());
|
||||
assert_eq!(config.log_dir, deserialized.log_dir);
|
||||
assert_eq!(config.log_level, deserialized.log_level);
|
||||
// discovery and metrics are None in Default implementation
|
||||
assert!(deserialized.discovery.is_none());
|
||||
assert!(deserialized.metrics.is_none());
|
||||
}
|
||||
|
||||
// ============= RoutingMode Tests =============
|
||||
@@ -948,6 +947,7 @@ mod tests {
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: default_history_backend(),
|
||||
};
|
||||
|
||||
assert!(config.mode.is_pd_mode());
|
||||
@@ -1011,6 +1011,7 @@ mod tests {
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: default_history_backend(),
|
||||
};
|
||||
|
||||
assert!(!config.mode.is_pd_mode());
|
||||
@@ -1070,6 +1071,7 @@ mod tests {
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: default_history_backend(),
|
||||
};
|
||||
|
||||
assert!(config.has_service_discovery());
|
||||
|
||||
11
sgl-router/src/data_connector/mod.rs
Normal file
11
sgl-router/src/data_connector/mod.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
// Data connector module for response storage
|
||||
pub mod response_memory_store;
|
||||
pub mod response_noop_store;
|
||||
pub mod responses;
|
||||
|
||||
pub use response_memory_store::MemoryResponseStorage;
|
||||
pub use response_noop_store::NoOpResponseStorage;
|
||||
pub use responses::{
|
||||
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, SharedResponseStorage,
|
||||
StoredResponse,
|
||||
};
|
||||
325
sgl-router/src/data_connector/response_memory_store.rs
Normal file
325
sgl-router/src/data_connector/response_memory_store.rs
Normal file
@@ -0,0 +1,325 @@
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::responses::{ResponseChain, ResponseId, ResponseStorage, Result, StoredResponse};
|
||||
|
||||
/// Internal store structure holding both maps together
|
||||
#[derive(Default)]
|
||||
struct InnerStore {
|
||||
/// All stored responses indexed by ID
|
||||
responses: HashMap<ResponseId, StoredResponse>,
|
||||
/// Index of response IDs by user
|
||||
user_index: HashMap<String, Vec<ResponseId>>,
|
||||
}
|
||||
|
||||
/// In-memory implementation of response storage
|
||||
pub struct MemoryResponseStorage {
|
||||
/// Single lock wrapping both maps to prevent deadlocks and ensure atomic updates
|
||||
store: Arc<RwLock<InnerStore>>,
|
||||
}
|
||||
|
||||
impl MemoryResponseStorage {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
store: Arc::new(RwLock::new(InnerStore::default())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get statistics about the store
|
||||
pub fn stats(&self) -> MemoryStoreStats {
|
||||
let store = self.store.read();
|
||||
MemoryStoreStats {
|
||||
response_count: store.responses.len(),
|
||||
user_count: store.user_index.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear all data (useful for testing)
|
||||
pub fn clear(&self) {
|
||||
let mut store = self.store.write();
|
||||
store.responses.clear();
|
||||
store.user_index.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MemoryResponseStorage {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ResponseStorage for MemoryResponseStorage {
|
||||
async fn store_response(&self, mut response: StoredResponse) -> Result<ResponseId> {
|
||||
// Generate ID if not set
|
||||
if response.id.0.is_empty() {
|
||||
response.id = ResponseId::new();
|
||||
}
|
||||
|
||||
let response_id = response.id.clone();
|
||||
|
||||
// Single lock acquisition for atomic update
|
||||
let mut store = self.store.write();
|
||||
|
||||
// Update user index if user is specified
|
||||
if let Some(ref user) = response.user {
|
||||
store
|
||||
.user_index
|
||||
.entry(user.clone())
|
||||
.or_default()
|
||||
.push(response_id.clone());
|
||||
}
|
||||
|
||||
// Store the response
|
||||
store.responses.insert(response_id.clone(), response);
|
||||
|
||||
Ok(response_id)
|
||||
}
|
||||
|
||||
async fn get_response(&self, response_id: &ResponseId) -> Result<Option<StoredResponse>> {
|
||||
let store = self.store.read();
|
||||
Ok(store.responses.get(response_id).cloned())
|
||||
}
|
||||
|
||||
async fn delete_response(&self, response_id: &ResponseId) -> Result<()> {
|
||||
let mut store = self.store.write();
|
||||
|
||||
// Remove the response and update user index if needed
|
||||
if let Some(response) = store.responses.remove(response_id) {
|
||||
if let Some(ref user) = response.user {
|
||||
if let Some(user_responses) = store.user_index.get_mut(user) {
|
||||
user_responses.retain(|id| id != response_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_response_chain(
|
||||
&self,
|
||||
response_id: &ResponseId,
|
||||
max_depth: Option<usize>,
|
||||
) -> Result<ResponseChain> {
|
||||
let mut chain = ResponseChain::new();
|
||||
let max_depth = max_depth.unwrap_or(100); // Default max depth to prevent infinite loops
|
||||
|
||||
// Collect all response IDs first
|
||||
let mut response_ids = Vec::new();
|
||||
let mut current_id = Some(response_id.clone());
|
||||
let mut depth = 0;
|
||||
|
||||
// Single lock acquisition to collect the chain
|
||||
{
|
||||
let store = self.store.read();
|
||||
while let Some(id) = current_id {
|
||||
if depth >= max_depth {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(response) = store.responses.get(&id) {
|
||||
response_ids.push(id);
|
||||
current_id = response.previous_response_id.clone();
|
||||
depth += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse to get chronological order (oldest first)
|
||||
response_ids.reverse();
|
||||
|
||||
// Now collect the actual responses
|
||||
let store = self.store.read();
|
||||
for id in response_ids {
|
||||
if let Some(response) = store.responses.get(&id) {
|
||||
chain.add_response(response.clone());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(chain)
|
||||
}
|
||||
|
||||
async fn list_user_responses(
|
||||
&self,
|
||||
user: &str,
|
||||
limit: Option<usize>,
|
||||
) -> Result<Vec<StoredResponse>> {
|
||||
let store = self.store.read();
|
||||
|
||||
if let Some(user_response_ids) = store.user_index.get(user) {
|
||||
// Collect responses with their timestamps for sorting
|
||||
let mut responses_with_time: Vec<_> = user_response_ids
|
||||
.iter()
|
||||
.filter_map(|id| store.responses.get(id).map(|r| (r.created_at, id)))
|
||||
.collect();
|
||||
|
||||
// Sort by creation time (newest first)
|
||||
responses_with_time.sort_by(|a, b| b.0.cmp(&a.0));
|
||||
|
||||
// Apply limit and collect the actual responses
|
||||
let limit = limit.unwrap_or(responses_with_time.len());
|
||||
let user_responses: Vec<StoredResponse> = responses_with_time
|
||||
.into_iter()
|
||||
.take(limit)
|
||||
.filter_map(|(_, id)| store.responses.get(id).cloned())
|
||||
.collect();
|
||||
|
||||
Ok(user_responses)
|
||||
} else {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_user_responses(&self, user: &str) -> Result<usize> {
|
||||
let mut store = self.store.write();
|
||||
|
||||
if let Some(user_response_ids) = store.user_index.remove(user) {
|
||||
let count = user_response_ids.len();
|
||||
for id in user_response_ids {
|
||||
store.responses.remove(&id);
|
||||
}
|
||||
Ok(count)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics for the memory store
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryStoreStats {
|
||||
pub response_count: usize,
|
||||
pub user_count: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_store_basic() {
|
||||
let store = MemoryResponseStorage::new();
|
||||
|
||||
// Store a response
|
||||
let response = StoredResponse::new("Hello".to_string(), "Hi there!".to_string(), None);
|
||||
let response_id = store.store_response(response).await.unwrap();
|
||||
|
||||
// Retrieve it
|
||||
let retrieved = store.get_response(&response_id).await.unwrap();
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().input, "Hello");
|
||||
|
||||
// Delete it
|
||||
store.delete_response(&response_id).await.unwrap();
|
||||
let deleted = store.get_response(&response_id).await.unwrap();
|
||||
assert!(deleted.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_response_chain() {
|
||||
let store = MemoryResponseStorage::new();
|
||||
|
||||
// Create a chain of responses
|
||||
let response1 =
|
||||
StoredResponse::new("First".to_string(), "First response".to_string(), None);
|
||||
let id1 = store.store_response(response1).await.unwrap();
|
||||
|
||||
let response2 = StoredResponse::new(
|
||||
"Second".to_string(),
|
||||
"Second response".to_string(),
|
||||
Some(id1.clone()),
|
||||
);
|
||||
let id2 = store.store_response(response2).await.unwrap();
|
||||
|
||||
let response3 = StoredResponse::new(
|
||||
"Third".to_string(),
|
||||
"Third response".to_string(),
|
||||
Some(id2.clone()),
|
||||
);
|
||||
let id3 = store.store_response(response3).await.unwrap();
|
||||
|
||||
// Get the chain
|
||||
let chain = store.get_response_chain(&id3, None).await.unwrap();
|
||||
assert_eq!(chain.responses.len(), 3);
|
||||
assert_eq!(chain.responses[0].input, "First");
|
||||
assert_eq!(chain.responses[1].input, "Second");
|
||||
assert_eq!(chain.responses[2].input, "Third");
|
||||
|
||||
// Test with max_depth
|
||||
let limited_chain = store.get_response_chain(&id3, Some(2)).await.unwrap();
|
||||
assert_eq!(limited_chain.responses.len(), 2);
|
||||
assert_eq!(limited_chain.responses[0].input, "Second");
|
||||
assert_eq!(limited_chain.responses[1].input, "Third");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_responses() {
|
||||
let store = MemoryResponseStorage::new();
|
||||
|
||||
// Store responses for different users
|
||||
let mut response1 = StoredResponse::new(
|
||||
"User1 message".to_string(),
|
||||
"Response to user1".to_string(),
|
||||
None,
|
||||
);
|
||||
response1.user = Some("user1".to_string());
|
||||
store.store_response(response1).await.unwrap();
|
||||
|
||||
let mut response2 = StoredResponse::new(
|
||||
"Another user1 message".to_string(),
|
||||
"Another response to user1".to_string(),
|
||||
None,
|
||||
);
|
||||
response2.user = Some("user1".to_string());
|
||||
store.store_response(response2).await.unwrap();
|
||||
|
||||
let mut response3 = StoredResponse::new(
|
||||
"User2 message".to_string(),
|
||||
"Response to user2".to_string(),
|
||||
None,
|
||||
);
|
||||
response3.user = Some("user2".to_string());
|
||||
store.store_response(response3).await.unwrap();
|
||||
|
||||
// List user1's responses
|
||||
let user1_responses = store.list_user_responses("user1", None).await.unwrap();
|
||||
assert_eq!(user1_responses.len(), 2);
|
||||
|
||||
// List user2's responses
|
||||
let user2_responses = store.list_user_responses("user2", None).await.unwrap();
|
||||
assert_eq!(user2_responses.len(), 1);
|
||||
|
||||
// Delete user1's responses
|
||||
let deleted_count = store.delete_user_responses("user1").await.unwrap();
|
||||
assert_eq!(deleted_count, 2);
|
||||
|
||||
// Verify they're gone
|
||||
let user1_responses_after = store.list_user_responses("user1", None).await.unwrap();
|
||||
assert_eq!(user1_responses_after.len(), 0);
|
||||
|
||||
// User2's responses should still be there
|
||||
let user2_responses_after = store.list_user_responses("user2", None).await.unwrap();
|
||||
assert_eq!(user2_responses_after.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_store_stats() {
|
||||
let store = MemoryResponseStorage::new();
|
||||
|
||||
let mut response1 = StoredResponse::new("Test1".to_string(), "Reply1".to_string(), None);
|
||||
response1.user = Some("user1".to_string());
|
||||
store.store_response(response1).await.unwrap();
|
||||
|
||||
let mut response2 = StoredResponse::new("Test2".to_string(), "Reply2".to_string(), None);
|
||||
response2.user = Some("user2".to_string());
|
||||
store.store_response(response2).await.unwrap();
|
||||
|
||||
let stats = store.stats();
|
||||
assert_eq!(stats.response_count, 2);
|
||||
assert_eq!(stats.user_count, 2);
|
||||
}
|
||||
}
|
||||
53
sgl-router/src/data_connector/response_noop_store.rs
Normal file
53
sgl-router/src/data_connector/response_noop_store.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::responses::{ResponseChain, ResponseId, ResponseStorage, Result, StoredResponse};
|
||||
|
||||
/// No-op implementation of response storage (does nothing)
|
||||
pub struct NoOpResponseStorage;
|
||||
|
||||
impl NoOpResponseStorage {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NoOpResponseStorage {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ResponseStorage for NoOpResponseStorage {
|
||||
async fn store_response(&self, response: StoredResponse) -> Result<ResponseId> {
|
||||
Ok(response.id)
|
||||
}
|
||||
|
||||
async fn get_response(&self, _response_id: &ResponseId) -> Result<Option<StoredResponse>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn delete_response(&self, _response_id: &ResponseId) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_response_chain(
|
||||
&self,
|
||||
_response_id: &ResponseId,
|
||||
_max_depth: Option<usize>,
|
||||
) -> Result<ResponseChain> {
|
||||
Ok(ResponseChain::new())
|
||||
}
|
||||
|
||||
async fn list_user_responses(
|
||||
&self,
|
||||
_user: &str,
|
||||
_limit: Option<usize>,
|
||||
) -> Result<Vec<StoredResponse>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
async fn delete_user_responses(&self, _user: &str) -> Result<usize> {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
177
sgl-router/src/data_connector/responses.rs
Normal file
177
sgl-router/src/data_connector/responses.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Response identifier
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ResponseId(pub String);
|
||||
|
||||
impl ResponseId {
|
||||
pub fn new() -> Self {
|
||||
Self(ulid::Ulid::new().to_string())
|
||||
}
|
||||
|
||||
pub fn from_string(s: String) -> Self {
|
||||
Self(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ResponseId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Stored response data
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StoredResponse {
|
||||
/// Unique response ID
|
||||
pub id: ResponseId,
|
||||
|
||||
/// ID of the previous response in the chain (if any)
|
||||
pub previous_response_id: Option<ResponseId>,
|
||||
|
||||
/// The user input for this response
|
||||
pub input: String,
|
||||
|
||||
/// System instructions used
|
||||
pub instructions: Option<String>,
|
||||
|
||||
/// The model's output
|
||||
pub output: String,
|
||||
|
||||
/// Tool calls made by the model (if any)
|
||||
pub tool_calls: Vec<serde_json::Value>,
|
||||
|
||||
/// Custom metadata
|
||||
pub metadata: HashMap<String, serde_json::Value>,
|
||||
|
||||
/// When this response was created
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
|
||||
/// User identifier (optional)
|
||||
pub user: Option<String>,
|
||||
|
||||
/// Model used for generation
|
||||
pub model: Option<String>,
|
||||
}
|
||||
|
||||
impl StoredResponse {
|
||||
pub fn new(input: String, output: String, previous_response_id: Option<ResponseId>) -> Self {
|
||||
Self {
|
||||
id: ResponseId::new(),
|
||||
previous_response_id,
|
||||
input,
|
||||
instructions: None,
|
||||
output,
|
||||
tool_calls: Vec::new(),
|
||||
metadata: HashMap::new(),
|
||||
created_at: chrono::Utc::now(),
|
||||
user: None,
|
||||
model: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response chain - a sequence of related responses
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResponseChain {
|
||||
/// The responses in chronological order
|
||||
pub responses: Vec<StoredResponse>,
|
||||
|
||||
/// Metadata about the chain
|
||||
pub metadata: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl Default for ResponseChain {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseChain {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
responses: Vec::new(),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the ID of the most recent response in the chain
|
||||
pub fn latest_response_id(&self) -> Option<&ResponseId> {
|
||||
self.responses.last().map(|r| &r.id)
|
||||
}
|
||||
|
||||
/// Add a response to the chain
|
||||
pub fn add_response(&mut self, response: StoredResponse) {
|
||||
self.responses.push(response);
|
||||
}
|
||||
|
||||
/// Build context from the chain for the next request
|
||||
pub fn build_context(&self, max_responses: Option<usize>) -> Vec<(String, String)> {
|
||||
let responses = if let Some(max) = max_responses {
|
||||
let start = self.responses.len().saturating_sub(max);
|
||||
&self.responses[start..]
|
||||
} else {
|
||||
&self.responses[..]
|
||||
};
|
||||
|
||||
responses
|
||||
.iter()
|
||||
.map(|r| (r.input.clone(), r.output.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Error type for response storage operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ResponseStorageError {
|
||||
#[error("Response not found: {0}")]
|
||||
ResponseNotFound(String),
|
||||
|
||||
#[error("Invalid chain: {0}")]
|
||||
InvalidChain(String),
|
||||
|
||||
#[error("Storage error: {0}")]
|
||||
StorageError(String),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
SerializationError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ResponseStorageError>;
|
||||
|
||||
/// Trait for response storage
|
||||
#[async_trait]
|
||||
pub trait ResponseStorage: Send + Sync {
|
||||
/// Store a new response
|
||||
async fn store_response(&self, response: StoredResponse) -> Result<ResponseId>;
|
||||
|
||||
/// Get a response by ID
|
||||
async fn get_response(&self, response_id: &ResponseId) -> Result<Option<StoredResponse>>;
|
||||
|
||||
/// Delete a response
|
||||
async fn delete_response(&self, response_id: &ResponseId) -> Result<()>;
|
||||
|
||||
/// Get the chain of responses leading to a given response
|
||||
/// Returns responses in chronological order (oldest first)
|
||||
async fn get_response_chain(
|
||||
&self,
|
||||
response_id: &ResponseId,
|
||||
max_depth: Option<usize>,
|
||||
) -> Result<ResponseChain>;
|
||||
|
||||
/// List recent responses for a user
|
||||
async fn list_user_responses(
|
||||
&self,
|
||||
user: &str,
|
||||
limit: Option<usize>,
|
||||
) -> Result<Vec<StoredResponse>>;
|
||||
|
||||
/// Delete all responses for a user
|
||||
async fn delete_user_responses(&self, user: &str) -> Result<usize>;
|
||||
}
|
||||
|
||||
/// Type alias for shared storage
|
||||
pub type SharedResponseStorage = Arc<dyn ResponseStorage>;
|
||||
@@ -4,6 +4,7 @@ pub mod logging;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub mod core;
|
||||
pub mod data_connector;
|
||||
#[cfg(feature = "grpc-client")]
|
||||
pub mod grpc;
|
||||
pub mod mcp;
|
||||
@@ -229,6 +230,7 @@ impl Router {
|
||||
enable_igw: self.enable_igw,
|
||||
model_path: self.model_path.clone(),
|
||||
tokenizer_path: self.tokenizer_path.clone(),
|
||||
history_backend: config::HistoryBackend::Memory,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use clap::{ArgAction, Parser, ValueEnum};
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
|
||||
HealthCheckConfig, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||
HealthCheckConfig, HistoryBackend, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig,
|
||||
RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::metrics::PrometheusConfig;
|
||||
use sglang_router_rs::server::{self, ServerConfig};
|
||||
@@ -312,6 +313,10 @@ struct CliArgs {
|
||||
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
|
||||
#[arg(long)]
|
||||
tokenizer_path: Option<String>,
|
||||
|
||||
/// History backend configuration (memory or none)
|
||||
#[arg(long, default_value = "memory", value_parser = ["memory", "none"])]
|
||||
history_backend: String,
|
||||
}
|
||||
|
||||
impl CliArgs {
|
||||
@@ -506,6 +511,10 @@ impl CliArgs {
|
||||
rate_limit_tokens_per_second: None,
|
||||
model_path: self.model_path.clone(),
|
||||
tokenizer_path: self.tokenizer_path.clone(),
|
||||
history_backend: match self.history_backend.as_str() {
|
||||
"none" => HistoryBackend::None,
|
||||
_ => HistoryBackend::Memory,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::{
|
||||
config::{ConnectionMode, RouterConfig},
|
||||
config::{ConnectionMode, HistoryBackend, RouterConfig},
|
||||
core::{WorkerRegistry, WorkerType},
|
||||
data_connector::{MemoryResponseStorage, NoOpResponseStorage, SharedResponseStorage},
|
||||
logging::{self, LoggingConfig},
|
||||
metrics::{self, PrometheusConfig},
|
||||
middleware::{self, QueuedRequest, TokenBucket},
|
||||
@@ -50,6 +51,7 @@ pub struct AppContext {
|
||||
pub worker_registry: Arc<WorkerRegistry>,
|
||||
pub policy_registry: Arc<PolicyRegistry>,
|
||||
pub router_manager: Option<Arc<RouterManager>>,
|
||||
pub response_storage: SharedResponseStorage,
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
@@ -94,6 +96,12 @@ impl AppContext {
|
||||
|
||||
let router_manager = None;
|
||||
|
||||
// Initialize response storage based on configuration
|
||||
let response_storage: SharedResponseStorage = match router_config.history_backend {
|
||||
HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()),
|
||||
HistoryBackend::None => Arc::new(NoOpResponseStorage::new()),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
router_config,
|
||||
@@ -104,6 +112,7 @@ impl AppContext {
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
router_manager,
|
||||
response_storage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -603,6 +603,7 @@ mod tests {
|
||||
reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser
|
||||
tool_parser_registry: None, // HTTP mode doesn't need tool parser
|
||||
router_manager: None, // Test doesn't need router manager
|
||||
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
|
||||
});
|
||||
|
||||
let router = Router::new(vec![], &app_context).await.unwrap();
|
||||
|
||||
@@ -58,6 +58,7 @@ impl TestContext {
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
};
|
||||
|
||||
Self::new_with_config(config, worker_configs).await
|
||||
@@ -1392,6 +1393,7 @@ mod error_tests {
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
};
|
||||
|
||||
let ctx = TestContext::new_with_config(
|
||||
@@ -1750,6 +1752,7 @@ mod pd_mode_tests {
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
};
|
||||
|
||||
// Create app context
|
||||
@@ -1912,6 +1915,7 @@ mod request_id_tests {
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
};
|
||||
|
||||
let ctx = TestContext::new_with_config(
|
||||
|
||||
@@ -3,9 +3,7 @@ mod common;
|
||||
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::config::{RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -21,34 +19,10 @@ impl TestContext {
|
||||
mode: RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
},
|
||||
policy: PolicyConfig::Random,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 3003,
|
||||
max_payload_size: 256 * 1024 * 1024,
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 1,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut workers = Vec::new();
|
||||
|
||||
@@ -4,9 +4,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::config::{RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -22,34 +20,10 @@ impl TestContext {
|
||||
mode: RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
},
|
||||
policy: PolicyConfig::Random,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 3004,
|
||||
max_payload_size: 256 * 1024 * 1024,
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 1,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut workers = Vec::new();
|
||||
|
||||
@@ -191,6 +191,7 @@ mod test_pd_routing {
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
};
|
||||
|
||||
// Router creation will fail due to health checks, but config should be valid
|
||||
|
||||
Reference in New Issue
Block a user