[router] Support history management using conversation (#11339)
This commit is contained in:
260
sgl-router/src/data_connector/conversation_item_memory_store.rs
Normal file
260
sgl-router/src/data_connector/conversation_item_memory_store.rs
Normal file
@@ -0,0 +1,260 @@
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::sync::RwLock;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
use super::conversation_items::{
|
||||
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams,
|
||||
Result, SortOrder,
|
||||
};
|
||||
use super::conversations::ConversationId;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MemoryConversationItemStorage {
|
||||
items: RwLock<HashMap<ConversationItemId, ConversationItem>>, // item_id -> item
|
||||
#[allow(clippy::type_complexity)]
|
||||
links: RwLock<HashMap<ConversationId, BTreeMap<(i64, String), ConversationItemId>>>,
|
||||
// Per-conversation reverse index for fast after cursor lookup: item_id_str -> (ts, item_id_str)
|
||||
#[allow(clippy::type_complexity)]
|
||||
rev_index: RwLock<HashMap<ConversationId, HashMap<String, (i64, String)>>>,
|
||||
}
|
||||
|
||||
impl MemoryConversationItemStorage {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConversationItemStorage for MemoryConversationItemStorage {
|
||||
async fn create_item(
|
||||
&self,
|
||||
new_item: super::conversation_items::NewConversationItem,
|
||||
) -> Result<ConversationItem> {
|
||||
let id = new_item
|
||||
.id
|
||||
.clone()
|
||||
.unwrap_or_else(|| make_item_id(&new_item.item_type));
|
||||
let created_at = Utc::now();
|
||||
let item = ConversationItem {
|
||||
id: id.clone(),
|
||||
response_id: new_item.response_id,
|
||||
item_type: new_item.item_type,
|
||||
role: new_item.role,
|
||||
content: new_item.content,
|
||||
status: new_item.status,
|
||||
created_at,
|
||||
};
|
||||
let mut items = self.items.write().unwrap();
|
||||
items.insert(id.clone(), item.clone());
|
||||
Ok(item)
|
||||
}
|
||||
|
||||
async fn link_item(
|
||||
&self,
|
||||
conversation_id: &ConversationId,
|
||||
item_id: &ConversationItemId,
|
||||
added_at: DateTime<Utc>,
|
||||
) -> Result<()> {
|
||||
{
|
||||
let mut links = self.links.write().unwrap();
|
||||
let entry = links.entry(conversation_id.clone()).or_default();
|
||||
entry.insert((added_at.timestamp(), item_id.0.clone()), item_id.clone());
|
||||
}
|
||||
{
|
||||
let mut rev = self.rev_index.write().unwrap();
|
||||
let entry = rev.entry(conversation_id.clone()).or_default();
|
||||
entry.insert(item_id.0.clone(), (added_at.timestamp(), item_id.0.clone()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_items(
|
||||
&self,
|
||||
conversation_id: &ConversationId,
|
||||
params: ListParams,
|
||||
) -> Result<Vec<ConversationItem>> {
|
||||
let links_guard = self.links.read().unwrap();
|
||||
let map = match links_guard.get(conversation_id) {
|
||||
Some(m) => m,
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
let mut results: Vec<ConversationItem> = Vec::new();
|
||||
let after_key: Option<(i64, String)> = if let Some(after_id) = ¶ms.after {
|
||||
// O(1) lookup via reverse index for this conversation
|
||||
if let Some(conv_idx) = self.rev_index.read().unwrap().get(conversation_id) {
|
||||
conv_idx.get(after_id).cloned()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let take = params.limit;
|
||||
let items_guard = self.items.read().unwrap();
|
||||
|
||||
use std::ops::Bound::{Excluded, Unbounded};
|
||||
|
||||
// Helper to push item if it exists and stop when reaching the limit
|
||||
let mut push_item = |key: &ConversationItemId| -> bool {
|
||||
if let Some(it) = items_guard.get(key) {
|
||||
results.push(it.clone());
|
||||
if results.len() == take {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
};
|
||||
|
||||
match (params.order, after_key) {
|
||||
(SortOrder::Desc, Some(k)) => {
|
||||
for ((_ts, _id), item_key) in map.range(..k).rev() {
|
||||
if push_item(item_key) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
(SortOrder::Desc, None) => {
|
||||
for ((_ts, _id), item_key) in map.iter().rev() {
|
||||
if push_item(item_key) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
(SortOrder::Asc, Some(k)) => {
|
||||
for ((_ts, _id), item_key) in map.range((Excluded(k), Unbounded)) {
|
||||
if push_item(item_key) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
(SortOrder::Asc, None) => {
|
||||
for ((_ts, _id), item_key) in map.iter() {
|
||||
if push_item(item_key) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::{TimeZone, Utc};
|
||||
|
||||
fn make_item(
|
||||
item_type: &str,
|
||||
role: Option<&str>,
|
||||
content: serde_json::Value,
|
||||
) -> super::super::conversation_items::NewConversationItem {
|
||||
super::super::conversation_items::NewConversationItem {
|
||||
id: None,
|
||||
response_id: None,
|
||||
item_type: item_type.to_string(),
|
||||
role: role.map(|r| r.to_string()),
|
||||
content,
|
||||
status: Some("completed".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_ordering_and_cursors() {
|
||||
let store = MemoryConversationItemStorage::new();
|
||||
let conv: ConversationId = "conv_test".into();
|
||||
|
||||
// Create 3 items and link them at controlled timestamps
|
||||
let i1 = store
|
||||
.create_item(make_item("message", Some("user"), serde_json::json!([])))
|
||||
.await
|
||||
.unwrap();
|
||||
let i2 = store
|
||||
.create_item(make_item(
|
||||
"message",
|
||||
Some("assistant"),
|
||||
serde_json::json!([]),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
let i3 = store
|
||||
.create_item(make_item("reasoning", None, serde_json::json!([])))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let t1 = Utc.timestamp_opt(1_700_000_001, 0).single().unwrap();
|
||||
let t2 = Utc.timestamp_opt(1_700_000_002, 0).single().unwrap();
|
||||
let t3 = Utc.timestamp_opt(1_700_000_003, 0).single().unwrap();
|
||||
|
||||
store.link_item(&conv, &i1.id, t1).await.unwrap();
|
||||
store.link_item(&conv, &i2.id, t2).await.unwrap();
|
||||
store.link_item(&conv, &i3.id, t3).await.unwrap();
|
||||
|
||||
// Desc order, no cursor
|
||||
let desc = store
|
||||
.list_items(
|
||||
&conv,
|
||||
ListParams {
|
||||
limit: 2,
|
||||
order: SortOrder::Desc,
|
||||
after: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(desc.len() >= 2);
|
||||
assert_eq!(desc[0].id, i3.id);
|
||||
assert_eq!(desc[1].id, i2.id);
|
||||
|
||||
// Desc with cursor = i2 -> expect i1 next
|
||||
let desc_after = store
|
||||
.list_items(
|
||||
&conv,
|
||||
ListParams {
|
||||
limit: 2,
|
||||
order: SortOrder::Desc,
|
||||
after: Some(i2.id.0.clone()),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!desc_after.is_empty());
|
||||
assert_eq!(desc_after[0].id, i1.id);
|
||||
|
||||
// Asc order, no cursor
|
||||
let asc = store
|
||||
.list_items(
|
||||
&conv,
|
||||
ListParams {
|
||||
limit: 2,
|
||||
order: SortOrder::Asc,
|
||||
after: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(asc.len() >= 2);
|
||||
assert_eq!(asc[0].id, i1.id);
|
||||
assert_eq!(asc[1].id, i2.id);
|
||||
|
||||
// Asc with cursor = i2 -> expect i3 next
|
||||
let asc_after = store
|
||||
.list_items(
|
||||
&conv,
|
||||
ListParams {
|
||||
limit: 2,
|
||||
order: SortOrder::Asc,
|
||||
after: Some(i2.id.0.clone()),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!asc_after.is_empty());
|
||||
assert_eq!(asc_after[0].id, i3.id);
|
||||
}
|
||||
}
|
||||
409
sgl-router/src/data_connector/conversation_item_oracle_store.rs
Normal file
409
sgl-router/src/data_connector/conversation_item_oracle_store.rs
Normal file
@@ -0,0 +1,409 @@
|
||||
use crate::config::OracleConfig;
|
||||
use crate::data_connector::conversation_items::{
|
||||
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage,
|
||||
ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder,
|
||||
};
|
||||
use crate::data_connector::conversations::ConversationId;
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
|
||||
use oracle::sql_type::ToSql;
|
||||
use oracle::Connection;
|
||||
use serde_json::Value;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OracleConversationItemStorage {
|
||||
pool: Pool<ConversationItemOracleConnectionManager>,
|
||||
}
|
||||
|
||||
impl OracleConversationItemStorage {
|
||||
pub fn new(config: OracleConfig) -> ItemResult<Self> {
|
||||
configure_oracle_client(&config)?;
|
||||
initialize_schema(&config)?;
|
||||
|
||||
let config = Arc::new(config);
|
||||
let manager = ConversationItemOracleConnectionManager::new(config.clone());
|
||||
let mut builder = Pool::builder(manager)
|
||||
.max_size(config.pool_max)
|
||||
.runtime(deadpool::Runtime::Tokio1);
|
||||
if config.pool_timeout_secs > 0 {
|
||||
builder = builder.wait_timeout(Some(Duration::from_secs(config.pool_timeout_secs)));
|
||||
}
|
||||
let pool = builder.build().map_err(|err| {
|
||||
ConversationItemStorageError::StorageError(format!(
|
||||
"failed to build Oracle pool for conversation items: {err}"
|
||||
))
|
||||
})?;
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
async fn with_connection<F, T>(&self, func: F) -> ItemResult<T>
|
||||
where
|
||||
F: FnOnce(&Connection) -> ItemResult<T> + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let connection = self.pool.get().await.map_err(map_pool_error)?;
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let result = func(&connection);
|
||||
drop(connection);
|
||||
result
|
||||
})
|
||||
.await
|
||||
.map_err(|err| {
|
||||
ConversationItemStorageError::StorageError(format!(
|
||||
"failed to execute Oracle conversation item task: {err}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
|
||||
// reserved for future use when parsing JSON columns directly into Value
|
||||
// fn parse_json(raw: Option<String>) -> ItemResult<Value> { ... }
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConversationItemStorage for OracleConversationItemStorage {
|
||||
async fn create_item(
|
||||
&self,
|
||||
item: crate::data_connector::conversation_items::NewConversationItem,
|
||||
) -> ItemResult<ConversationItem> {
|
||||
let id = item
|
||||
.id
|
||||
.clone()
|
||||
.unwrap_or_else(|| make_item_id(&item.item_type));
|
||||
let created_at = Utc::now();
|
||||
let content_json = serde_json::to_string(&item.content)?;
|
||||
|
||||
// Build the return value up-front; move inexpensive clones as needed for SQL
|
||||
let conversation_item = ConversationItem {
|
||||
id: id.clone(),
|
||||
response_id: item.response_id.clone(),
|
||||
item_type: item.item_type.clone(),
|
||||
role: item.role.clone(),
|
||||
content: item.content,
|
||||
status: item.status.clone(),
|
||||
created_at,
|
||||
};
|
||||
|
||||
// Prepare values for SQL insertion
|
||||
let id_str = conversation_item.id.0.clone();
|
||||
let response_id = conversation_item.response_id.clone();
|
||||
let item_type = conversation_item.item_type.clone();
|
||||
let role = conversation_item.role.clone();
|
||||
let status = conversation_item.status.clone();
|
||||
|
||||
self.with_connection(move |conn| {
|
||||
conn.execute(
|
||||
"INSERT INTO conversation_items (id, response_id, item_type, role, content, status, created_at) \
|
||||
VALUES (:1, :2, :3, :4, :5, :6, :7)",
|
||||
&[&id_str, &response_id, &item_type, &role, &content_json, &status, &created_at],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(conversation_item)
|
||||
}
|
||||
|
||||
async fn link_item(
|
||||
&self,
|
||||
conversation_id: &ConversationId,
|
||||
item_id: &ConversationItemId,
|
||||
added_at: DateTime<Utc>,
|
||||
) -> ItemResult<()> {
|
||||
let cid = conversation_id.0.clone();
|
||||
let iid = item_id.0.clone();
|
||||
self.with_connection(move |conn| {
|
||||
conn.execute(
|
||||
"INSERT INTO conversation_item_links (conversation_id, item_id, added_at) VALUES (:1, :2, :3)",
|
||||
&[&cid, &iid, &added_at],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn list_items(
|
||||
&self,
|
||||
conversation_id: &ConversationId,
|
||||
params: ListParams,
|
||||
) -> ItemResult<Vec<ConversationItem>> {
|
||||
let cid = conversation_id.0.clone();
|
||||
let limit: i64 = params.limit as i64;
|
||||
let order_desc = matches!(params.order, SortOrder::Desc);
|
||||
let after_id = params.after.clone();
|
||||
|
||||
// Resolve the added_at of the after cursor if provided
|
||||
let after_key: Option<(DateTime<Utc>, String)> = if let Some(ref aid) = after_id {
|
||||
self.with_connection({
|
||||
let cid = cid.clone();
|
||||
let aid = aid.clone();
|
||||
move |conn| {
|
||||
let mut stmt = conn
|
||||
.statement(
|
||||
"SELECT added_at FROM conversation_item_links WHERE conversation_id = :1 AND item_id = :2",
|
||||
)
|
||||
.build()
|
||||
.map_err(map_oracle_error)?;
|
||||
let mut rows = stmt.query(&[&cid, &aid]).map_err(map_oracle_error)?;
|
||||
if let Some(row_res) = rows.next() {
|
||||
let row = row_res.map_err(map_oracle_error)?;
|
||||
let ts: DateTime<Utc> = row.get(0).map_err(map_oracle_error)?;
|
||||
Ok(Some((ts, aid)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
})
|
||||
.await?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build the main list query
|
||||
let rows: Vec<(String, Option<String>, String, Option<String>, Option<String>, Option<String>, DateTime<Utc>)> =
|
||||
self.with_connection({
|
||||
let cid = cid.clone();
|
||||
move |conn| {
|
||||
let mut sql = String::from(
|
||||
"SELECT i.id, i.response_id, i.item_type, i.role, i.content, i.status, i.created_at \
|
||||
FROM conversation_item_links l \
|
||||
JOIN conversation_items i ON i.id = l.item_id \
|
||||
WHERE l.conversation_id = :cid",
|
||||
);
|
||||
|
||||
// Cursor predicate
|
||||
if let Some((_ts, _iid)) = &after_key {
|
||||
if order_desc {
|
||||
sql.push_str(" AND (l.added_at < :ats OR (l.added_at = :ats AND l.item_id < :iid))");
|
||||
} else {
|
||||
sql.push_str(" AND (l.added_at > :ats OR (l.added_at = :ats AND l.item_id > :iid))");
|
||||
}
|
||||
}
|
||||
|
||||
// Order and limit
|
||||
if order_desc {
|
||||
sql.push_str(" ORDER BY l.added_at DESC, l.item_id DESC");
|
||||
} else {
|
||||
sql.push_str(" ORDER BY l.added_at ASC, l.item_id ASC");
|
||||
}
|
||||
sql.push_str(" FETCH NEXT :limit ROWS ONLY");
|
||||
|
||||
// Build params and perform a named SELECT query
|
||||
let mut params_vec: Vec<(&str, &dyn ToSql)> = vec![("cid", &cid)];
|
||||
if let Some((ts, iid)) = &after_key {
|
||||
params_vec.push(("ats", ts));
|
||||
params_vec.push(("iid", iid));
|
||||
}
|
||||
params_vec.push(("limit", &limit));
|
||||
|
||||
let rows_iter = conn.query_named(&sql, ¶ms_vec).map_err(map_oracle_error)?;
|
||||
|
||||
let mut out = Vec::new();
|
||||
for row_res in rows_iter {
|
||||
let row = row_res.map_err(map_oracle_error)?;
|
||||
let id: String = row.get(0).map_err(map_oracle_error)?;
|
||||
let resp_id: Option<String> = row.get(1).map_err(map_oracle_error)?;
|
||||
let item_type: String = row.get(2).map_err(map_oracle_error)?;
|
||||
let role: Option<String> = row.get(3).map_err(map_oracle_error)?;
|
||||
let content_raw: Option<String> = row.get(4).map_err(map_oracle_error)?;
|
||||
let status: Option<String> = row.get(5).map_err(map_oracle_error)?;
|
||||
let created_at: DateTime<Utc> = row.get(6).map_err(map_oracle_error)?;
|
||||
out.push((id, resp_id, item_type, role, content_raw, status, created_at));
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Map rows to ConversationItem; propagate JSON parse errors instead of swallowing
|
||||
rows.into_iter()
|
||||
.map(
|
||||
|(id, resp_id, item_type, role, content_raw, status, created_at)| {
|
||||
let content = match content_raw {
|
||||
Some(s) => {
|
||||
serde_json::from_str(&s).map_err(ConversationItemStorageError::from)?
|
||||
}
|
||||
None => Value::Null,
|
||||
};
|
||||
Ok(ConversationItem {
|
||||
id: ConversationItemId(id),
|
||||
response_id: resp_id,
|
||||
item_type,
|
||||
role,
|
||||
content,
|
||||
status,
|
||||
created_at,
|
||||
})
|
||||
},
|
||||
)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConversationItemOracleConnectionManager {
|
||||
params: Arc<OracleConnectParams>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct OracleConnectParams {
|
||||
username: String,
|
||||
password: String,
|
||||
connect_descriptor: String,
|
||||
}
|
||||
|
||||
impl ConversationItemOracleConnectionManager {
|
||||
fn new(config: Arc<OracleConfig>) -> Self {
|
||||
let params = OracleConnectParams {
|
||||
username: config.username.clone(),
|
||||
password: config.password.clone(),
|
||||
connect_descriptor: config.connect_descriptor.clone(),
|
||||
};
|
||||
Self {
|
||||
params: Arc::new(params),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ConversationItemOracleConnectionManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ConversationItemOracleConnectionManager")
|
||||
.field("username", &self.params.username)
|
||||
.field("connect_descriptor", &self.params.connect_descriptor)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Manager for ConversationItemOracleConnectionManager {
|
||||
type Type = Connection;
|
||||
type Error = oracle::Error;
|
||||
|
||||
fn create(
|
||||
&self,
|
||||
) -> impl std::future::Future<Output = std::result::Result<Connection, oracle::Error>> + Send
|
||||
{
|
||||
let params = self.params.clone();
|
||||
async move {
|
||||
let mut conn = Connection::connect(
|
||||
¶ms.username,
|
||||
¶ms.password,
|
||||
¶ms.connect_descriptor,
|
||||
)?;
|
||||
conn.set_autocommit(true);
|
||||
Ok(conn)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::manual_async_fn)]
|
||||
fn recycle(
|
||||
&self,
|
||||
conn: &mut Connection,
|
||||
_: &Metrics,
|
||||
) -> impl std::future::Future<Output = RecycleResult<Self::Error>> + Send {
|
||||
async move { conn.ping().map_err(RecycleError::Backend) }
|
||||
}
|
||||
}
|
||||
|
||||
fn configure_oracle_client(config: &OracleConfig) -> ItemResult<()> {
|
||||
if let Some(wallet_path) = &config.wallet_path {
|
||||
let wallet_path = Path::new(wallet_path);
|
||||
if !wallet_path.is_dir() {
|
||||
return Err(ConversationItemStorageError::StorageError(format!(
|
||||
"Oracle wallet/config path '{}' is not a directory",
|
||||
wallet_path.display()
|
||||
)));
|
||||
}
|
||||
if !wallet_path.join("tnsnames.ora").exists() && !wallet_path.join("sqlnet.ora").exists() {
|
||||
return Err(ConversationItemStorageError::StorageError(format!(
|
||||
"Oracle wallet/config path '{}' is missing tnsnames.ora or sqlnet.ora",
|
||||
wallet_path.display()
|
||||
)));
|
||||
}
|
||||
std::env::set_var("TNS_ADMIN", wallet_path);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn initialize_schema(config: &OracleConfig) -> ItemResult<()> {
|
||||
let conn = Connection::connect(
|
||||
&config.username,
|
||||
&config.password,
|
||||
&config.connect_descriptor,
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
|
||||
let exists_items: i64 = conn
|
||||
.query_row_as(
|
||||
"SELECT COUNT(*) FROM user_tables WHERE table_name = 'CONVERSATION_ITEMS'",
|
||||
&[],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
if exists_items == 0 {
|
||||
conn.execute(
|
||||
"CREATE TABLE conversation_items (
|
||||
id VARCHAR2(64) PRIMARY KEY,
|
||||
response_id VARCHAR2(64),
|
||||
item_type VARCHAR2(32) NOT NULL,
|
||||
role VARCHAR2(32),
|
||||
content CLOB,
|
||||
status VARCHAR2(32),
|
||||
created_at TIMESTAMP WITH TIME ZONE
|
||||
)",
|
||||
&[],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
}
|
||||
|
||||
let exists_links: i64 = conn
|
||||
.query_row_as(
|
||||
"SELECT COUNT(*) FROM user_tables WHERE table_name = 'CONVERSATION_ITEM_LINKS'",
|
||||
&[],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
if exists_links == 0 {
|
||||
conn.execute(
|
||||
"CREATE TABLE conversation_item_links (
|
||||
conversation_id VARCHAR2(64) NOT NULL,
|
||||
item_id VARCHAR2(64) NOT NULL,
|
||||
added_at TIMESTAMP WITH TIME ZONE,
|
||||
CONSTRAINT pk_conv_item_link PRIMARY KEY (conversation_id, item_id)
|
||||
)",
|
||||
&[],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
conn.execute(
|
||||
"CREATE INDEX conv_item_links_conv_idx ON conversation_item_links (conversation_id, added_at)",
|
||||
&[],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn map_pool_error(err: PoolError<oracle::Error>) -> ConversationItemStorageError {
|
||||
match err {
|
||||
PoolError::Backend(e) => map_oracle_error(e),
|
||||
other => ConversationItemStorageError::StorageError(format!(
|
||||
"failed to obtain Oracle conversation item connection: {other}"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn map_oracle_error(err: oracle::Error) -> ConversationItemStorageError {
|
||||
if let Some(db_err) = err.db_error() {
|
||||
ConversationItemStorageError::StorageError(format!(
|
||||
"Oracle error (code {}): {}",
|
||||
db_err.code(),
|
||||
db_err.message()
|
||||
))
|
||||
} else {
|
||||
ConversationItemStorageError::StorageError(err.to_string())
|
||||
}
|
||||
}
|
||||
125
sgl-router/src/data_connector/conversation_items.rs
Normal file
125
sgl-router/src/data_connector/conversation_items.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use rand::RngCore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::conversations::ConversationId;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct ConversationItemId(pub String);
|
||||
|
||||
impl Display for ConversationItemId {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for ConversationItemId {
|
||||
fn from(value: String) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for ConversationItemId {
|
||||
fn from(value: &str) -> Self {
|
||||
Self(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConversationItem {
|
||||
pub id: ConversationItemId,
|
||||
pub response_id: Option<String>,
|
||||
pub item_type: String,
|
||||
pub role: Option<String>,
|
||||
pub content: Value,
|
||||
pub status: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NewConversationItem {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<ConversationItemId>,
|
||||
pub response_id: Option<String>,
|
||||
pub item_type: String,
|
||||
pub role: Option<String>,
|
||||
pub content: Value,
|
||||
pub status: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum SortOrder {
|
||||
Asc,
|
||||
Desc,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ListParams {
|
||||
pub limit: usize,
|
||||
pub order: SortOrder,
|
||||
pub after: Option<String>, // item_id cursor
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ConversationItemStorageError>;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConversationItemStorageError {
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Storage error: {0}")]
|
||||
StorageError(String),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
SerializationError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ConversationItemStorage: Send + Sync + 'static {
|
||||
async fn create_item(&self, item: NewConversationItem) -> Result<ConversationItem>;
|
||||
|
||||
async fn link_item(
|
||||
&self,
|
||||
conversation_id: &ConversationId,
|
||||
item_id: &ConversationItemId,
|
||||
added_at: DateTime<Utc>,
|
||||
) -> Result<()>;
|
||||
|
||||
async fn list_items(
|
||||
&self,
|
||||
conversation_id: &ConversationId,
|
||||
params: ListParams,
|
||||
) -> Result<Vec<ConversationItem>>;
|
||||
}
|
||||
|
||||
pub type SharedConversationItemStorage = Arc<dyn ConversationItemStorage>;
|
||||
|
||||
/// Helper to build id prefix based on item_type
|
||||
pub fn make_item_id(item_type: &str) -> ConversationItemId {
|
||||
// Generate a 24-byte random hex string (48 hex chars), consistent with conversation id style
|
||||
let mut rng = rand::rng();
|
||||
let mut bytes = [0u8; 24];
|
||||
rng.fill_bytes(&mut bytes);
|
||||
let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
|
||||
|
||||
let prefix: String = match item_type {
|
||||
"message" => "msg".to_string(),
|
||||
"reasoning" => "rs".to_string(),
|
||||
"mcp_call" => "mcp".to_string(),
|
||||
"mcp_list_tools" => "mcpl".to_string(),
|
||||
"function_tool_call" => "ftc".to_string(),
|
||||
other => {
|
||||
// Fallback: first 3 letters of type or "itm"
|
||||
let mut p = other.chars().take(3).collect::<String>();
|
||||
if p.is_empty() {
|
||||
p = "itm".to_string();
|
||||
}
|
||||
p
|
||||
}
|
||||
};
|
||||
ConversationItemId(format!("{}_{}", prefix, hex_string))
|
||||
}
|
||||
@@ -1,4 +1,7 @@
|
||||
// Data connector module for response storage and conversation storage
|
||||
pub mod conversation_item_memory_store;
|
||||
pub mod conversation_item_oracle_store;
|
||||
pub mod conversation_items;
|
||||
pub mod conversation_memory_store;
|
||||
pub mod conversation_noop_store;
|
||||
pub mod conversation_oracle_store;
|
||||
@@ -8,6 +11,14 @@ pub mod response_noop_store;
|
||||
pub mod response_oracle_store;
|
||||
pub mod responses;
|
||||
|
||||
pub use conversation_item_memory_store::MemoryConversationItemStorage;
|
||||
pub use conversation_item_oracle_store::OracleConversationItemStorage;
|
||||
pub use conversation_items::{
|
||||
ConversationItem, ConversationItemId, ConversationItemStorage, ConversationItemStorageError,
|
||||
ListParams as ConversationItemsListParams, NewConversationItem,
|
||||
Result as ConversationItemsResult, SharedConversationItemStorage,
|
||||
SortOrder as ConversationItemsSortOrder,
|
||||
};
|
||||
pub use conversation_memory_store::MemoryConversationStorage;
|
||||
pub use conversation_noop_store::NoOpConversationStorage;
|
||||
pub use conversation_oracle_store::OracleConversationStorage;
|
||||
|
||||
@@ -13,7 +13,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \
|
||||
tool_calls, metadata, created_at, user_id, model, raw_response FROM responses";
|
||||
tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OracleResponseStorage {
|
||||
@@ -95,8 +95,11 @@ impl OracleResponseStorage {
|
||||
let model: Option<String> = row
|
||||
.get(9)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch model"))?;
|
||||
let raw_response_json: Option<String> = row
|
||||
let conversation_id: Option<String> = row
|
||||
.get(10)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch conversation_id"))?;
|
||||
let raw_response_json: Option<String> = row
|
||||
.get(11)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch raw_response"))?;
|
||||
|
||||
let previous_response_id = previous.map(ResponseId);
|
||||
@@ -115,6 +118,7 @@ impl OracleResponseStorage {
|
||||
created_at,
|
||||
user: user_id,
|
||||
model,
|
||||
conversation_id,
|
||||
raw_response,
|
||||
})
|
||||
}
|
||||
@@ -134,6 +138,7 @@ impl ResponseStorage for OracleResponseStorage {
|
||||
created_at,
|
||||
user,
|
||||
model,
|
||||
conversation_id,
|
||||
raw_response,
|
||||
} = response;
|
||||
|
||||
@@ -147,8 +152,8 @@ impl ResponseStorage for OracleResponseStorage {
|
||||
self.with_connection(move |conn| {
|
||||
conn.execute(
|
||||
"INSERT INTO responses (id, previous_response_id, input, instructions, output, \
|
||||
tool_calls, metadata, created_at, user_id, model, raw_response) \
|
||||
VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11)",
|
||||
tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response) \
|
||||
VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11, :12)",
|
||||
&[
|
||||
&response_id_str,
|
||||
&previous_id,
|
||||
@@ -160,6 +165,7 @@ impl ResponseStorage for OracleResponseStorage {
|
||||
&created_at,
|
||||
&user,
|
||||
&model,
|
||||
&conversation_id,
|
||||
&json_raw_response,
|
||||
],
|
||||
)
|
||||
@@ -394,6 +400,7 @@ fn initialize_schema(config: &OracleConfig) -> StorageResult<()> {
|
||||
conn.execute(
|
||||
"CREATE TABLE responses (
|
||||
id VARCHAR2(64) PRIMARY KEY,
|
||||
conversation_id VARCHAR2(64),
|
||||
previous_response_id VARCHAR2(64),
|
||||
input CLOB,
|
||||
instructions CLOB,
|
||||
|
||||
@@ -65,6 +65,10 @@ pub struct StoredResponse {
|
||||
/// Model used for generation
|
||||
pub model: Option<String>,
|
||||
|
||||
/// Conversation id if associated with a conversation
|
||||
#[serde(default)]
|
||||
pub conversation_id: Option<String>,
|
||||
|
||||
/// Raw OpenAI response payload
|
||||
#[serde(default)]
|
||||
pub raw_response: Value,
|
||||
@@ -83,6 +87,7 @@ impl StoredResponse {
|
||||
created_at: chrono::Utc::now(),
|
||||
user: None,
|
||||
model: None,
|
||||
conversation_id: None,
|
||||
raw_response: Value::Null,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user