diff --git a/sgl-router/src/data_connector/conversation_item_memory_store.rs b/sgl-router/src/data_connector/conversation_item_memory_store.rs new file mode 100644 index 000000000..cb8c65787 --- /dev/null +++ b/sgl-router/src/data_connector/conversation_item_memory_store.rs @@ -0,0 +1,260 @@ +use std::collections::{BTreeMap, HashMap}; +use std::sync::RwLock; + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; + +use super::conversation_items::{ + make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams, + Result, SortOrder, +}; +use super::conversations::ConversationId; + +#[derive(Default)] +pub struct MemoryConversationItemStorage { + items: RwLock>, // item_id -> item + #[allow(clippy::type_complexity)] + links: RwLock>>, + // Per-conversation reverse index for fast after cursor lookup: item_id_str -> (ts, item_id_str) + #[allow(clippy::type_complexity)] + rev_index: RwLock>>, +} + +impl MemoryConversationItemStorage { + pub fn new() -> Self { + Self::default() + } +} + +#[async_trait] +impl ConversationItemStorage for MemoryConversationItemStorage { + async fn create_item( + &self, + new_item: super::conversation_items::NewConversationItem, + ) -> Result { + let id = new_item + .id + .clone() + .unwrap_or_else(|| make_item_id(&new_item.item_type)); + let created_at = Utc::now(); + let item = ConversationItem { + id: id.clone(), + response_id: new_item.response_id, + item_type: new_item.item_type, + role: new_item.role, + content: new_item.content, + status: new_item.status, + created_at, + }; + let mut items = self.items.write().unwrap(); + items.insert(id.clone(), item.clone()); + Ok(item) + } + + async fn link_item( + &self, + conversation_id: &ConversationId, + item_id: &ConversationItemId, + added_at: DateTime, + ) -> Result<()> { + { + let mut links = self.links.write().unwrap(); + let entry = links.entry(conversation_id.clone()).or_default(); + entry.insert((added_at.timestamp(), item_id.0.clone()), item_id.clone()); + } + { + let mut rev = self.rev_index.write().unwrap(); + let entry = rev.entry(conversation_id.clone()).or_default(); + entry.insert(item_id.0.clone(), (added_at.timestamp(), item_id.0.clone())); + } + Ok(()) + } + + async fn list_items( + &self, + conversation_id: &ConversationId, + params: ListParams, + ) -> Result> { + let links_guard = self.links.read().unwrap(); + let map = match links_guard.get(conversation_id) { + Some(m) => m, + None => return Ok(Vec::new()), + }; + + let mut results: Vec = Vec::new(); + let after_key: Option<(i64, String)> = if let Some(after_id) = ¶ms.after { + // O(1) lookup via reverse index for this conversation + if let Some(conv_idx) = self.rev_index.read().unwrap().get(conversation_id) { + conv_idx.get(after_id).cloned() + } else { + None + } + } else { + None + }; + + let take = params.limit; + let items_guard = self.items.read().unwrap(); + + use std::ops::Bound::{Excluded, Unbounded}; + + // Helper to push item if it exists and stop when reaching the limit + let mut push_item = |key: &ConversationItemId| -> bool { + if let Some(it) = items_guard.get(key) { + results.push(it.clone()); + if results.len() == take { + return true; + } + } + false + }; + + match (params.order, after_key) { + (SortOrder::Desc, Some(k)) => { + for ((_ts, _id), item_key) in map.range(..k).rev() { + if push_item(item_key) { + break; + } + } + } + (SortOrder::Desc, None) => { + for ((_ts, _id), item_key) in map.iter().rev() { + if push_item(item_key) { + break; + } + } + } + (SortOrder::Asc, Some(k)) => { + for ((_ts, _id), item_key) in map.range((Excluded(k), Unbounded)) { + if push_item(item_key) { + break; + } + } + } + (SortOrder::Asc, None) => { + for ((_ts, _id), item_key) in map.iter() { + if push_item(item_key) { + break; + } + } + } + } + + Ok(results) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::{TimeZone, Utc}; + + fn make_item( + item_type: &str, + role: Option<&str>, + content: serde_json::Value, + ) -> super::super::conversation_items::NewConversationItem { + super::super::conversation_items::NewConversationItem { + id: None, + response_id: None, + item_type: item_type.to_string(), + role: role.map(|r| r.to_string()), + content, + status: Some("completed".to_string()), + } + } + + #[tokio::test] + async fn test_list_ordering_and_cursors() { + let store = MemoryConversationItemStorage::new(); + let conv: ConversationId = "conv_test".into(); + + // Create 3 items and link them at controlled timestamps + let i1 = store + .create_item(make_item("message", Some("user"), serde_json::json!([]))) + .await + .unwrap(); + let i2 = store + .create_item(make_item( + "message", + Some("assistant"), + serde_json::json!([]), + )) + .await + .unwrap(); + let i3 = store + .create_item(make_item("reasoning", None, serde_json::json!([]))) + .await + .unwrap(); + + let t1 = Utc.timestamp_opt(1_700_000_001, 0).single().unwrap(); + let t2 = Utc.timestamp_opt(1_700_000_002, 0).single().unwrap(); + let t3 = Utc.timestamp_opt(1_700_000_003, 0).single().unwrap(); + + store.link_item(&conv, &i1.id, t1).await.unwrap(); + store.link_item(&conv, &i2.id, t2).await.unwrap(); + store.link_item(&conv, &i3.id, t3).await.unwrap(); + + // Desc order, no cursor + let desc = store + .list_items( + &conv, + ListParams { + limit: 2, + order: SortOrder::Desc, + after: None, + }, + ) + .await + .unwrap(); + assert!(desc.len() >= 2); + assert_eq!(desc[0].id, i3.id); + assert_eq!(desc[1].id, i2.id); + + // Desc with cursor = i2 -> expect i1 next + let desc_after = store + .list_items( + &conv, + ListParams { + limit: 2, + order: SortOrder::Desc, + after: Some(i2.id.0.clone()), + }, + ) + .await + .unwrap(); + assert!(!desc_after.is_empty()); + assert_eq!(desc_after[0].id, i1.id); + + // Asc order, no cursor + let asc = store + .list_items( + &conv, + ListParams { + limit: 2, + order: SortOrder::Asc, + after: None, + }, + ) + .await + .unwrap(); + assert!(asc.len() >= 2); + assert_eq!(asc[0].id, i1.id); + assert_eq!(asc[1].id, i2.id); + + // Asc with cursor = i2 -> expect i3 next + let asc_after = store + .list_items( + &conv, + ListParams { + limit: 2, + order: SortOrder::Asc, + after: Some(i2.id.0.clone()), + }, + ) + .await + .unwrap(); + assert!(!asc_after.is_empty()); + assert_eq!(asc_after[0].id, i3.id); + } +} diff --git a/sgl-router/src/data_connector/conversation_item_oracle_store.rs b/sgl-router/src/data_connector/conversation_item_oracle_store.rs new file mode 100644 index 000000000..fb3b0bc57 --- /dev/null +++ b/sgl-router/src/data_connector/conversation_item_oracle_store.rs @@ -0,0 +1,409 @@ +use crate::config::OracleConfig; +use crate::data_connector::conversation_items::{ + make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, + ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder, +}; +use crate::data_connector::conversations::ConversationId; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult}; +use oracle::sql_type::ToSql; +use oracle::Connection; +use serde_json::Value; +use std::path::Path; +use std::sync::Arc; +use std::time::Duration; + +#[derive(Clone)] +pub struct OracleConversationItemStorage { + pool: Pool, +} + +impl OracleConversationItemStorage { + pub fn new(config: OracleConfig) -> ItemResult { + configure_oracle_client(&config)?; + initialize_schema(&config)?; + + let config = Arc::new(config); + let manager = ConversationItemOracleConnectionManager::new(config.clone()); + let mut builder = Pool::builder(manager) + .max_size(config.pool_max) + .runtime(deadpool::Runtime::Tokio1); + if config.pool_timeout_secs > 0 { + builder = builder.wait_timeout(Some(Duration::from_secs(config.pool_timeout_secs))); + } + let pool = builder.build().map_err(|err| { + ConversationItemStorageError::StorageError(format!( + "failed to build Oracle pool for conversation items: {err}" + )) + })?; + Ok(Self { pool }) + } + + async fn with_connection(&self, func: F) -> ItemResult + where + F: FnOnce(&Connection) -> ItemResult + Send + 'static, + T: Send + 'static, + { + let connection = self.pool.get().await.map_err(map_pool_error)?; + tokio::task::spawn_blocking(move || { + let result = func(&connection); + drop(connection); + result + }) + .await + .map_err(|err| { + ConversationItemStorageError::StorageError(format!( + "failed to execute Oracle conversation item task: {err}" + )) + })? + } + + // reserved for future use when parsing JSON columns directly into Value + // fn parse_json(raw: Option) -> ItemResult { ... } +} + +#[async_trait] +impl ConversationItemStorage for OracleConversationItemStorage { + async fn create_item( + &self, + item: crate::data_connector::conversation_items::NewConversationItem, + ) -> ItemResult { + let id = item + .id + .clone() + .unwrap_or_else(|| make_item_id(&item.item_type)); + let created_at = Utc::now(); + let content_json = serde_json::to_string(&item.content)?; + + // Build the return value up-front; move inexpensive clones as needed for SQL + let conversation_item = ConversationItem { + id: id.clone(), + response_id: item.response_id.clone(), + item_type: item.item_type.clone(), + role: item.role.clone(), + content: item.content, + status: item.status.clone(), + created_at, + }; + + // Prepare values for SQL insertion + let id_str = conversation_item.id.0.clone(); + let response_id = conversation_item.response_id.clone(); + let item_type = conversation_item.item_type.clone(); + let role = conversation_item.role.clone(); + let status = conversation_item.status.clone(); + + self.with_connection(move |conn| { + conn.execute( + "INSERT INTO conversation_items (id, response_id, item_type, role, content, status, created_at) \ + VALUES (:1, :2, :3, :4, :5, :6, :7)", + &[&id_str, &response_id, &item_type, &role, &content_json, &status, &created_at], + ) + .map_err(map_oracle_error)?; + Ok(()) + }) + .await?; + + Ok(conversation_item) + } + + async fn link_item( + &self, + conversation_id: &ConversationId, + item_id: &ConversationItemId, + added_at: DateTime, + ) -> ItemResult<()> { + let cid = conversation_id.0.clone(); + let iid = item_id.0.clone(); + self.with_connection(move |conn| { + conn.execute( + "INSERT INTO conversation_item_links (conversation_id, item_id, added_at) VALUES (:1, :2, :3)", + &[&cid, &iid, &added_at], + ) + .map_err(map_oracle_error)?; + Ok(()) + }) + .await + } + + async fn list_items( + &self, + conversation_id: &ConversationId, + params: ListParams, + ) -> ItemResult> { + let cid = conversation_id.0.clone(); + let limit: i64 = params.limit as i64; + let order_desc = matches!(params.order, SortOrder::Desc); + let after_id = params.after.clone(); + + // Resolve the added_at of the after cursor if provided + let after_key: Option<(DateTime, String)> = if let Some(ref aid) = after_id { + self.with_connection({ + let cid = cid.clone(); + let aid = aid.clone(); + move |conn| { + let mut stmt = conn + .statement( + "SELECT added_at FROM conversation_item_links WHERE conversation_id = :1 AND item_id = :2", + ) + .build() + .map_err(map_oracle_error)?; + let mut rows = stmt.query(&[&cid, &aid]).map_err(map_oracle_error)?; + if let Some(row_res) = rows.next() { + let row = row_res.map_err(map_oracle_error)?; + let ts: DateTime = row.get(0).map_err(map_oracle_error)?; + Ok(Some((ts, aid))) + } else { + Ok(None) + } + } + }) + .await? + } else { + None + }; + + // Build the main list query + let rows: Vec<(String, Option, String, Option, Option, Option, DateTime)> = + self.with_connection({ + let cid = cid.clone(); + move |conn| { + let mut sql = String::from( + "SELECT i.id, i.response_id, i.item_type, i.role, i.content, i.status, i.created_at \ + FROM conversation_item_links l \ + JOIN conversation_items i ON i.id = l.item_id \ + WHERE l.conversation_id = :cid", + ); + + // Cursor predicate + if let Some((_ts, _iid)) = &after_key { + if order_desc { + sql.push_str(" AND (l.added_at < :ats OR (l.added_at = :ats AND l.item_id < :iid))"); + } else { + sql.push_str(" AND (l.added_at > :ats OR (l.added_at = :ats AND l.item_id > :iid))"); + } + } + + // Order and limit + if order_desc { + sql.push_str(" ORDER BY l.added_at DESC, l.item_id DESC"); + } else { + sql.push_str(" ORDER BY l.added_at ASC, l.item_id ASC"); + } + sql.push_str(" FETCH NEXT :limit ROWS ONLY"); + + // Build params and perform a named SELECT query + let mut params_vec: Vec<(&str, &dyn ToSql)> = vec![("cid", &cid)]; + if let Some((ts, iid)) = &after_key { + params_vec.push(("ats", ts)); + params_vec.push(("iid", iid)); + } + params_vec.push(("limit", &limit)); + + let rows_iter = conn.query_named(&sql, ¶ms_vec).map_err(map_oracle_error)?; + + let mut out = Vec::new(); + for row_res in rows_iter { + let row = row_res.map_err(map_oracle_error)?; + let id: String = row.get(0).map_err(map_oracle_error)?; + let resp_id: Option = row.get(1).map_err(map_oracle_error)?; + let item_type: String = row.get(2).map_err(map_oracle_error)?; + let role: Option = row.get(3).map_err(map_oracle_error)?; + let content_raw: Option = row.get(4).map_err(map_oracle_error)?; + let status: Option = row.get(5).map_err(map_oracle_error)?; + let created_at: DateTime = row.get(6).map_err(map_oracle_error)?; + out.push((id, resp_id, item_type, role, content_raw, status, created_at)); + } + Ok(out) + } + }) + .await?; + + // Map rows to ConversationItem; propagate JSON parse errors instead of swallowing + rows.into_iter() + .map( + |(id, resp_id, item_type, role, content_raw, status, created_at)| { + let content = match content_raw { + Some(s) => { + serde_json::from_str(&s).map_err(ConversationItemStorageError::from)? + } + None => Value::Null, + }; + Ok(ConversationItem { + id: ConversationItemId(id), + response_id: resp_id, + item_type, + role, + content, + status, + created_at, + }) + }, + ) + .collect() + } +} + +#[derive(Clone)] +struct ConversationItemOracleConnectionManager { + params: Arc, +} + +#[derive(Clone)] +struct OracleConnectParams { + username: String, + password: String, + connect_descriptor: String, +} + +impl ConversationItemOracleConnectionManager { + fn new(config: Arc) -> Self { + let params = OracleConnectParams { + username: config.username.clone(), + password: config.password.clone(), + connect_descriptor: config.connect_descriptor.clone(), + }; + Self { + params: Arc::new(params), + } + } +} + +impl std::fmt::Debug for ConversationItemOracleConnectionManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConversationItemOracleConnectionManager") + .field("username", &self.params.username) + .field("connect_descriptor", &self.params.connect_descriptor) + .finish() + } +} + +#[async_trait] +impl Manager for ConversationItemOracleConnectionManager { + type Type = Connection; + type Error = oracle::Error; + + fn create( + &self, + ) -> impl std::future::Future> + Send + { + let params = self.params.clone(); + async move { + let mut conn = Connection::connect( + ¶ms.username, + ¶ms.password, + ¶ms.connect_descriptor, + )?; + conn.set_autocommit(true); + Ok(conn) + } + } + + #[allow(clippy::manual_async_fn)] + fn recycle( + &self, + conn: &mut Connection, + _: &Metrics, + ) -> impl std::future::Future> + Send { + async move { conn.ping().map_err(RecycleError::Backend) } + } +} + +fn configure_oracle_client(config: &OracleConfig) -> ItemResult<()> { + if let Some(wallet_path) = &config.wallet_path { + let wallet_path = Path::new(wallet_path); + if !wallet_path.is_dir() { + return Err(ConversationItemStorageError::StorageError(format!( + "Oracle wallet/config path '{}' is not a directory", + wallet_path.display() + ))); + } + if !wallet_path.join("tnsnames.ora").exists() && !wallet_path.join("sqlnet.ora").exists() { + return Err(ConversationItemStorageError::StorageError(format!( + "Oracle wallet/config path '{}' is missing tnsnames.ora or sqlnet.ora", + wallet_path.display() + ))); + } + std::env::set_var("TNS_ADMIN", wallet_path); + } + Ok(()) +} + +fn initialize_schema(config: &OracleConfig) -> ItemResult<()> { + let conn = Connection::connect( + &config.username, + &config.password, + &config.connect_descriptor, + ) + .map_err(map_oracle_error)?; + + let exists_items: i64 = conn + .query_row_as( + "SELECT COUNT(*) FROM user_tables WHERE table_name = 'CONVERSATION_ITEMS'", + &[], + ) + .map_err(map_oracle_error)?; + if exists_items == 0 { + conn.execute( + "CREATE TABLE conversation_items ( + id VARCHAR2(64) PRIMARY KEY, + response_id VARCHAR2(64), + item_type VARCHAR2(32) NOT NULL, + role VARCHAR2(32), + content CLOB, + status VARCHAR2(32), + created_at TIMESTAMP WITH TIME ZONE + )", + &[], + ) + .map_err(map_oracle_error)?; + } + + let exists_links: i64 = conn + .query_row_as( + "SELECT COUNT(*) FROM user_tables WHERE table_name = 'CONVERSATION_ITEM_LINKS'", + &[], + ) + .map_err(map_oracle_error)?; + if exists_links == 0 { + conn.execute( + "CREATE TABLE conversation_item_links ( + conversation_id VARCHAR2(64) NOT NULL, + item_id VARCHAR2(64) NOT NULL, + added_at TIMESTAMP WITH TIME ZONE, + CONSTRAINT pk_conv_item_link PRIMARY KEY (conversation_id, item_id) + )", + &[], + ) + .map_err(map_oracle_error)?; + conn.execute( + "CREATE INDEX conv_item_links_conv_idx ON conversation_item_links (conversation_id, added_at)", + &[], + ) + .map_err(map_oracle_error)?; + } + + Ok(()) +} + +fn map_pool_error(err: PoolError) -> ConversationItemStorageError { + match err { + PoolError::Backend(e) => map_oracle_error(e), + other => ConversationItemStorageError::StorageError(format!( + "failed to obtain Oracle conversation item connection: {other}" + )), + } +} + +fn map_oracle_error(err: oracle::Error) -> ConversationItemStorageError { + if let Some(db_err) = err.db_error() { + ConversationItemStorageError::StorageError(format!( + "Oracle error (code {}): {}", + db_err.code(), + db_err.message() + )) + } else { + ConversationItemStorageError::StorageError(err.to_string()) + } +} diff --git a/sgl-router/src/data_connector/conversation_items.rs b/sgl-router/src/data_connector/conversation_items.rs new file mode 100644 index 000000000..58f624403 --- /dev/null +++ b/sgl-router/src/data_connector/conversation_items.rs @@ -0,0 +1,125 @@ +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use rand::RngCore; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use super::conversations::ConversationId; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] +pub struct ConversationItemId(pub String); + +impl Display for ConversationItemId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl From for ConversationItemId { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From<&str> for ConversationItemId { + fn from(value: &str) -> Self { + Self(value.to_string()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationItem { + pub id: ConversationItemId, + pub response_id: Option, + pub item_type: String, + pub role: Option, + pub content: Value, + pub status: Option, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NewConversationItem { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub id: Option, + pub response_id: Option, + pub item_type: String, + pub role: Option, + pub content: Value, + pub status: Option, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +pub enum SortOrder { + Asc, + Desc, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ListParams { + pub limit: usize, + pub order: SortOrder, + pub after: Option, // item_id cursor +} + +pub type Result = std::result::Result; + +#[derive(Debug, thiserror::Error)] +pub enum ConversationItemStorageError { + #[error("Not found: {0}")] + NotFound(String), + + #[error("Storage error: {0}")] + StorageError(String), + + #[error("Serialization error: {0}")] + SerializationError(#[from] serde_json::Error), +} + +#[async_trait] +pub trait ConversationItemStorage: Send + Sync + 'static { + async fn create_item(&self, item: NewConversationItem) -> Result; + + async fn link_item( + &self, + conversation_id: &ConversationId, + item_id: &ConversationItemId, + added_at: DateTime, + ) -> Result<()>; + + async fn list_items( + &self, + conversation_id: &ConversationId, + params: ListParams, + ) -> Result>; +} + +pub type SharedConversationItemStorage = Arc; + +/// Helper to build id prefix based on item_type +pub fn make_item_id(item_type: &str) -> ConversationItemId { + // Generate a 24-byte random hex string (48 hex chars), consistent with conversation id style + let mut rng = rand::rng(); + let mut bytes = [0u8; 24]; + rng.fill_bytes(&mut bytes); + let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect(); + + let prefix: String = match item_type { + "message" => "msg".to_string(), + "reasoning" => "rs".to_string(), + "mcp_call" => "mcp".to_string(), + "mcp_list_tools" => "mcpl".to_string(), + "function_tool_call" => "ftc".to_string(), + other => { + // Fallback: first 3 letters of type or "itm" + let mut p = other.chars().take(3).collect::(); + if p.is_empty() { + p = "itm".to_string(); + } + p + } + }; + ConversationItemId(format!("{}_{}", prefix, hex_string)) +} diff --git a/sgl-router/src/data_connector/mod.rs b/sgl-router/src/data_connector/mod.rs index 0877ad40b..32892d505 100644 --- a/sgl-router/src/data_connector/mod.rs +++ b/sgl-router/src/data_connector/mod.rs @@ -1,4 +1,7 @@ // Data connector module for response storage and conversation storage +pub mod conversation_item_memory_store; +pub mod conversation_item_oracle_store; +pub mod conversation_items; pub mod conversation_memory_store; pub mod conversation_noop_store; pub mod conversation_oracle_store; @@ -8,6 +11,14 @@ pub mod response_noop_store; pub mod response_oracle_store; pub mod responses; +pub use conversation_item_memory_store::MemoryConversationItemStorage; +pub use conversation_item_oracle_store::OracleConversationItemStorage; +pub use conversation_items::{ + ConversationItem, ConversationItemId, ConversationItemStorage, ConversationItemStorageError, + ListParams as ConversationItemsListParams, NewConversationItem, + Result as ConversationItemsResult, SharedConversationItemStorage, + SortOrder as ConversationItemsSortOrder, +}; pub use conversation_memory_store::MemoryConversationStorage; pub use conversation_noop_store::NoOpConversationStorage; pub use conversation_oracle_store::OracleConversationStorage; diff --git a/sgl-router/src/data_connector/response_oracle_store.rs b/sgl-router/src/data_connector/response_oracle_store.rs index 2622b59e2..5ad3fab5f 100644 --- a/sgl-router/src/data_connector/response_oracle_store.rs +++ b/sgl-router/src/data_connector/response_oracle_store.rs @@ -13,7 +13,7 @@ use std::sync::Arc; use std::time::Duration; const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \ - tool_calls, metadata, created_at, user_id, model, raw_response FROM responses"; + tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses"; #[derive(Clone)] pub struct OracleResponseStorage { @@ -95,8 +95,11 @@ impl OracleResponseStorage { let model: Option = row .get(9) .map_err(|err| map_oracle_error(err).into_storage_error("fetch model"))?; - let raw_response_json: Option = row + let conversation_id: Option = row .get(10) + .map_err(|err| map_oracle_error(err).into_storage_error("fetch conversation_id"))?; + let raw_response_json: Option = row + .get(11) .map_err(|err| map_oracle_error(err).into_storage_error("fetch raw_response"))?; let previous_response_id = previous.map(ResponseId); @@ -115,6 +118,7 @@ impl OracleResponseStorage { created_at, user: user_id, model, + conversation_id, raw_response, }) } @@ -134,6 +138,7 @@ impl ResponseStorage for OracleResponseStorage { created_at, user, model, + conversation_id, raw_response, } = response; @@ -147,8 +152,8 @@ impl ResponseStorage for OracleResponseStorage { self.with_connection(move |conn| { conn.execute( "INSERT INTO responses (id, previous_response_id, input, instructions, output, \ - tool_calls, metadata, created_at, user_id, model, raw_response) \ - VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11)", + tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response) \ + VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11, :12)", &[ &response_id_str, &previous_id, @@ -160,6 +165,7 @@ impl ResponseStorage for OracleResponseStorage { &created_at, &user, &model, + &conversation_id, &json_raw_response, ], ) @@ -394,6 +400,7 @@ fn initialize_schema(config: &OracleConfig) -> StorageResult<()> { conn.execute( "CREATE TABLE responses ( id VARCHAR2(64) PRIMARY KEY, + conversation_id VARCHAR2(64), previous_response_id VARCHAR2(64), input CLOB, instructions CLOB, diff --git a/sgl-router/src/data_connector/responses.rs b/sgl-router/src/data_connector/responses.rs index e0420f3c6..bb203652b 100644 --- a/sgl-router/src/data_connector/responses.rs +++ b/sgl-router/src/data_connector/responses.rs @@ -65,6 +65,10 @@ pub struct StoredResponse { /// Model used for generation pub model: Option, + /// Conversation id if associated with a conversation + #[serde(default)] + pub conversation_id: Option, + /// Raw OpenAI response payload #[serde(default)] pub raw_response: Value, @@ -83,6 +87,7 @@ impl StoredResponse { created_at: chrono::Utc::now(), user: None, model: None, + conversation_id: None, raw_response: Value::Null, } } diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index 7e12a8cfe..10998b718 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -1103,6 +1103,10 @@ pub struct ResponsesRequest { #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, + /// Optional conversation id to persist input/output as items + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation: Option, + /// Whether to enable parallel tool calls #[serde(default = "default_true")] pub parallel_tool_calls: bool, @@ -1214,6 +1218,7 @@ impl Default for ResponsesRequest { max_tool_calls: None, metadata: None, model: None, + conversation: None, parallel_tool_calls: true, previous_response_id: None, reasoning: None, diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index dad0144e1..eecf8f839 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -129,6 +129,7 @@ impl RouterFactory { Some(ctx.router_config.circuit_breaker.clone()), ctx.response_storage.clone(), ctx.conversation_storage.clone(), + ctx.conversation_item_storage.clone(), ) .await?; diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index 33034457f..a18ed3da6 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -3,8 +3,10 @@ use crate::config::CircuitBreakerConfig; use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}; use crate::data_connector::{ - Conversation, ConversationId, ConversationMetadata, ResponseId, SharedConversationStorage, - SharedResponseStorage, StoredResponse, + Conversation, ConversationId, ConversationItemsListParams, ConversationItemsSortOrder, + ConversationMetadata, NewConversationItem as DCNewConversationItem, ResponseId, + SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, + StoredResponse, }; use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, @@ -81,6 +83,8 @@ pub struct OpenAIRouter { response_storage: SharedResponseStorage, /// Conversation storage backend conversation_storage: SharedConversationStorage, + /// Conversation item storage backend + conversation_item_storage: SharedConversationItemStorage, /// Optional MCP manager (enabled via config presence) mcp_manager: Option>, } @@ -706,12 +710,15 @@ impl StreamingResponseAccumulator { } impl OpenAIRouter { + // Maximum number of conversation items to attach as input when a conversation is provided + const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100; /// Create a new OpenAI router pub async fn new( base_url: String, circuit_breaker_config: Option, response_storage: SharedResponseStorage, conversation_storage: SharedConversationStorage, + conversation_item_storage: SharedConversationItemStorage, ) -> Result { let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(300)) @@ -759,6 +766,7 @@ impl OpenAIRouter { healthy: AtomicBool::new(true), response_storage, conversation_storage, + conversation_item_storage, mcp_manager, }) } @@ -946,6 +954,12 @@ impl OpenAIRouter { // Mask tools back to MCP format for client Self::mask_tools_as_mcp(&mut final_response_json, original_body); + // Attach conversation id for client response if present (not forwarded upstream) + if let Some(conv_id) = original_body.conversation.clone() { + if let Some(obj) = final_response_json.as_object_mut() { + obj.insert("conversation".to_string(), json!({"id": conv_id})); + } + } if original_body.store { if let Err(e) = self .store_response_internal(&final_response_json, original_body) @@ -954,6 +968,18 @@ impl OpenAIRouter { warn!("Failed to store response: {}", e); } } + if let Some(conv_id) = original_body.conversation.clone() { + if let Err(err) = self + .persist_conversation_items( + &conv_id, + original_body, + &final_response_json, + ) + .await + { + warn!("Failed to persist conversation items: {}", err); + } + } match serde_json::to_string(&final_response_json) { Ok(json_str) => ( @@ -990,6 +1016,22 @@ impl OpenAIRouter { } } + async fn persist_conversation_items( + &self, + conversation_id: &str, + original_body: &ResponsesRequest, + final_response_json: &Value, + ) -> Result<(), String> { + persist_items_with_storages( + self.conversation_storage.clone(), + self.conversation_item_storage.clone(), + conversation_id.to_string(), + original_body.clone(), + final_response_json.clone(), + ) + .await + } + /// Build a request-scoped MCP manager from request tools, if present. async fn mcp_manager_from_request_tools( tools: &[ResponseTool], @@ -1123,7 +1165,10 @@ impl OpenAIRouter { let should_store = original_body.store; let storage = self.response_storage.clone(); + let conv_storage = self.conversation_storage.clone(); + let conv_item_storage = self.conversation_item_storage.clone(); let original_request = original_body.clone(); + let persist_needed = original_request.conversation.is_some(); let previous_response_id = original_previous_response_id.clone(); tokio::spawn(async move { @@ -1160,7 +1205,7 @@ impl OpenAIRouter { Cow::Borrowed(raw_block.as_str()) }; - if should_store { + if should_store || persist_needed { accumulator.ingest_block(block_cow.as_ref()); } @@ -1189,7 +1234,7 @@ impl OpenAIRouter { } } - if should_store && !upstream_failed { + if (should_store || persist_needed) && !upstream_failed { if !pending.trim().is_empty() { accumulator.ingest_block(&pending); } @@ -1201,10 +1246,28 @@ impl OpenAIRouter { previous_response_id.as_deref(), ); - if let Err(err) = - Self::store_response_impl(&storage, &response_json, &original_request).await - { - warn!("Failed to store streaming response: {}", err); + if should_store { + if let Err(err) = + Self::store_response_impl(&storage, &response_json, &original_request) + .await + { + warn!("Failed to store streaming response: {}", err); + } + } + if persist_needed { + if let Some(conv_id) = original_request.conversation.clone() { + if let Err(err) = persist_items_with_storages( + conv_storage.clone(), + conv_item_storage.clone(), + conv_id, + original_request.clone(), + response_json.clone(), + ) + .await + { + warn!("Failed to persist conversation items (stream): {}", err); + } + } } } else if let Some(error_payload) = encountered_error { warn!("Upstream streaming error payload: {}", error_payload); @@ -1683,7 +1746,10 @@ impl OpenAIRouter { let (tx, rx) = mpsc::unbounded_channel::>(); let should_store = original_body.store; let storage = self.response_storage.clone(); + let conv_storage = self.conversation_storage.clone(); + let conv_item_storage = self.conversation_item_storage.clone(); let original_request = original_body.clone(); + let persist_needed = original_request.conversation.is_some(); let previous_response_id = original_previous_response_id.clone(); let client = self.client.clone(); @@ -1901,30 +1967,33 @@ impl OpenAIRouter { return; } - // Send final events and done marker - if should_store { - if let Some(mut response_json) = handler.accumulator.into_final_response() { - if let Some(ref id) = preserved_response_id { - if let Some(obj) = response_json.as_object_mut() { - obj.insert("id".to_string(), Value::String(id.clone())); - } + let final_response_json = if should_store || persist_needed { + handler.accumulator.into_final_response() + } else { + None + }; + + if let Some(mut response_json) = final_response_json { + if let Some(ref id) = preserved_response_id { + if let Some(obj) = response_json.as_object_mut() { + obj.insert("id".to_string(), Value::String(id.clone())); } - Self::inject_mcp_metadata_streaming( - &mut response_json, - &state, - &active_mcp_clone, - server_label, - ); + } + Self::inject_mcp_metadata_streaming( + &mut response_json, + &state, + &active_mcp_clone, + server_label, + ); - // Mask tools back to MCP format - Self::mask_tools_as_mcp(&mut response_json, &original_request); - - Self::patch_streaming_response_json( - &mut response_json, - &original_request, - previous_response_id.as_deref(), - ); + Self::mask_tools_as_mcp(&mut response_json, &original_request); + Self::patch_streaming_response_json( + &mut response_json, + &original_request, + previous_response_id.as_deref(), + ); + if should_store { if let Err(err) = Self::store_response_impl( &storage, &response_json, @@ -1935,6 +2004,25 @@ impl OpenAIRouter { warn!("Failed to store streaming response: {}", err); } } + + if persist_needed { + if let Some(conv_id) = original_request.conversation.clone() { + if let Err(err) = persist_items_with_storages( + conv_storage.clone(), + conv_item_storage.clone(), + conv_id, + original_request.clone(), + response_json.clone(), + ) + .await + { + warn!( + "Failed to persist conversation items (stream + MCP): {}", + err + ); + } + } + } } let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); @@ -2332,6 +2420,11 @@ impl OpenAIRouter { .map(|s| s.to_string()) .or_else(|| original_body.user.clone()); + // Set conversation id from request if provided + if let Some(conv_id) = original_body.conversation.clone() { + stored_response.conversation_id = Some(conv_id); + } + stored_response.metadata = response_json .get("metadata") .and_then(|v| v.as_object()) @@ -2428,6 +2521,11 @@ impl OpenAIRouter { obj.insert("user".to_string(), Value::String(user.clone())); } } + + // Attach conversation id for client response if present (final aggregated JSON) + if let Some(conv_id) = original_body.conversation.clone() { + obj.insert("conversation".to_string(), json!({"id": conv_id})); + } } } @@ -2500,6 +2598,12 @@ impl OpenAIRouter { changed = true; } } + + // Attach conversation id into streaming event response content with ordering + if let Some(conv_id) = original_body.conversation.clone() { + response_obj.insert("conversation".to_string(), json!({"id": conv_id})); + changed = true; + } } if !changed { @@ -3389,11 +3493,30 @@ impl super::super::RouterTrait for OpenAIRouter { "openai_responses_request" ); + // Validate mutually exclusive params: previous_response_id and conversation + // TODO: this validation logic should move the right place, also we need a proper error message module + if body.previous_response_id.is_some() && body.conversation.is_some() { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": { + "message": "Mutually exclusive parameters. Ensure you are only providing one of: 'previous_response_id' or 'conversation'.", + "type": "invalid_request_error", + "param": Value::Null, + "code": "mutually_exclusive_parameters" + } + })), + ) + .into_response(); + } + // Clone the body and override model if needed let mut request_body = body.clone(); if let Some(model) = model_id { request_body.model = Some(model.to_string()); } + // Do not forward conversation field upstream; retain for local persistence only + request_body.conversation = None; // Store the original previous_response_id for the response let original_previous_response_id = request_body.previous_response_id.clone(); @@ -3448,6 +3571,75 @@ impl super::super::RouterTrait for OpenAIRouter { request_body.previous_response_id = None; } + // If conversation is provided, attach its items as input to upstream request + if let Some(conv_id_str) = body.conversation.clone() { + let conv_id: ConversationId = conv_id_str.as_str().into(); + let mut items: Vec = Vec::new(); + // Fetch up to MAX_CONVERSATION_HISTORY_ITEMS items in ascending order + let params = ConversationItemsListParams { + limit: Self::MAX_CONVERSATION_HISTORY_ITEMS, + order: ConversationItemsSortOrder::Asc, + after: None, + }; + match self + .conversation_item_storage + .list_items(&conv_id, params) + .await + { + Ok(stored_items) => { + for it in stored_items { + match it.item_type.as_str() { + "message" => { + // content is expected to be an array of ResponseContentPart + let parts: Vec = match serde_json::from_value( + it.content.clone(), + ) { + Ok(parts) => parts, + Err(e) => { + warn!( + item_id = %it.id.0, + error = %e, + "Failed to deserialize conversation item content; skipping message item" + ); + continue; + } + }; + let role = it.role.unwrap_or_else(|| "user".to_string()); + items.push(ResponseInputOutputItem::Message { + id: it.id.0, + role, + content: parts, + status: it.status, + }); + } + _ => { + // Skip unsupported types for request input (e.g., MCP items) + } + } + } + } + Err(err) => { + warn!(conversation_id = %conv_id.0, error = %err.to_string(), "Failed to load conversation items for request input"); + } + } + + // Append the current request input at the end + match &request_body.input { + ResponseInput::Text(text) => { + items.push(ResponseInputOutputItem::Message { + id: format!("msg_u_current_{}", items.len()), + role: "user".to_string(), + status: Some("completed".to_string()), + content: vec![ResponseContentPart::InputText { text: text.clone() }], + }); + } + ResponseInput::Items(existing) => { + items.extend(existing.clone()); + } + } + request_body.input = ResponseInput::Items(items); + } + if let Some(mut items) = conversation_items { match &request_body.input { ResponseInput::Text(text) => { @@ -3489,6 +3681,7 @@ impl super::super::RouterTrait for OpenAIRouter { "top_k", "min_p", "repetition_penalty", + "conversation", ] { obj.remove(key); } @@ -3973,6 +4166,113 @@ impl super::super::RouterTrait for OpenAIRouter { fn router_type(&self) -> &'static str { "openai" } + + async fn list_conversation_items( + &self, + _headers: Option<&HeaderMap>, + conversation_id: &str, + limit: Option, + order: Option, + after: Option, + ) -> Response { + let id: ConversationId = conversation_id.into(); + match self.conversation_storage.get_conversation(&id).await { + Ok(Some(_)) => {} + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({ + "error": { + "message": format!("Conversation with id '{}' not found.", conversation_id), + "type": "invalid_request_error", + "param": Value::Null, + "code": Value::Null + } + })), + ) + .into_response(); + } + Err(err) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": { + "message": err.to_string(), + "type": "internal_error", + "param": Value::Null, + "code": Value::Null + } + })), + ) + .into_response(); + } + } + + let lim = limit.unwrap_or(20).clamp(1, 100); + let sort = match order.as_deref() { + Some("asc") => ConversationItemsSortOrder::Asc, + _ => ConversationItemsSortOrder::Desc, + }; + let params = ConversationItemsListParams { + limit: lim + 1, + order: sort, + after, + }; + + match self.conversation_item_storage.list_items(&id, params).await { + Ok(mut items) => { + let has_more = items.len() > lim; + if has_more { + items.truncate(lim); + } + let data: Vec = items + .into_iter() + .map(|it| { + json!({ + "id": it.id.0, + "type": it.item_type, + "status": it.status.unwrap_or_else(|| "completed".to_string()), + "content": it.content, + "role": it.role, + }) + }) + .collect(); + let first_id = data + .first() + .and_then(|v| v.get("id")) + .cloned() + .unwrap_or(Value::Null); + let last_id = data + .last() + .and_then(|v| v.get("id")) + .cloned() + .unwrap_or(Value::Null); + ( + StatusCode::OK, + Json(json!({ + "object": "list", + "data": data, + "first_id": first_id, + "last_id": last_id, + "has_more": has_more + })), + ) + .into_response() + } + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": { + "message": err.to_string(), + "type": "internal_error", + "param": Value::Null, + "code": Value::Null + } + })), + ) + .into_response(), + } + } } // Maximum number of properties allowed in conversation metadata (align with server) const MAX_METADATA_PROPERTIES: usize = 16; @@ -3985,3 +4285,263 @@ fn conversation_to_json(conversation: &Conversation) -> Value { "metadata": to_value(&conversation.metadata).unwrap_or(Value::Null), }) } + +async fn persist_items_with_storages( + conv_storage: SharedConversationStorage, + item_storage: SharedConversationItemStorage, + conversation_id: String, + request: ResponsesRequest, + response: Value, +) -> Result<(), String> { + let conv_id: ConversationId = conversation_id.as_str().into(); + match conv_storage.get_conversation(&conv_id).await { + Ok(Some(_)) => {} + Ok(None) => { + warn!(conversation_id = %conv_id.0, "Conversation not found; skipping item persistence"); + return Ok(()); + } + Err(err) => return Err(err.to_string()), + } + + // Extract response_id once for attaching to both input and output items + let response_id_opt = response + .get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + // Helper to ensure status defaults to completed + async fn create_and_link_item( + item_storage: &SharedConversationItemStorage, + conv_id: &ConversationId, + mut new_item: DCNewConversationItem, + ) -> Result<(), String> { + if new_item.status.is_none() { + new_item.status = Some("completed".to_string()); + } + let created = item_storage + .create_item(new_item) + .await + .map_err(|e| e.to_string())?; + item_storage + .link_item(conv_id, &created.id, chrono::Utc::now()) + .await + .map_err(|e| e.to_string())?; + tracing::info!(conversation_id = %conv_id.0, item_id = %created.id.0, item_type = %created.item_type, "Persisted conversation item and link"); + Ok(()) + } + + match request.input.clone() { + ResponseInput::Text(text) => { + let new_item = DCNewConversationItem { + id: None, // generate new message id for input + response_id: response_id_opt.clone(), + item_type: "message".to_string(), + role: Some("user".to_string()), + content: json!([{ "type": "input_text", "text": text }]), + status: Some("completed".to_string()), + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + ResponseInput::Items(items) => { + for input_item in items { + match input_item { + ResponseInputOutputItem::Message { + role, + content, + status, + .. + } => { + let content_v = + serde_json::to_value(&content).map_err(|e| e.to_string())?; + let new_item = DCNewConversationItem { + id: None, // generate new id for input items + response_id: response_id_opt.clone(), + item_type: "message".to_string(), + role: Some(role), + content: content_v, + status, + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + ResponseInputOutputItem::Reasoning { + summary, + content, + status, + .. + } => { + let new_item = DCNewConversationItem { + id: None, // generate new id for input items + response_id: response_id_opt.clone(), + item_type: "reasoning".to_string(), + role: None, + content: json!({ "summary": summary, "content": content }), + status, + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + ResponseInputOutputItem::FunctionToolCall { + name, + arguments, + output, + status, + .. + } => { + let new_item = DCNewConversationItem { + id: None, // generate new id for input items + response_id: response_id_opt.clone(), + item_type: "function_tool_call".to_string(), + role: None, + content: json!({ "name": name, "arguments": arguments, "output": output }), + status, + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + } + } + } + } + + if let Some(output_array) = response.get("output").and_then(|v| v.as_array()) { + for item in output_array { + let item_type = match item.get("type").and_then(|v| v.as_str()) { + Some(t) => t, + None => continue, + }; + + match item_type { + "message" => { + let id_in = item + .get("id") + .and_then(|v| v.as_str()) + .map(|s| crate::data_connector::ConversationItemId(s.to_string())); + let role = item + .get("role") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let content_v = item + .get("content") + .cloned() + .unwrap_or_else(|| Value::Array(Vec::new())); + let status = item + .get("status") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let new_item = DCNewConversationItem { + id: id_in, + response_id: response_id_opt.clone(), + item_type: "message".to_string(), + role, + content: content_v, + status, + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + "reasoning" => { + let id_in = item + .get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let summary_v = item + .get("summary") + .cloned() + .unwrap_or_else(|| Value::Array(Vec::new())); + let content_v = item + .get("content") + .cloned() + .unwrap_or_else(|| Value::Array(Vec::new())); + let status = item + .get("status") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let new_item = DCNewConversationItem { + id: id_in.map(crate::data_connector::ConversationItemId), + response_id: response_id_opt.clone(), + item_type: "reasoning".to_string(), + role: None, + content: json!({ "summary": summary_v, "content": content_v }), + status, + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + "function_tool_call" => { + let id_in = item + .get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); + let arguments = item.get("arguments").and_then(|v| v.as_str()).unwrap_or(""); + let output_str = item.get("output").and_then(|v| v.as_str()).unwrap_or(""); + let status = item + .get("status") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let new_item = DCNewConversationItem { + id: id_in.map(crate::data_connector::ConversationItemId), + response_id: response_id_opt.clone(), + item_type: "function_tool_call".to_string(), + role: None, + content: json!({ + "name": name, + "arguments": arguments, + "output": output_str + }), + status, + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + "mcp_call" => { + let id_in = item + .get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); + let arguments = item.get("arguments").and_then(|v| v.as_str()).unwrap_or(""); + let output_str = item.get("output").and_then(|v| v.as_str()).unwrap_or(""); + let status = item + .get("status") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let content_v = json!({ + "server_label": item.get("server_label").cloned().unwrap_or(Value::Null), + "name": name, + "arguments": arguments, + "output": output_str, + "error": item.get("error").cloned().unwrap_or(Value::Null), + "approval_request_id": item.get("approval_request_id").cloned().unwrap_or(Value::Null) + }); + let new_item = DCNewConversationItem { + id: id_in.map(crate::data_connector::ConversationItemId), + response_id: response_id_opt.clone(), + item_type: "mcp_call".to_string(), + role: None, + content: content_v, + status, + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + "mcp_list_tools" => { + let id_in = item + .get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let content_v = json!({ + "server_label": item.get("server_label").cloned().unwrap_or(Value::Null), + "tools": item.get("tools").cloned().unwrap_or_else(|| Value::Array(Vec::new())) + }); + let new_item = DCNewConversationItem { + id: id_in.map(crate::data_connector::ConversationItemId), + response_id: response_id_opt.clone(), + item_type: "mcp_list_tools".to_string(), + role: None, + content: content_v, + status: Some("completed".to_string()), + }; + create_and_link_item(&item_storage, &conv_id, new_item).await?; + } + _ => {} + } + } + } + + Ok(()) +} diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 45b4a5fcd..a74503424 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -173,6 +173,22 @@ pub trait RouterTrait: Send + Sync + Debug { .into_response() } + /// List items for a conversation + async fn list_conversation_items( + &self, + _headers: Option<&HeaderMap>, + _conversation_id: &str, + _limit: Option, + _order: Option, + _after: Option, + ) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Conversation items list endpoint not implemented", + ) + .into_response() + } + /// Get router type name fn router_type(&self) -> &'static str; diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index cc05266bc..063aa660e 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -589,6 +589,31 @@ impl RouterTrait for RouterManager { .into_response() } } + + async fn list_conversation_items( + &self, + headers: Option<&HeaderMap>, + conversation_id: &str, + limit: Option, + order: Option, + after: Option, + ) -> Response { + let router = self.select_router_for_request(headers, None); + if let Some(router) = router { + router + .list_conversation_items(headers, conversation_id, limit, order, after) + .await + } else { + ( + StatusCode::NOT_FOUND, + format!( + "No router available to list conversation items for '{}'", + conversation_id + ), + ) + .into_response() + } + } } impl std::fmt::Debug for RouterManager { diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 1992daea5..f20c21b26 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -2,9 +2,10 @@ use crate::{ config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType}, data_connector::{ - MemoryConversationStorage, MemoryResponseStorage, NoOpConversationStorage, - NoOpResponseStorage, OracleConversationStorage, OracleResponseStorage, - SharedConversationStorage, SharedResponseStorage, + MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, + NoOpConversationStorage, NoOpResponseStorage, OracleConversationItemStorage, + OracleConversationStorage, OracleResponseStorage, SharedConversationStorage, + SharedResponseStorage, }, logging::{self, LoggingConfig}, metrics::{self, PrometheusConfig}, @@ -56,6 +57,7 @@ pub struct AppContext { pub router_manager: Option>, pub response_storage: SharedResponseStorage, pub conversation_storage: SharedConversationStorage, + pub conversation_item_storage: crate::data_connector::SharedConversationItemStorage, pub load_monitor: Option>, pub configured_reasoning_parser: Option, pub configured_tool_parser: Option, @@ -121,8 +123,8 @@ impl AppContext { format!("failed to initialize Oracle response storage: {err}") })?; - let conversation_storage = - OracleConversationStorage::new(oracle_cfg).map_err(|err| { + let conversation_storage = OracleConversationStorage::new(oracle_cfg.clone()) + .map_err(|err| { format!("failed to initialize Oracle conversation storage: {err}") })?; @@ -130,6 +132,20 @@ impl AppContext { } }; + // Conversation items storage (memory-backed for now) + let conversation_item_storage: crate::data_connector::SharedConversationItemStorage = + match router_config.history_backend { + HistoryBackend::Oracle => { + let oracle_cfg = router_config.oracle.clone().ok_or_else(|| { + "oracle configuration is required when history_backend=oracle".to_string() + })?; + Arc::new(OracleConversationItemStorage::new(oracle_cfg).map_err(|e| { + format!("failed to initialize Oracle conversation item storage: {e}") + })?) + } + _ => Arc::new(MemoryConversationItemStorage::new()), + }; + let load_monitor = Some(Arc::new(LoadMonitor::new( worker_registry.clone(), policy_registry.clone(), @@ -152,6 +168,7 @@ impl AppContext { router_manager, response_storage, conversation_storage, + conversation_item_storage, load_monitor, configured_reasoning_parser, configured_tool_parser, @@ -400,6 +417,29 @@ async fn v1_conversations_delete( .await } +#[derive(Deserialize, Default)] +struct ListItemsQuery { + limit: Option, + order: Option, + after: Option, +} + +async fn v1_conversations_list_items( + State(state): State>, + Path(conversation_id): Path, + Query(ListItemsQuery { + limit, + order, + after, + }): Query, + headers: http::HeaderMap, +) -> Response { + state + .router + .list_conversation_items(Some(&headers), &conversation_id, limit, order, after) + .await +} + #[derive(Deserialize)] struct AddWorkerQuery { url: String, @@ -674,6 +714,10 @@ pub fn build_app( .post(v1_conversations_update) .delete(v1_conversations_delete), ) + .route( + "/v1/conversations/{conversation_id}/items", + get(v1_conversations_list_items), + ) .route_layer(axum::middleware::from_fn_with_state( app_state.clone(), middleware::concurrency_limit_middleware, diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 9e317fd6b..00746bf1d 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -543,6 +543,9 @@ mod tests { router_manager: None, response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()), conversation_storage: Arc::new(crate::data_connector::MemoryConversationStorage::new()), + conversation_item_storage: Arc::new( + crate::data_connector::MemoryConversationItemStorage::new(), + ), load_monitor: None, configured_reasoning_parser: None, configured_tool_parser: None, diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index 528bf6a5a..0ab4d720e 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -125,6 +125,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { top_k: -1, min_p: 0.0, repetition_penalty: 1.0, + conversation: None, }; let resp = router @@ -371,6 +372,7 @@ fn test_responses_request_creation() { top_k: -1, min_p: 0.0, repetition_penalty: 1.0, + conversation: None, }; assert!(!request.is_stream()); @@ -411,6 +413,7 @@ fn test_sampling_params_conversion() { top_k: 10, min_p: 0.05, repetition_penalty: 1.1, + conversation: None, }; let params = request.to_sampling_params(1000, None); @@ -524,6 +527,7 @@ fn test_json_serialization() { top_k: 50, min_p: 0.1, repetition_penalty: 1.2, + conversation: None, }; let json = serde_json::to_string(&request).expect("Serialization should work"); @@ -651,6 +655,7 @@ async fn test_multi_turn_loop_with_mcp() { top_k: 50, min_p: 0.0, repetition_penalty: 1.0, + conversation: None, }; // Execute the request (this should trigger the multi-turn loop) @@ -828,6 +833,7 @@ async fn test_max_tool_calls_limit() { top_k: 50, min_p: 0.0, repetition_penalty: 1.0, + conversation: None, }; let response = router.route_responses(None, &req, None).await; @@ -1023,6 +1029,7 @@ async fn test_streaming_with_mcp_tool_calls() { top_k: 50, min_p: 0.0, repetition_penalty: 1.0, + conversation: None, }; let response = router.route_responses(None, &req, None).await; @@ -1301,6 +1308,7 @@ async fn test_streaming_multi_turn_with_mcp() { top_k: 50, min_p: 0.0, repetition_penalty: 1.0, + conversation: None, }; let response = router.route_responses(None, &req, None).await; diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index 680d6dcb8..1b53ed42f 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -9,6 +9,7 @@ use axum::{ Json, Router, }; use serde_json::json; +use sglang_router_rs::data_connector::MemoryConversationItemStorage; use sglang_router_rs::{ config::{ ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode, @@ -95,6 +96,7 @@ async fn test_openai_router_creation() { None, Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryConversationStorage::new()), + Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()), ) .await; @@ -113,6 +115,7 @@ async fn test_openai_router_server_info() { None, Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryConversationStorage::new()), + Arc::new(MemoryConversationItemStorage::new()), ) .await .unwrap(); @@ -143,6 +146,7 @@ async fn test_openai_router_models() { None, Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryConversationStorage::new()), + Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()), ) .await .unwrap(); @@ -222,6 +226,7 @@ async fn test_openai_router_responses_with_mock() { None, storage.clone(), Arc::new(MemoryConversationStorage::new()), + Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()), ) .await .unwrap(); @@ -482,6 +487,7 @@ async fn test_openai_router_responses_streaming_with_mock() { None, storage.clone(), Arc::new(MemoryConversationStorage::new()), + Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()), ) .await .unwrap(); @@ -586,6 +592,7 @@ async fn test_unsupported_endpoints() { None, Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryConversationStorage::new()), + Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()), ) .await .unwrap(); @@ -627,6 +634,7 @@ async fn test_openai_router_chat_completion_with_mock() { None, Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryConversationStorage::new()), + Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()), ) .await .unwrap(); @@ -669,6 +677,7 @@ async fn test_openai_e2e_with_server() { None, Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryConversationStorage::new()), + Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()), ) .await .unwrap(); @@ -739,6 +748,7 @@ async fn test_openai_router_chat_streaming_with_mock() { None, Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryConversationStorage::new()), + Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()), ) .await .unwrap(); @@ -792,6 +802,7 @@ async fn test_openai_router_circuit_breaker() { Some(cb_config), Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryConversationStorage::new()), + Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()), ) .await .unwrap(); @@ -820,6 +831,7 @@ async fn test_openai_router_models_auth_forwarding() { None, Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryConversationStorage::new()), + Arc::new(sglang_router_rs::data_connector::MemoryConversationItemStorage::new()), ) .await .unwrap();