[router] support Openai router conversation API CRUD (#11297)

This commit is contained in:
Keyang Ru
2025-10-07 15:31:35 -07:00
committed by GitHub
parent cd4b39a900
commit 4ed67c27e3
15 changed files with 1258 additions and 45 deletions

View File

@@ -0,0 +1,57 @@
use async_trait::async_trait;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use super::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage, NewConversation,
Result,
};
/// In-memory conversation storage used for development and tests
#[derive(Default, Clone)]
pub struct MemoryConversationStorage {
inner: Arc<RwLock<HashMap<ConversationId, Conversation>>>,
}
impl MemoryConversationStorage {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[async_trait]
impl ConversationStorage for MemoryConversationStorage {
async fn create_conversation(&self, input: NewConversation) -> Result<Conversation> {
let conversation = Conversation::new(input);
self.inner
.write()
.insert(conversation.id.clone(), conversation.clone());
Ok(conversation)
}
async fn get_conversation(&self, id: &ConversationId) -> Result<Option<Conversation>> {
Ok(self.inner.read().get(id).cloned())
}
async fn update_conversation(
&self,
id: &ConversationId,
metadata: Option<ConversationMetadata>,
) -> Result<Option<Conversation>> {
let mut store = self.inner.write();
if let Some(entry) = store.get_mut(id) {
entry.metadata = metadata;
return Ok(Some(entry.clone()));
}
Ok(None)
}
async fn delete_conversation(&self, id: &ConversationId) -> Result<bool> {
let removed = self.inner.write().remove(id).is_some();
Ok(removed)
}
}

View File

@@ -0,0 +1,41 @@
use async_trait::async_trait;
use super::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage, Result,
};
/// No-op implementation that synthesizes conversation responses without persistence
#[derive(Default, Debug, Clone)]
pub struct NoOpConversationStorage;
impl NoOpConversationStorage {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl ConversationStorage for NoOpConversationStorage {
async fn create_conversation(
&self,
input: super::conversations::NewConversation,
) -> Result<Conversation> {
Ok(Conversation::new(input))
}
async fn get_conversation(&self, _id: &ConversationId) -> Result<Option<Conversation>> {
Ok(None)
}
async fn update_conversation(
&self,
_id: &ConversationId,
_metadata: Option<ConversationMetadata>,
) -> Result<Option<Conversation>> {
Ok(None)
}
async fn delete_conversation(&self, _id: &ConversationId) -> Result<bool> {
Ok(false)
}
}

View File

@@ -0,0 +1,338 @@
use crate::config::OracleConfig;
use crate::data_connector::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
ConversationStorageError, NewConversation, Result,
};
use async_trait::async_trait;
use chrono::Utc;
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::{sql_type::OracleType, Connection};
use serde_json::Value;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone)]
pub struct OracleConversationStorage {
pool: Pool<ConversationOracleConnectionManager>,
}
impl OracleConversationStorage {
pub fn new(config: OracleConfig) -> Result<Self> {
configure_oracle_client(&config)?;
initialize_schema(&config)?;
let config = Arc::new(config);
let manager = ConversationOracleConnectionManager::new(config.clone());
let mut builder = Pool::builder(manager)
.max_size(config.pool_max)
.runtime(deadpool::Runtime::Tokio1);
if config.pool_timeout_secs > 0 {
builder = builder.wait_timeout(Some(Duration::from_secs(config.pool_timeout_secs)));
}
let pool = builder.build().map_err(|err| {
ConversationStorageError::StorageError(format!(
"failed to build Oracle pool for conversations: {err}"
))
})?;
Ok(Self { pool })
}
async fn with_connection<F, T>(&self, func: F) -> Result<T>
where
F: FnOnce(&Connection) -> Result<T> + Send + 'static,
T: Send + 'static,
{
let connection = self.pool.get().await.map_err(map_pool_error)?;
tokio::task::spawn_blocking(move || {
let result = func(&connection);
drop(connection);
result
})
.await
.map_err(|err| {
ConversationStorageError::StorageError(format!(
"failed to execute Oracle conversation task: {err}"
))
})?
}
fn parse_metadata(raw: Option<String>) -> Result<Option<ConversationMetadata>> {
match raw {
Some(json) if !json.is_empty() => {
let value: Value = serde_json::from_str(&json)?;
match value {
Value::Object(map) => Ok(Some(map)),
Value::Null => Ok(None),
other => Err(ConversationStorageError::StorageError(format!(
"conversation metadata expected object, got {other}"
))),
}
}
_ => Ok(None),
}
}
}
#[async_trait]
impl ConversationStorage for OracleConversationStorage {
async fn create_conversation(&self, input: NewConversation) -> Result<Conversation> {
let conversation = Conversation::new(input);
let id_str = conversation.id.0.clone();
let created_at = conversation.created_at;
let metadata_json = conversation
.metadata
.as_ref()
.map(serde_json::to_string)
.transpose()?;
self.with_connection(move |conn| {
conn.execute(
"INSERT INTO conversations (id, created_at, metadata) VALUES (:1, :2, :3)",
&[&id_str, &created_at, &metadata_json],
)
.map(|_| ())
.map_err(map_oracle_error)
})
.await?;
Ok(conversation)
}
async fn get_conversation(&self, id: &ConversationId) -> Result<Option<Conversation>> {
let lookup = id.0.clone();
self.with_connection(move |conn| {
let mut stmt = conn
.statement("SELECT id, created_at, metadata FROM conversations WHERE id = :1")
.build()
.map_err(map_oracle_error)?;
let mut rows = stmt.query(&[&lookup]).map_err(map_oracle_error)?;
if let Some(row_res) = rows.next() {
let row = row_res.map_err(map_oracle_error)?;
let id: String = row.get(0).map_err(map_oracle_error)?;
let created_at: chrono::DateTime<Utc> = row.get(1).map_err(map_oracle_error)?;
let metadata_raw: Option<String> = row.get(2).map_err(map_oracle_error)?;
let metadata = Self::parse_metadata(metadata_raw)?;
Ok(Some(Conversation::with_parts(
ConversationId(id),
created_at,
metadata,
)))
} else {
Ok(None)
}
})
.await
}
async fn update_conversation(
&self,
id: &ConversationId,
metadata: Option<ConversationMetadata>,
) -> Result<Option<Conversation>> {
let id_str = id.0.clone();
let metadata_json = metadata.as_ref().map(serde_json::to_string).transpose()?;
let conversation_id = id.clone();
self.with_connection(move |conn| {
let mut stmt = conn
.statement(
"UPDATE conversations \
SET metadata = :1 \
WHERE id = :2 \
RETURNING created_at INTO :3",
)
.build()
.map_err(map_oracle_error)?;
stmt.bind(3, &OracleType::TimestampTZ(6))
.map_err(map_oracle_error)?;
stmt.execute(&[&metadata_json, &id_str])
.map_err(map_oracle_error)?;
if stmt.row_count().map_err(map_oracle_error)? == 0 {
return Ok(None);
}
let mut created_at: Vec<chrono::DateTime<Utc>> =
stmt.returned_values(3).map_err(map_oracle_error)?;
let created_at = created_at.pop().ok_or_else(|| {
ConversationStorageError::StorageError(
"Oracle update did not return created_at".to_string(),
)
})?;
Ok(Some(Conversation::with_parts(
conversation_id,
created_at,
metadata,
)))
})
.await
}
async fn delete_conversation(&self, id: &ConversationId) -> Result<bool> {
let id_str = id.0.clone();
let res = self
.with_connection(move |conn| {
conn.execute("DELETE FROM conversations WHERE id = :1", &[&id_str])
.map_err(map_oracle_error)
})
.await?;
Ok(res.row_count().map_err(map_oracle_error)? > 0)
}
}
#[derive(Clone)]
struct ConversationOracleConnectionManager {
params: Arc<OracleConnectParams>,
}
impl ConversationOracleConnectionManager {
fn new(config: Arc<OracleConfig>) -> Self {
let params = OracleConnectParams {
username: config.username.clone(),
password: config.password.clone(),
connect_descriptor: config.connect_descriptor.clone(),
};
Self {
params: Arc::new(params),
}
}
}
impl std::fmt::Debug for ConversationOracleConnectionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConversationOracleConnectionManager")
.field("username", &self.params.username)
.field("connect_descriptor", &self.params.connect_descriptor)
.finish()
}
}
#[derive(Clone)]
struct OracleConnectParams {
username: String,
password: String,
connect_descriptor: String,
}
impl std::fmt::Debug for OracleConnectParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OracleConnectParams")
.field("username", &self.username)
.field("connect_descriptor", &self.connect_descriptor)
.finish()
}
}
#[async_trait]
impl Manager for ConversationOracleConnectionManager {
type Type = Connection;
type Error = oracle::Error;
fn create(
&self,
) -> impl std::future::Future<Output = std::result::Result<Connection, oracle::Error>> + Send
{
let params = self.params.clone();
async move {
let mut conn = Connection::connect(
&params.username,
&params.password,
&params.connect_descriptor,
)?;
conn.set_autocommit(true);
Ok(conn)
}
}
#[allow(clippy::manual_async_fn)]
fn recycle(
&self,
conn: &mut Connection,
_: &Metrics,
) -> impl std::future::Future<Output = RecycleResult<Self::Error>> + Send {
async move { conn.ping().map_err(RecycleError::Backend) }
}
}
fn configure_oracle_client(config: &OracleConfig) -> Result<()> {
if let Some(wallet_path) = &config.wallet_path {
let wallet_path = Path::new(wallet_path);
if !wallet_path.is_dir() {
return Err(ConversationStorageError::StorageError(format!(
"Oracle wallet/config path '{}' is not a directory",
wallet_path.display()
)));
}
if !wallet_path.join("tnsnames.ora").exists() && !wallet_path.join("sqlnet.ora").exists() {
return Err(ConversationStorageError::StorageError(format!(
"Oracle wallet/config path '{}' is missing tnsnames.ora or sqlnet.ora",
wallet_path.display()
)));
}
std::env::set_var("TNS_ADMIN", wallet_path);
}
Ok(())
}
fn initialize_schema(config: &OracleConfig) -> Result<()> {
let conn = Connection::connect(
&config.username,
&config.password,
&config.connect_descriptor,
)
.map_err(map_oracle_error)?;
let exists: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tables WHERE table_name = 'CONVERSATIONS'",
&[],
)
.map_err(map_oracle_error)?;
if exists == 0 {
conn.execute(
"CREATE TABLE conversations (
id VARCHAR2(64) PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE,
metadata CLOB
)",
&[],
)
.map_err(map_oracle_error)?;
}
Ok(())
}
fn map_pool_error(err: PoolError<oracle::Error>) -> ConversationStorageError {
match err {
PoolError::Backend(e) => map_oracle_error(e),
other => ConversationStorageError::StorageError(format!(
"failed to obtain Oracle conversation connection: {other}"
)),
}
}
fn map_oracle_error(err: oracle::Error) -> ConversationStorageError {
if let Some(db_err) = err.db_error() {
ConversationStorageError::StorageError(format!(
"Oracle error (code {}): {}",
db_err.code(),
db_err.message()
))
} else {
ConversationStorageError::StorageError(err.to_string())
}
}

View File

@@ -0,0 +1,120 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use serde_json::{Map as JsonMap, Value};
use std::fmt::{Display, Formatter};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct ConversationId(pub String);
impl ConversationId {
pub fn new() -> Self {
let mut rng = rand::rng();
let mut bytes = [0u8; 24];
rng.fill_bytes(&mut bytes);
let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
Self(format!("conv_{}", hex_string))
}
}
impl Default for ConversationId {
fn default() -> Self {
Self::new()
}
}
impl From<String> for ConversationId {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<&str> for ConversationId {
fn from(value: &str) -> Self {
Self(value.to_string())
}
}
impl Display for ConversationId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
/// Metadata payload persisted with a conversation
pub type ConversationMetadata = JsonMap<String, Value>;
/// Input payload for creating a conversation
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct NewConversation {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<ConversationMetadata>,
}
/// Stored conversation data structure
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Conversation {
pub id: ConversationId,
pub created_at: DateTime<Utc>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<ConversationMetadata>,
}
impl Conversation {
pub fn new(new_conversation: NewConversation) -> Self {
Self {
id: ConversationId::new(),
created_at: Utc::now(),
metadata: new_conversation.metadata,
}
}
pub fn with_parts(
id: ConversationId,
created_at: DateTime<Utc>,
metadata: Option<ConversationMetadata>,
) -> Self {
Self {
id,
created_at,
metadata,
}
}
}
/// Result alias for conversation storage operations
pub type Result<T> = std::result::Result<T, ConversationStorageError>;
/// Error type for conversation storage operations
#[derive(Debug, thiserror::Error)]
pub enum ConversationStorageError {
#[error("Conversation not found: {0}")]
ConversationNotFound(String),
#[error("Storage error: {0}")]
StorageError(String),
#[error("Serialization error: {0}")]
SerializationError(#[from] serde_json::Error),
}
/// Trait describing the CRUD interface for conversation storage backends
#[async_trait]
pub trait ConversationStorage: Send + Sync + 'static {
async fn create_conversation(&self, input: NewConversation) -> Result<Conversation>;
async fn get_conversation(&self, id: &ConversationId) -> Result<Option<Conversation>>;
async fn update_conversation(
&self,
id: &ConversationId,
metadata: Option<ConversationMetadata>,
) -> Result<Option<Conversation>>;
async fn delete_conversation(&self, id: &ConversationId) -> Result<bool>;
}
/// Shared pointer alias for conversation storage
pub type SharedConversationStorage = Arc<dyn ConversationStorage>;

View File

@@ -1,9 +1,21 @@
// Data connector module for response storage
// Data connector module for response storage and conversation storage
pub mod conversation_memory_store;
pub mod conversation_noop_store;
pub mod conversation_oracle_store;
pub mod conversations;
pub mod response_memory_store;
pub mod response_noop_store;
pub mod response_oracle_store;
pub mod responses;
pub use conversation_memory_store::MemoryConversationStorage;
pub use conversation_noop_store::NoOpConversationStorage;
pub use conversation_oracle_store::OracleConversationStorage;
pub use conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
ConversationStorageError, NewConversation, Result as ConversationResult,
SharedConversationStorage,
};
pub use response_memory_store::MemoryResponseStorage;
pub use response_noop_store::NoOpResponseStorage;
pub use response_oracle_store::OracleResponseStorage;

View File

@@ -207,10 +207,10 @@ mod tests {
async fn test_store_with_custom_id() {
let store = MemoryResponseStorage::new();
let mut response = StoredResponse::new("Input".to_string(), "Output".to_string(), None);
response.id = ResponseId::from_string("resp_custom".to_string());
response.id = ResponseId::from("resp_custom");
store.store_response(response.clone()).await.unwrap();
let retrieved = store
.get_response(&ResponseId::from_string("resp_custom".to_string()))
.get_response(&ResponseId::from("resp_custom"))
.await
.unwrap();
assert!(retrieved.is_some());

View File

@@ -12,10 +12,6 @@ impl ResponseId {
pub fn new() -> Self {
Self(ulid::Ulid::new().to_string())
}
pub fn from_string(s: String) -> Self {
Self(s)
}
}
impl Default for ResponseId {
@@ -24,6 +20,18 @@ impl Default for ResponseId {
}
}
impl From<String> for ResponseId {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<&str> for ResponseId {
fn from(value: &str) -> Self {
Self(value.to_string())
}
}
/// Stored response data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredResponse {