diff --git a/sgl-router/src/data_connector/conversation_memory_store.rs b/sgl-router/src/data_connector/conversation_memory_store.rs new file mode 100644 index 000000000..c2091c6b2 --- /dev/null +++ b/sgl-router/src/data_connector/conversation_memory_store.rs @@ -0,0 +1,57 @@ +use async_trait::async_trait; +use parking_lot::RwLock; +use std::collections::HashMap; +use std::sync::Arc; + +use super::conversations::{ + Conversation, ConversationId, ConversationMetadata, ConversationStorage, NewConversation, + Result, +}; + +/// In-memory conversation storage used for development and tests +#[derive(Default, Clone)] +pub struct MemoryConversationStorage { + inner: Arc>>, +} + +impl MemoryConversationStorage { + pub fn new() -> Self { + Self { + inner: Arc::new(RwLock::new(HashMap::new())), + } + } +} + +#[async_trait] +impl ConversationStorage for MemoryConversationStorage { + async fn create_conversation(&self, input: NewConversation) -> Result { + let conversation = Conversation::new(input); + self.inner + .write() + .insert(conversation.id.clone(), conversation.clone()); + Ok(conversation) + } + + async fn get_conversation(&self, id: &ConversationId) -> Result> { + Ok(self.inner.read().get(id).cloned()) + } + + async fn update_conversation( + &self, + id: &ConversationId, + metadata: Option, + ) -> Result> { + let mut store = self.inner.write(); + if let Some(entry) = store.get_mut(id) { + entry.metadata = metadata; + return Ok(Some(entry.clone())); + } + + Ok(None) + } + + async fn delete_conversation(&self, id: &ConversationId) -> Result { + let removed = self.inner.write().remove(id).is_some(); + Ok(removed) + } +} diff --git a/sgl-router/src/data_connector/conversation_noop_store.rs b/sgl-router/src/data_connector/conversation_noop_store.rs new file mode 100644 index 000000000..68b06bc59 --- /dev/null +++ b/sgl-router/src/data_connector/conversation_noop_store.rs @@ -0,0 +1,41 @@ +use async_trait::async_trait; + +use super::conversations::{ + Conversation, ConversationId, ConversationMetadata, ConversationStorage, Result, +}; + +/// No-op implementation that synthesizes conversation responses without persistence +#[derive(Default, Debug, Clone)] +pub struct NoOpConversationStorage; + +impl NoOpConversationStorage { + pub fn new() -> Self { + Self + } +} + +#[async_trait] +impl ConversationStorage for NoOpConversationStorage { + async fn create_conversation( + &self, + input: super::conversations::NewConversation, + ) -> Result { + Ok(Conversation::new(input)) + } + + async fn get_conversation(&self, _id: &ConversationId) -> Result> { + Ok(None) + } + + async fn update_conversation( + &self, + _id: &ConversationId, + _metadata: Option, + ) -> Result> { + Ok(None) + } + + async fn delete_conversation(&self, _id: &ConversationId) -> Result { + Ok(false) + } +} diff --git a/sgl-router/src/data_connector/conversation_oracle_store.rs b/sgl-router/src/data_connector/conversation_oracle_store.rs new file mode 100644 index 000000000..452b85c2c --- /dev/null +++ b/sgl-router/src/data_connector/conversation_oracle_store.rs @@ -0,0 +1,338 @@ +use crate::config::OracleConfig; +use crate::data_connector::conversations::{ + Conversation, ConversationId, ConversationMetadata, ConversationStorage, + ConversationStorageError, NewConversation, Result, +}; +use async_trait::async_trait; +use chrono::Utc; +use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult}; +use oracle::{sql_type::OracleType, Connection}; +use serde_json::Value; +use std::path::Path; +use std::sync::Arc; +use std::time::Duration; + +#[derive(Clone)] +pub struct OracleConversationStorage { + pool: Pool, +} + +impl OracleConversationStorage { + pub fn new(config: OracleConfig) -> Result { + configure_oracle_client(&config)?; + initialize_schema(&config)?; + + let config = Arc::new(config); + let manager = ConversationOracleConnectionManager::new(config.clone()); + + let mut builder = Pool::builder(manager) + .max_size(config.pool_max) + .runtime(deadpool::Runtime::Tokio1); + + if config.pool_timeout_secs > 0 { + builder = builder.wait_timeout(Some(Duration::from_secs(config.pool_timeout_secs))); + } + + let pool = builder.build().map_err(|err| { + ConversationStorageError::StorageError(format!( + "failed to build Oracle pool for conversations: {err}" + )) + })?; + + Ok(Self { pool }) + } + + async fn with_connection(&self, func: F) -> Result + where + F: FnOnce(&Connection) -> Result + Send + 'static, + T: Send + 'static, + { + let connection = self.pool.get().await.map_err(map_pool_error)?; + tokio::task::spawn_blocking(move || { + let result = func(&connection); + drop(connection); + result + }) + .await + .map_err(|err| { + ConversationStorageError::StorageError(format!( + "failed to execute Oracle conversation task: {err}" + )) + })? + } + + fn parse_metadata(raw: Option) -> Result> { + match raw { + Some(json) if !json.is_empty() => { + let value: Value = serde_json::from_str(&json)?; + match value { + Value::Object(map) => Ok(Some(map)), + Value::Null => Ok(None), + other => Err(ConversationStorageError::StorageError(format!( + "conversation metadata expected object, got {other}" + ))), + } + } + _ => Ok(None), + } + } +} + +#[async_trait] +impl ConversationStorage for OracleConversationStorage { + async fn create_conversation(&self, input: NewConversation) -> Result { + let conversation = Conversation::new(input); + let id_str = conversation.id.0.clone(); + let created_at = conversation.created_at; + let metadata_json = conversation + .metadata + .as_ref() + .map(serde_json::to_string) + .transpose()?; + + self.with_connection(move |conn| { + conn.execute( + "INSERT INTO conversations (id, created_at, metadata) VALUES (:1, :2, :3)", + &[&id_str, &created_at, &metadata_json], + ) + .map(|_| ()) + .map_err(map_oracle_error) + }) + .await?; + + Ok(conversation) + } + + async fn get_conversation(&self, id: &ConversationId) -> Result> { + let lookup = id.0.clone(); + self.with_connection(move |conn| { + let mut stmt = conn + .statement("SELECT id, created_at, metadata FROM conversations WHERE id = :1") + .build() + .map_err(map_oracle_error)?; + let mut rows = stmt.query(&[&lookup]).map_err(map_oracle_error)?; + + if let Some(row_res) = rows.next() { + let row = row_res.map_err(map_oracle_error)?; + let id: String = row.get(0).map_err(map_oracle_error)?; + let created_at: chrono::DateTime = row.get(1).map_err(map_oracle_error)?; + let metadata_raw: Option = row.get(2).map_err(map_oracle_error)?; + let metadata = Self::parse_metadata(metadata_raw)?; + Ok(Some(Conversation::with_parts( + ConversationId(id), + created_at, + metadata, + ))) + } else { + Ok(None) + } + }) + .await + } + + async fn update_conversation( + &self, + id: &ConversationId, + metadata: Option, + ) -> Result> { + let id_str = id.0.clone(); + let metadata_json = metadata.as_ref().map(serde_json::to_string).transpose()?; + let conversation_id = id.clone(); + + self.with_connection(move |conn| { + let mut stmt = conn + .statement( + "UPDATE conversations \ + SET metadata = :1 \ + WHERE id = :2 \ + RETURNING created_at INTO :3", + ) + .build() + .map_err(map_oracle_error)?; + + stmt.bind(3, &OracleType::TimestampTZ(6)) + .map_err(map_oracle_error)?; + stmt.execute(&[&metadata_json, &id_str]) + .map_err(map_oracle_error)?; + + if stmt.row_count().map_err(map_oracle_error)? == 0 { + return Ok(None); + } + + let mut created_at: Vec> = + stmt.returned_values(3).map_err(map_oracle_error)?; + let created_at = created_at.pop().ok_or_else(|| { + ConversationStorageError::StorageError( + "Oracle update did not return created_at".to_string(), + ) + })?; + + Ok(Some(Conversation::with_parts( + conversation_id, + created_at, + metadata, + ))) + }) + .await + } + + async fn delete_conversation(&self, id: &ConversationId) -> Result { + let id_str = id.0.clone(); + let res = self + .with_connection(move |conn| { + conn.execute("DELETE FROM conversations WHERE id = :1", &[&id_str]) + .map_err(map_oracle_error) + }) + .await?; + + Ok(res.row_count().map_err(map_oracle_error)? > 0) + } +} + +#[derive(Clone)] +struct ConversationOracleConnectionManager { + params: Arc, +} + +impl ConversationOracleConnectionManager { + fn new(config: Arc) -> Self { + let params = OracleConnectParams { + username: config.username.clone(), + password: config.password.clone(), + connect_descriptor: config.connect_descriptor.clone(), + }; + + Self { + params: Arc::new(params), + } + } +} + +impl std::fmt::Debug for ConversationOracleConnectionManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConversationOracleConnectionManager") + .field("username", &self.params.username) + .field("connect_descriptor", &self.params.connect_descriptor) + .finish() + } +} + +#[derive(Clone)] +struct OracleConnectParams { + username: String, + password: String, + connect_descriptor: String, +} + +impl std::fmt::Debug for OracleConnectParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OracleConnectParams") + .field("username", &self.username) + .field("connect_descriptor", &self.connect_descriptor) + .finish() + } +} + +#[async_trait] +impl Manager for ConversationOracleConnectionManager { + type Type = Connection; + type Error = oracle::Error; + + fn create( + &self, + ) -> impl std::future::Future> + Send + { + let params = self.params.clone(); + async move { + let mut conn = Connection::connect( + ¶ms.username, + ¶ms.password, + ¶ms.connect_descriptor, + )?; + conn.set_autocommit(true); + Ok(conn) + } + } + + #[allow(clippy::manual_async_fn)] + fn recycle( + &self, + conn: &mut Connection, + _: &Metrics, + ) -> impl std::future::Future> + Send { + async move { conn.ping().map_err(RecycleError::Backend) } + } +} + +fn configure_oracle_client(config: &OracleConfig) -> Result<()> { + if let Some(wallet_path) = &config.wallet_path { + let wallet_path = Path::new(wallet_path); + if !wallet_path.is_dir() { + return Err(ConversationStorageError::StorageError(format!( + "Oracle wallet/config path '{}' is not a directory", + wallet_path.display() + ))); + } + + if !wallet_path.join("tnsnames.ora").exists() && !wallet_path.join("sqlnet.ora").exists() { + return Err(ConversationStorageError::StorageError(format!( + "Oracle wallet/config path '{}' is missing tnsnames.ora or sqlnet.ora", + wallet_path.display() + ))); + } + + std::env::set_var("TNS_ADMIN", wallet_path); + } + Ok(()) +} + +fn initialize_schema(config: &OracleConfig) -> Result<()> { + let conn = Connection::connect( + &config.username, + &config.password, + &config.connect_descriptor, + ) + .map_err(map_oracle_error)?; + + let exists: i64 = conn + .query_row_as( + "SELECT COUNT(*) FROM user_tables WHERE table_name = 'CONVERSATIONS'", + &[], + ) + .map_err(map_oracle_error)?; + + if exists == 0 { + conn.execute( + "CREATE TABLE conversations ( + id VARCHAR2(64) PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE, + metadata CLOB + )", + &[], + ) + .map_err(map_oracle_error)?; + } + + Ok(()) +} + +fn map_pool_error(err: PoolError) -> ConversationStorageError { + match err { + PoolError::Backend(e) => map_oracle_error(e), + other => ConversationStorageError::StorageError(format!( + "failed to obtain Oracle conversation connection: {other}" + )), + } +} + +fn map_oracle_error(err: oracle::Error) -> ConversationStorageError { + if let Some(db_err) = err.db_error() { + ConversationStorageError::StorageError(format!( + "Oracle error (code {}): {}", + db_err.code(), + db_err.message() + )) + } else { + ConversationStorageError::StorageError(err.to_string()) + } +} diff --git a/sgl-router/src/data_connector/conversations.rs b/sgl-router/src/data_connector/conversations.rs new file mode 100644 index 000000000..3c27555cb --- /dev/null +++ b/sgl-router/src/data_connector/conversations.rs @@ -0,0 +1,120 @@ +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use rand::RngCore; +use serde::{Deserialize, Serialize}; +use serde_json::{Map as JsonMap, Value}; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] +pub struct ConversationId(pub String); + +impl ConversationId { + pub fn new() -> Self { + let mut rng = rand::rng(); + let mut bytes = [0u8; 24]; + rng.fill_bytes(&mut bytes); + let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect(); + Self(format!("conv_{}", hex_string)) + } +} + +impl Default for ConversationId { + fn default() -> Self { + Self::new() + } +} + +impl From for ConversationId { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From<&str> for ConversationId { + fn from(value: &str) -> Self { + Self(value.to_string()) + } +} + +impl Display for ConversationId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +/// Metadata payload persisted with a conversation +pub type ConversationMetadata = JsonMap; + +/// Input payload for creating a conversation +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct NewConversation { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +/// Stored conversation data structure +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Conversation { + pub id: ConversationId, + pub created_at: DateTime, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +impl Conversation { + pub fn new(new_conversation: NewConversation) -> Self { + Self { + id: ConversationId::new(), + created_at: Utc::now(), + metadata: new_conversation.metadata, + } + } + + pub fn with_parts( + id: ConversationId, + created_at: DateTime, + metadata: Option, + ) -> Self { + Self { + id, + created_at, + metadata, + } + } +} + +/// Result alias for conversation storage operations +pub type Result = std::result::Result; + +/// Error type for conversation storage operations +#[derive(Debug, thiserror::Error)] +pub enum ConversationStorageError { + #[error("Conversation not found: {0}")] + ConversationNotFound(String), + + #[error("Storage error: {0}")] + StorageError(String), + + #[error("Serialization error: {0}")] + SerializationError(#[from] serde_json::Error), +} + +/// Trait describing the CRUD interface for conversation storage backends +#[async_trait] +pub trait ConversationStorage: Send + Sync + 'static { + async fn create_conversation(&self, input: NewConversation) -> Result; + + async fn get_conversation(&self, id: &ConversationId) -> Result>; + + async fn update_conversation( + &self, + id: &ConversationId, + metadata: Option, + ) -> Result>; + + async fn delete_conversation(&self, id: &ConversationId) -> Result; +} + +/// Shared pointer alias for conversation storage +pub type SharedConversationStorage = Arc; diff --git a/sgl-router/src/data_connector/mod.rs b/sgl-router/src/data_connector/mod.rs index e79cacea8..0877ad40b 100644 --- a/sgl-router/src/data_connector/mod.rs +++ b/sgl-router/src/data_connector/mod.rs @@ -1,9 +1,21 @@ -// Data connector module for response storage +// Data connector module for response storage and conversation storage +pub mod conversation_memory_store; +pub mod conversation_noop_store; +pub mod conversation_oracle_store; +pub mod conversations; pub mod response_memory_store; pub mod response_noop_store; pub mod response_oracle_store; pub mod responses; +pub use conversation_memory_store::MemoryConversationStorage; +pub use conversation_noop_store::NoOpConversationStorage; +pub use conversation_oracle_store::OracleConversationStorage; +pub use conversations::{ + Conversation, ConversationId, ConversationMetadata, ConversationStorage, + ConversationStorageError, NewConversation, Result as ConversationResult, + SharedConversationStorage, +}; pub use response_memory_store::MemoryResponseStorage; pub use response_noop_store::NoOpResponseStorage; pub use response_oracle_store::OracleResponseStorage; diff --git a/sgl-router/src/data_connector/response_memory_store.rs b/sgl-router/src/data_connector/response_memory_store.rs index b6ca9f527..d9067ef54 100644 --- a/sgl-router/src/data_connector/response_memory_store.rs +++ b/sgl-router/src/data_connector/response_memory_store.rs @@ -207,10 +207,10 @@ mod tests { async fn test_store_with_custom_id() { let store = MemoryResponseStorage::new(); let mut response = StoredResponse::new("Input".to_string(), "Output".to_string(), None); - response.id = ResponseId::from_string("resp_custom".to_string()); + response.id = ResponseId::from("resp_custom"); store.store_response(response.clone()).await.unwrap(); let retrieved = store - .get_response(&ResponseId::from_string("resp_custom".to_string())) + .get_response(&ResponseId::from("resp_custom")) .await .unwrap(); assert!(retrieved.is_some()); diff --git a/sgl-router/src/data_connector/responses.rs b/sgl-router/src/data_connector/responses.rs index 175311ef8..7a4277183 100644 --- a/sgl-router/src/data_connector/responses.rs +++ b/sgl-router/src/data_connector/responses.rs @@ -12,10 +12,6 @@ impl ResponseId { pub fn new() -> Self { Self(ulid::Ulid::new().to_string()) } - - pub fn from_string(s: String) -> Self { - Self(s) - } } impl Default for ResponseId { @@ -24,6 +20,18 @@ impl Default for ResponseId { } } +impl From for ResponseId { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From<&str> for ResponseId { + fn from(value: &str) -> Self { + Self(value.to_string()) + } +} + /// Stored response data #[derive(Debug, Clone, Serialize, Deserialize)] pub struct StoredResponse { diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index ef9e21c0a..dad0144e1 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -128,6 +128,7 @@ impl RouterFactory { base_url, Some(ctx.router_config.circuit_breaker.clone()), ctx.response_storage.clone(), + ctx.conversation_storage.clone(), ) .await?; diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index 04a9b5956..d14028592 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -2,7 +2,10 @@ use crate::config::CircuitBreakerConfig; use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}; -use crate::data_connector::{ResponseId, SharedResponseStorage, StoredResponse}; +use crate::data_connector::{ + Conversation, ConversationId, ConversationMetadata, ResponseId, SharedConversationStorage, + SharedResponseStorage, StoredResponse, +}; use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem, @@ -16,6 +19,7 @@ use axum::{ extract::Request, http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, response::{IntoResponse, Response}, + Json, }; use bytes::Bytes; use futures_util::StreamExt; @@ -75,6 +79,8 @@ pub struct OpenAIRouter { healthy: AtomicBool, /// Response storage for managing conversation history response_storage: SharedResponseStorage, + /// Conversation storage backend + conversation_storage: SharedConversationStorage, /// Optional MCP manager (enabled via config presence) mcp_manager: Option>, } @@ -705,6 +711,7 @@ impl OpenAIRouter { base_url: String, circuit_breaker_config: Option, response_storage: SharedResponseStorage, + conversation_storage: SharedConversationStorage, ) -> Result { let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(300)) @@ -751,6 +758,7 @@ impl OpenAIRouter { circuit_breaker, healthy: AtomicBool::new(true), response_storage, + conversation_storage, mcp_manager, }) } @@ -2337,16 +2345,16 @@ impl OpenAIRouter { stored_response.previous_response_id = response_json .get("previous_response_id") .and_then(|v| v.as_str()) - .map(|s| ResponseId::from_string(s.to_string())) + .map(ResponseId::from) .or_else(|| { original_body .previous_response_id .as_ref() - .map(|id| ResponseId::from_string(id.clone())) + .map(|id| ResponseId::from(id.as_str())) }); if let Some(id_str) = response_json.get("id").and_then(|v| v.as_str()) { - stored_response.id = ResponseId::from_string(id_str.to_string()); + stored_response.id = ResponseId::from(id_str); } stored_response.raw_response = response_json.clone(); @@ -3393,7 +3401,7 @@ impl super::super::RouterTrait for OpenAIRouter { // Handle previous_response_id by loading prior context let mut conversation_items: Option> = None; if let Some(prev_id_str) = request_body.previous_response_id.clone() { - let prev_id = ResponseId::from_string(prev_id_str.clone()); + let prev_id = ResponseId::from(prev_id_str.as_str()); match self .response_storage .get_response_chain(&prev_id, None) @@ -3516,7 +3524,7 @@ impl super::super::RouterTrait for OpenAIRouter { response_id: &str, params: &ResponsesGetParams, ) -> Response { - let stored_id = ResponseId::from_string(response_id.to_string()); + let stored_id = ResponseId::from(response_id); if let Ok(Some(stored_response)) = self.response_storage.get_response(&stored_id).await { let stream_requested = params.stream.unwrap_or(false); let raw_value = stored_response.raw_response.clone(); @@ -3646,10 +3654,6 @@ impl super::super::RouterTrait for OpenAIRouter { } } - fn router_type(&self) -> &'static str { - "openai" - } - async fn route_embeddings( &self, _headers: Option<&HeaderMap>, @@ -3675,4 +3679,309 @@ impl super::super::RouterTrait for OpenAIRouter { ) .into_response() } + + async fn create_conversation(&self, _headers: Option<&HeaderMap>, body: &Value) -> Response { + // TODO: move this spec validation to the right place + let metadata = match body.get("metadata") { + Some(Value::Object(map)) => { + if map.len() > MAX_METADATA_PROPERTIES { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": { + "message": format!( + "Invalid 'metadata': too many properties. Max {}, got {}", + MAX_METADATA_PROPERTIES, map.len() + ), + "type": "invalid_request_error", + "param": "metadata", + "code": "metadata_max_properties_exceeded" + } + })), + ) + .into_response(); + } + Some(map.clone()) + } + Some(Value::Null) | None => None, + Some(other) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": { + "message": format!( + "Invalid 'metadata': expected object or null but got {}", + other + ), + "type": "invalid_request_error", + "param": "metadata", + "code": "metadata_invalid_type" + } + })), + ) + .into_response(); + } + }; + + match self + .conversation_storage + .create_conversation(crate::data_connector::NewConversation { metadata }) + .await + { + Ok(conversation) => { + (StatusCode::OK, Json(conversation_to_json(&conversation))).into_response() + } + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": { + "message": err.to_string(), + "type": "internal_error", + "param": Value::Null, + "code": Value::Null + } + })), + ) + .into_response(), + } + } + + async fn get_conversation( + &self, + _headers: Option<&HeaderMap>, + conversation_id: &str, + ) -> Response { + let id: ConversationId = conversation_id.to_string().into(); + match self.conversation_storage.get_conversation(&id).await { + Ok(Some(conv)) => (StatusCode::OK, Json(conversation_to_json(&conv))).into_response(), + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(json!({ + "error": { + "message": format!("Conversation with id '{}' not found.", conversation_id), + "type": "invalid_request_error", + "param": Value::Null, + "code": Value::Null + } + })), + ) + .into_response(), + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": { + "message": err.to_string(), + "type": "internal_error", + "param": Value::Null, + "code": Value::Null + } + })), + ) + .into_response(), + } + } + + async fn update_conversation( + &self, + _headers: Option<&HeaderMap>, + conversation_id: &str, + body: &Value, + ) -> Response { + let id: ConversationId = conversation_id.to_string().into(); + let existing = match self.conversation_storage.get_conversation(&id).await { + Ok(Some(c)) => c, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({ + "error": { + "message": format!("Conversation with id '{}' not found.", conversation_id), + "type": "invalid_request_error", + "param": Value::Null, + "code": Value::Null + } + })), + ) + .into_response(); + } + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": { + "message": err.to_string(), + "type": "internal_error", + "param": Value::Null, + "code": Value::Null + } + })), + ) + .into_response(); + } + }; + + // Parse metadata patch + enum Patch { + NoChange, + ClearAll, + Merge(ConversationMetadata), + } + let patch = match body.get("metadata") { + None => Patch::NoChange, + Some(Value::Null) => Patch::ClearAll, + Some(Value::Object(map)) => Patch::Merge(map.clone()), + Some(other) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": { + "message": format!( + "Invalid 'metadata': expected object or null but got {}", + other + ), + "type": "invalid_request_error", + "param": "metadata", + "code": "metadata_invalid_type" + } + })), + ) + .into_response(); + } + }; + + let merged_metadata = match patch { + Patch::NoChange => { + return (StatusCode::OK, Json(conversation_to_json(&existing))).into_response(); + } + Patch::ClearAll => None, + Patch::Merge(upd) => { + let mut merged = existing.metadata.clone().unwrap_or_default(); + let previous = merged.len(); + for (k, v) in upd.into_iter() { + if v.is_null() { + merged.remove(&k); + } else { + merged.insert(k, v); + } + } + let updated = merged.len(); + if updated > MAX_METADATA_PROPERTIES { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": { + "message": format!( + "Invalid 'metadata': too many properties after update. Max {} ({} -> {}).", + MAX_METADATA_PROPERTIES, previous, updated + ), + "type": "invalid_request_error", + "param": "metadata", + "code": "metadata_max_properties_exceeded", + "extra": { + "previous_property_count": previous, + "updated_property_count": updated + } + } + })), + ) + .into_response(); + } + if merged.is_empty() { + None + } else { + Some(merged) + } + } + }; + + match self + .conversation_storage + .update_conversation(&id, merged_metadata) + .await + { + Ok(Some(conv)) => (StatusCode::OK, Json(conversation_to_json(&conv))).into_response(), + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(json!({ + "error": { + "message": format!("Conversation with id '{}' not found.", conversation_id), + "type": "invalid_request_error", + "param": Value::Null, + "code": Value::Null + } + })), + ) + .into_response(), + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": { + "message": err.to_string(), + "type": "internal_error", + "param": Value::Null, + "code": Value::Null + } + })), + ) + .into_response(), + } + } + + async fn delete_conversation( + &self, + _headers: Option<&HeaderMap>, + conversation_id: &str, + ) -> Response { + let id: ConversationId = conversation_id.to_string().into(); + match self.conversation_storage.delete_conversation(&id).await { + Ok(true) => ( + StatusCode::OK, + Json(json!({ + "id": conversation_id, + "object": "conversation.deleted", + "deleted": true + })), + ) + .into_response(), + Ok(false) => ( + StatusCode::NOT_FOUND, + Json(json!({ + "error": { + "message": format!("Conversation with id '{}' not found.", conversation_id), + "type": "invalid_request_error", + "param": Value::Null, + "code": Value::Null + } + })), + ) + .into_response(), + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": { + "message": err.to_string(), + "type": "internal_error", + "param": Value::Null, + "code": Value::Null + } + })), + ) + .into_response(), + } + } + + fn router_type(&self) -> &'static str { + "openai" + } +} +// Maximum number of properties allowed in conversation metadata (align with server) +const MAX_METADATA_PROPERTIES: usize = 16; + +fn conversation_to_json(conversation: &Conversation) -> Value { + json!({ + "id": conversation.id.0, + "object": "conversation", + "created_at": conversation.created_at.timestamp(), + "metadata": to_value(&conversation.metadata).unwrap_or(Value::Null), + }) } diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 212afbfcf..45b4a5fcd 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -13,6 +13,7 @@ use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams, ResponsesRequest, }; +use serde_json::Value; pub mod factory; pub mod grpc; @@ -126,6 +127,52 @@ pub trait RouterTrait: Send + Sync + Debug { model_id: Option<&str>, ) -> Response; + // Conversations API + async fn create_conversation(&self, _headers: Option<&HeaderMap>, _body: &Value) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Conversations create endpoint not implemented", + ) + .into_response() + } + + async fn get_conversation( + &self, + _headers: Option<&HeaderMap>, + _conversation_id: &str, + ) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Conversations get endpoint not implemented", + ) + .into_response() + } + + async fn update_conversation( + &self, + _headers: Option<&HeaderMap>, + _conversation_id: &str, + _body: &Value, + ) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Conversations update endpoint not implemented", + ) + .into_response() + } + + async fn delete_conversation( + &self, + _headers: Option<&HeaderMap>, + _conversation_id: &str, + ) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Conversations delete endpoint not implemented", + ) + .into_response() + } + /// Get router type name fn router_type(&self) -> &'static str; diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index 5ad875212..cc05266bc 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -20,6 +20,7 @@ use axum::{ response::{IntoResponse, Response}, }; use dashmap::DashMap; +use serde_json::Value; use std::sync::Arc; use tracing::{debug, info, warn}; @@ -511,6 +512,83 @@ impl RouterTrait for RouterManager { fn router_type(&self) -> &'static str { "manager" } + + // Conversations API delegates + async fn create_conversation(&self, headers: Option<&HeaderMap>, body: &Value) -> Response { + let router = self.select_router_for_request(headers, None); + if let Some(router) = router { + router.create_conversation(headers, body).await + } else { + ( + StatusCode::NOT_FOUND, + "No router available to create conversation", + ) + .into_response() + } + } + + async fn get_conversation( + &self, + headers: Option<&HeaderMap>, + conversation_id: &str, + ) -> Response { + let router = self.select_router_for_request(headers, None); + if let Some(router) = router { + router.get_conversation(headers, conversation_id).await + } else { + ( + StatusCode::NOT_FOUND, + format!( + "No router available to get conversation '{}'", + conversation_id + ), + ) + .into_response() + } + } + + async fn update_conversation( + &self, + headers: Option<&HeaderMap>, + conversation_id: &str, + body: &Value, + ) -> Response { + let router = self.select_router_for_request(headers, None); + if let Some(router) = router { + router + .update_conversation(headers, conversation_id, body) + .await + } else { + ( + StatusCode::NOT_FOUND, + format!( + "No router available to update conversation '{}'", + conversation_id + ), + ) + .into_response() + } + } + + async fn delete_conversation( + &self, + headers: Option<&HeaderMap>, + conversation_id: &str, + ) -> Response { + let router = self.select_router_for_request(headers, None); + if let Some(router) = router { + router.delete_conversation(headers, conversation_id).await + } else { + ( + StatusCode::NOT_FOUND, + format!( + "No router available to delete conversation '{}'", + conversation_id + ), + ) + .into_response() + } + } } impl std::fmt::Debug for RouterManager { diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 0e1676794..1992daea5 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -2,7 +2,9 @@ use crate::{ config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType}, data_connector::{ - MemoryResponseStorage, NoOpResponseStorage, OracleResponseStorage, SharedResponseStorage, + MemoryConversationStorage, MemoryResponseStorage, NoOpConversationStorage, + NoOpResponseStorage, OracleConversationStorage, OracleResponseStorage, + SharedConversationStorage, SharedResponseStorage, }, logging::{self, LoggingConfig}, metrics::{self, PrometheusConfig}, @@ -39,6 +41,8 @@ use std::{ use tokio::{net::TcpListener, signal, spawn}; use tracing::{error, info, warn, Level}; +// + #[derive(Clone)] pub struct AppContext { pub client: Client, @@ -51,6 +55,7 @@ pub struct AppContext { pub policy_registry: Arc, pub router_manager: Option>, pub response_storage: SharedResponseStorage, + pub conversation_storage: SharedConversationStorage, pub load_monitor: Option>, pub configured_reasoning_parser: Option, pub configured_tool_parser: Option, @@ -94,19 +99,34 @@ impl AppContext { let router_manager = None; - let response_storage: SharedResponseStorage = match router_config.history_backend { - HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()), - HistoryBackend::None => Arc::new(NoOpResponseStorage::new()), + let (response_storage, conversation_storage): ( + SharedResponseStorage, + SharedConversationStorage, + ) = match router_config.history_backend { + HistoryBackend::Memory => ( + Arc::new(MemoryResponseStorage::new()), + Arc::new(MemoryConversationStorage::new()), + ), + HistoryBackend::None => ( + Arc::new(NoOpResponseStorage::new()), + Arc::new(NoOpConversationStorage::new()), + ), HistoryBackend::Oracle => { let oracle_cfg = router_config.oracle.clone().ok_or_else(|| { "oracle configuration is required when history_backend=oracle".to_string() })?; - let storage = OracleResponseStorage::new(oracle_cfg).map_err(|err| { - format!("failed to initialize Oracle response storage: {err}") - })?; + let response_storage = + OracleResponseStorage::new(oracle_cfg.clone()).map_err(|err| { + format!("failed to initialize Oracle response storage: {err}") + })?; - Arc::new(storage) + let conversation_storage = + OracleConversationStorage::new(oracle_cfg).map_err(|err| { + format!("failed to initialize Oracle conversation storage: {err}") + })?; + + (Arc::new(response_storage), Arc::new(conversation_storage)) } }; @@ -131,6 +151,7 @@ impl AppContext { policy_registry, router_manager, response_storage, + conversation_storage, load_monitor, configured_reasoning_parser, configured_tool_parser, @@ -334,6 +355,51 @@ async fn v1_responses_list_input_items( .await } +async fn v1_conversations_create( + State(state): State>, + headers: http::HeaderMap, + Json(body): Json, +) -> Response { + state + .router + .create_conversation(Some(&headers), &body) + .await +} + +async fn v1_conversations_get( + State(state): State>, + Path(conversation_id): Path, + headers: http::HeaderMap, +) -> Response { + state + .router + .get_conversation(Some(&headers), &conversation_id) + .await +} + +async fn v1_conversations_update( + State(state): State>, + Path(conversation_id): Path, + headers: http::HeaderMap, + Json(body): Json, +) -> Response { + state + .router + .update_conversation(Some(&headers), &conversation_id, &body) + .await +} + +async fn v1_conversations_delete( + State(state): State>, + Path(conversation_id): Path, + headers: http::HeaderMap, +) -> Response { + state + .router + .delete_conversation(Some(&headers), &conversation_id) + .await +} + #[derive(Deserialize)] struct AddWorkerQuery { url: String, @@ -601,6 +667,13 @@ pub fn build_app( "/v1/responses/{response_id}/input", get(v1_responses_list_input_items), ) + .route("/v1/conversations", post(v1_conversations_create)) + .route( + "/v1/conversations/{conversation_id}", + get(v1_conversations_get) + .post(v1_conversations_update) + .delete(v1_conversations_delete), + ) .route_layer(axum::middleware::from_fn_with_state( app_state.clone(), middleware::concurrency_limit_middleware, diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 4266661e5..9e317fd6b 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -542,6 +542,7 @@ mod tests { tool_parser_factory: None, router_manager: None, response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()), + conversation_storage: Arc::new(crate::data_connector::MemoryConversationStorage::new()), load_monitor: None, configured_reasoning_parser: None, configured_tool_parser: None, diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index ebbc76135..528bf6a5a 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -239,6 +239,100 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { mcp.stop().await; } +#[tokio::test] +async fn test_conversations_crud_basic() { + // Router in OpenAI mode (no actual upstream calls in these tests) + let router_cfg = RouterConfig { + mode: RoutingMode::OpenAI { + worker_urls: vec!["http://localhost".to_string()], + }, + connection_mode: ConnectionMode::Http, + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 0, + max_payload_size: 8 * 1024 * 1024, + request_timeout_secs: 60, + worker_startup_timeout_secs: 1, + worker_startup_check_interval_secs: 1, + dp_aware: false, + api_key: None, + discovery: None, + metrics: None, + log_dir: None, + log_level: Some("warn".to_string()), + request_id_headers: None, + max_concurrent_requests: 8, + queue_size: 0, + queue_timeout_secs: 5, + rate_limit_tokens_per_second: None, + cors_allowed_origins: vec![], + retry: RetryConfig::default(), + circuit_breaker: CircuitBreakerConfig::default(), + disable_retries: false, + disable_circuit_breaker: false, + health_check: HealthCheckConfig::default(), + enable_igw: false, + model_path: None, + tokenizer_path: None, + history_backend: sglang_router_rs::config::HistoryBackend::Memory, + oracle: None, + reasoning_parser: None, + tool_call_parser: None, + }; + + let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx"); + let router = RouterFactory::create_router(&Arc::new(ctx)) + .await + .expect("router"); + + // Create + let create_body = serde_json::json!({ "metadata": { "project": "alpha" } }); + let create_resp = router.create_conversation(None, &create_body).await; + assert_eq!(create_resp.status(), axum::http::StatusCode::OK); + let create_bytes = axum::body::to_bytes(create_resp.into_body(), usize::MAX) + .await + .unwrap(); + let create_json: serde_json::Value = serde_json::from_slice(&create_bytes).unwrap(); + let conv_id = create_json["id"].as_str().expect("id missing"); + assert!(conv_id.starts_with("conv_")); + assert_eq!(create_json["object"], "conversation"); + + // Get + let get_resp = router.get_conversation(None, conv_id).await; + assert_eq!(get_resp.status(), axum::http::StatusCode::OK); + let get_bytes = axum::body::to_bytes(get_resp.into_body(), usize::MAX) + .await + .unwrap(); + let get_json: serde_json::Value = serde_json::from_slice(&get_bytes).unwrap(); + assert_eq!(get_json["metadata"]["project"], serde_json::json!("alpha")); + + // Update (merge) + let update_body = serde_json::json!({ "metadata": { "owner": "alice" } }); + let upd_resp = router + .update_conversation(None, conv_id, &update_body) + .await; + assert_eq!(upd_resp.status(), axum::http::StatusCode::OK); + let upd_bytes = axum::body::to_bytes(upd_resp.into_body(), usize::MAX) + .await + .unwrap(); + let upd_json: serde_json::Value = serde_json::from_slice(&upd_bytes).unwrap(); + assert_eq!(upd_json["metadata"]["project"], serde_json::json!("alpha")); + assert_eq!(upd_json["metadata"]["owner"], serde_json::json!("alice")); + + // Delete + let del_resp = router.delete_conversation(None, conv_id).await; + assert_eq!(del_resp.status(), axum::http::StatusCode::OK); + let del_bytes = axum::body::to_bytes(del_resp.into_body(), usize::MAX) + .await + .unwrap(); + let del_json: serde_json::Value = serde_json::from_slice(&del_bytes).unwrap(); + assert_eq!(del_json["deleted"], serde_json::json!(true)); + + // Get again -> 404 + let not_found = router.get_conversation(None, conv_id).await; + assert_eq!(not_found.status(), axum::http::StatusCode::NOT_FOUND); +} + #[test] fn test_responses_request_creation() { let request = ResponsesRequest { diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index 98ad8d6b3..680d6dcb8 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -13,7 +13,10 @@ use sglang_router_rs::{ config::{ ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode, }, - data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage, StoredResponse}, + data_connector::{ + MemoryConversationStorage, MemoryResponseStorage, ResponseId, ResponseStorage, + StoredResponse, + }, protocols::spec::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput, ResponsesGetParams, ResponsesRequest, UserMessageContent, @@ -91,6 +94,7 @@ async fn test_openai_router_creation() { "https://api.openai.com".to_string(), None, Arc::new(MemoryResponseStorage::new()), + Arc::new(MemoryConversationStorage::new()), ) .await; @@ -108,6 +112,7 @@ async fn test_openai_router_server_info() { "https://api.openai.com".to_string(), None, Arc::new(MemoryResponseStorage::new()), + Arc::new(MemoryConversationStorage::new()), ) .await .unwrap(); @@ -137,6 +142,7 @@ async fn test_openai_router_models() { mock_server.base_url(), None, Arc::new(MemoryResponseStorage::new()), + Arc::new(MemoryConversationStorage::new()), ) .await .unwrap(); @@ -211,9 +217,14 @@ async fn test_openai_router_responses_with_mock() { let base_url = format!("http://{}", addr); let storage = Arc::new(MemoryResponseStorage::new()); - let router = OpenAIRouter::new(base_url, None, storage.clone()) - .await - .unwrap(); + let router = OpenAIRouter::new( + base_url, + None, + storage.clone(), + Arc::new(MemoryConversationStorage::new()), + ) + .await + .unwrap(); let request1 = ResponsesRequest { model: Some("gpt-4o-mini".to_string()), @@ -252,7 +263,7 @@ async fn test_openai_router_responses_with_mock() { ); let stored1 = storage - .get_response(&ResponseId::from_string(resp1_id.clone())) + .get_response(&ResponseId::from(resp1_id.clone())) .await .unwrap() .expect("first response missing"); @@ -261,7 +272,7 @@ async fn test_openai_router_responses_with_mock() { assert!(stored1.previous_response_id.is_none()); let stored2 = storage - .get_response(&ResponseId::from_string(resp2_id.to_string())) + .get_response(&ResponseId::from(resp2_id)) .await .unwrap() .expect("second response missing"); @@ -463,12 +474,17 @@ async fn test_openai_router_responses_streaming_with_mock() { "Earlier answer".to_string(), None, ); - previous.id = ResponseId::from_string("resp_prev_chain".to_string()); + previous.id = ResponseId::from("resp_prev_chain"); storage.store_response(previous).await.unwrap(); - let router = OpenAIRouter::new(base_url, None, storage.clone()) - .await - .unwrap(); + let router = OpenAIRouter::new( + base_url, + None, + storage.clone(), + Arc::new(MemoryConversationStorage::new()), + ) + .await + .unwrap(); let mut metadata = HashMap::new(); metadata.insert("topic".to_string(), json!("unicorns")); @@ -504,7 +520,7 @@ async fn test_openai_router_responses_streaming_with_mock() { assert!(body_text.contains("Once upon a streamed unicorn adventure.")); // Wait for the storage task to persist the streaming response. - let target_id = ResponseId::from_string("resp_stream_123".to_string()); + let target_id = ResponseId::from("resp_stream_123"); let stored = loop { if let Some(resp) = storage.get_response(&target_id).await.unwrap() { break resp; @@ -569,6 +585,7 @@ async fn test_unsupported_endpoints() { "https://api.openai.com".to_string(), None, Arc::new(MemoryResponseStorage::new()), + Arc::new(MemoryConversationStorage::new()), ) .await .unwrap(); @@ -605,9 +622,14 @@ async fn test_openai_router_chat_completion_with_mock() { let base_url = mock_server.base_url(); // Create router pointing to mock server - let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new())) - .await - .unwrap(); + let router = OpenAIRouter::new( + base_url, + None, + Arc::new(MemoryResponseStorage::new()), + Arc::new(MemoryConversationStorage::new()), + ) + .await + .unwrap(); // Create a minimal chat completion request let mut chat_request = create_minimal_chat_request(); @@ -642,9 +664,14 @@ async fn test_openai_e2e_with_server() { let base_url = mock_server.base_url(); // Create router - let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new())) - .await - .unwrap(); + let router = OpenAIRouter::new( + base_url, + None, + Arc::new(MemoryResponseStorage::new()), + Arc::new(MemoryConversationStorage::new()), + ) + .await + .unwrap(); // Create Axum app with chat completions endpoint let app = Router::new().route( @@ -707,9 +734,14 @@ async fn test_openai_e2e_with_server() { async fn test_openai_router_chat_streaming_with_mock() { let mock_server = MockOpenAIServer::new().await; let base_url = mock_server.base_url(); - let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new())) - .await - .unwrap(); + let router = OpenAIRouter::new( + base_url, + None, + Arc::new(MemoryResponseStorage::new()), + Arc::new(MemoryConversationStorage::new()), + ) + .await + .unwrap(); // Build a streaming chat request let val = json!({ @@ -759,6 +791,7 @@ async fn test_openai_router_circuit_breaker() { "http://invalid-url-that-will-fail".to_string(), Some(cb_config), Arc::new(MemoryResponseStorage::new()), + Arc::new(MemoryConversationStorage::new()), ) .await .unwrap(); @@ -786,6 +819,7 @@ async fn test_openai_router_models_auth_forwarding() { mock_server.base_url(), None, Arc::new(MemoryResponseStorage::new()), + Arc::new(MemoryConversationStorage::new()), ) .await .unwrap();