[router] support Openai router conversation API CRUD (#11297)
This commit is contained in:
57
sgl-router/src/data_connector/conversation_memory_store.rs
Normal file
57
sgl-router/src/data_connector/conversation_memory_store.rs
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::conversations::{
|
||||||
|
Conversation, ConversationId, ConversationMetadata, ConversationStorage, NewConversation,
|
||||||
|
Result,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// In-memory conversation storage used for development and tests
|
||||||
|
#[derive(Default, Clone)]
|
||||||
|
pub struct MemoryConversationStorage {
|
||||||
|
inner: Arc<RwLock<HashMap<ConversationId, Conversation>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryConversationStorage {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
inner: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ConversationStorage for MemoryConversationStorage {
|
||||||
|
async fn create_conversation(&self, input: NewConversation) -> Result<Conversation> {
|
||||||
|
let conversation = Conversation::new(input);
|
||||||
|
self.inner
|
||||||
|
.write()
|
||||||
|
.insert(conversation.id.clone(), conversation.clone());
|
||||||
|
Ok(conversation)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_conversation(&self, id: &ConversationId) -> Result<Option<Conversation>> {
|
||||||
|
Ok(self.inner.read().get(id).cloned())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_conversation(
|
||||||
|
&self,
|
||||||
|
id: &ConversationId,
|
||||||
|
metadata: Option<ConversationMetadata>,
|
||||||
|
) -> Result<Option<Conversation>> {
|
||||||
|
let mut store = self.inner.write();
|
||||||
|
if let Some(entry) = store.get_mut(id) {
|
||||||
|
entry.metadata = metadata;
|
||||||
|
return Ok(Some(entry.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_conversation(&self, id: &ConversationId) -> Result<bool> {
|
||||||
|
let removed = self.inner.write().remove(id).is_some();
|
||||||
|
Ok(removed)
|
||||||
|
}
|
||||||
|
}
|
||||||
41
sgl-router/src/data_connector/conversation_noop_store.rs
Normal file
41
sgl-router/src/data_connector/conversation_noop_store.rs
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use super::conversations::{
|
||||||
|
Conversation, ConversationId, ConversationMetadata, ConversationStorage, Result,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// No-op implementation that synthesizes conversation responses without persistence
|
||||||
|
#[derive(Default, Debug, Clone)]
|
||||||
|
pub struct NoOpConversationStorage;
|
||||||
|
|
||||||
|
impl NoOpConversationStorage {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ConversationStorage for NoOpConversationStorage {
|
||||||
|
async fn create_conversation(
|
||||||
|
&self,
|
||||||
|
input: super::conversations::NewConversation,
|
||||||
|
) -> Result<Conversation> {
|
||||||
|
Ok(Conversation::new(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_conversation(&self, _id: &ConversationId) -> Result<Option<Conversation>> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_conversation(
|
||||||
|
&self,
|
||||||
|
_id: &ConversationId,
|
||||||
|
_metadata: Option<ConversationMetadata>,
|
||||||
|
) -> Result<Option<Conversation>> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_conversation(&self, _id: &ConversationId) -> Result<bool> {
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
338
sgl-router/src/data_connector/conversation_oracle_store.rs
Normal file
338
sgl-router/src/data_connector/conversation_oracle_store.rs
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
use crate::config::OracleConfig;
|
||||||
|
use crate::data_connector::conversations::{
|
||||||
|
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
|
||||||
|
ConversationStorageError, NewConversation, Result,
|
||||||
|
};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use chrono::Utc;
|
||||||
|
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
|
||||||
|
use oracle::{sql_type::OracleType, Connection};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OracleConversationStorage {
|
||||||
|
pool: Pool<ConversationOracleConnectionManager>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OracleConversationStorage {
|
||||||
|
pub fn new(config: OracleConfig) -> Result<Self> {
|
||||||
|
configure_oracle_client(&config)?;
|
||||||
|
initialize_schema(&config)?;
|
||||||
|
|
||||||
|
let config = Arc::new(config);
|
||||||
|
let manager = ConversationOracleConnectionManager::new(config.clone());
|
||||||
|
|
||||||
|
let mut builder = Pool::builder(manager)
|
||||||
|
.max_size(config.pool_max)
|
||||||
|
.runtime(deadpool::Runtime::Tokio1);
|
||||||
|
|
||||||
|
if config.pool_timeout_secs > 0 {
|
||||||
|
builder = builder.wait_timeout(Some(Duration::from_secs(config.pool_timeout_secs)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let pool = builder.build().map_err(|err| {
|
||||||
|
ConversationStorageError::StorageError(format!(
|
||||||
|
"failed to build Oracle pool for conversations: {err}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(Self { pool })
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn with_connection<F, T>(&self, func: F) -> Result<T>
|
||||||
|
where
|
||||||
|
F: FnOnce(&Connection) -> Result<T> + Send + 'static,
|
||||||
|
T: Send + 'static,
|
||||||
|
{
|
||||||
|
let connection = self.pool.get().await.map_err(map_pool_error)?;
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
|
let result = func(&connection);
|
||||||
|
drop(connection);
|
||||||
|
result
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
ConversationStorageError::StorageError(format!(
|
||||||
|
"failed to execute Oracle conversation task: {err}"
|
||||||
|
))
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_metadata(raw: Option<String>) -> Result<Option<ConversationMetadata>> {
|
||||||
|
match raw {
|
||||||
|
Some(json) if !json.is_empty() => {
|
||||||
|
let value: Value = serde_json::from_str(&json)?;
|
||||||
|
match value {
|
||||||
|
Value::Object(map) => Ok(Some(map)),
|
||||||
|
Value::Null => Ok(None),
|
||||||
|
other => Err(ConversationStorageError::StorageError(format!(
|
||||||
|
"conversation metadata expected object, got {other}"
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ConversationStorage for OracleConversationStorage {
|
||||||
|
async fn create_conversation(&self, input: NewConversation) -> Result<Conversation> {
|
||||||
|
let conversation = Conversation::new(input);
|
||||||
|
let id_str = conversation.id.0.clone();
|
||||||
|
let created_at = conversation.created_at;
|
||||||
|
let metadata_json = conversation
|
||||||
|
.metadata
|
||||||
|
.as_ref()
|
||||||
|
.map(serde_json::to_string)
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
self.with_connection(move |conn| {
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO conversations (id, created_at, metadata) VALUES (:1, :2, :3)",
|
||||||
|
&[&id_str, &created_at, &metadata_json],
|
||||||
|
)
|
||||||
|
.map(|_| ())
|
||||||
|
.map_err(map_oracle_error)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(conversation)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_conversation(&self, id: &ConversationId) -> Result<Option<Conversation>> {
|
||||||
|
let lookup = id.0.clone();
|
||||||
|
self.with_connection(move |conn| {
|
||||||
|
let mut stmt = conn
|
||||||
|
.statement("SELECT id, created_at, metadata FROM conversations WHERE id = :1")
|
||||||
|
.build()
|
||||||
|
.map_err(map_oracle_error)?;
|
||||||
|
let mut rows = stmt.query(&[&lookup]).map_err(map_oracle_error)?;
|
||||||
|
|
||||||
|
if let Some(row_res) = rows.next() {
|
||||||
|
let row = row_res.map_err(map_oracle_error)?;
|
||||||
|
let id: String = row.get(0).map_err(map_oracle_error)?;
|
||||||
|
let created_at: chrono::DateTime<Utc> = row.get(1).map_err(map_oracle_error)?;
|
||||||
|
let metadata_raw: Option<String> = row.get(2).map_err(map_oracle_error)?;
|
||||||
|
let metadata = Self::parse_metadata(metadata_raw)?;
|
||||||
|
Ok(Some(Conversation::with_parts(
|
||||||
|
ConversationId(id),
|
||||||
|
created_at,
|
||||||
|
metadata,
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_conversation(
|
||||||
|
&self,
|
||||||
|
id: &ConversationId,
|
||||||
|
metadata: Option<ConversationMetadata>,
|
||||||
|
) -> Result<Option<Conversation>> {
|
||||||
|
let id_str = id.0.clone();
|
||||||
|
let metadata_json = metadata.as_ref().map(serde_json::to_string).transpose()?;
|
||||||
|
let conversation_id = id.clone();
|
||||||
|
|
||||||
|
self.with_connection(move |conn| {
|
||||||
|
let mut stmt = conn
|
||||||
|
.statement(
|
||||||
|
"UPDATE conversations \
|
||||||
|
SET metadata = :1 \
|
||||||
|
WHERE id = :2 \
|
||||||
|
RETURNING created_at INTO :3",
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
.map_err(map_oracle_error)?;
|
||||||
|
|
||||||
|
stmt.bind(3, &OracleType::TimestampTZ(6))
|
||||||
|
.map_err(map_oracle_error)?;
|
||||||
|
stmt.execute(&[&metadata_json, &id_str])
|
||||||
|
.map_err(map_oracle_error)?;
|
||||||
|
|
||||||
|
if stmt.row_count().map_err(map_oracle_error)? == 0 {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut created_at: Vec<chrono::DateTime<Utc>> =
|
||||||
|
stmt.returned_values(3).map_err(map_oracle_error)?;
|
||||||
|
let created_at = created_at.pop().ok_or_else(|| {
|
||||||
|
ConversationStorageError::StorageError(
|
||||||
|
"Oracle update did not return created_at".to_string(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(Some(Conversation::with_parts(
|
||||||
|
conversation_id,
|
||||||
|
created_at,
|
||||||
|
metadata,
|
||||||
|
)))
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_conversation(&self, id: &ConversationId) -> Result<bool> {
|
||||||
|
let id_str = id.0.clone();
|
||||||
|
let res = self
|
||||||
|
.with_connection(move |conn| {
|
||||||
|
conn.execute("DELETE FROM conversations WHERE id = :1", &[&id_str])
|
||||||
|
.map_err(map_oracle_error)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(res.row_count().map_err(map_oracle_error)? > 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct ConversationOracleConnectionManager {
|
||||||
|
params: Arc<OracleConnectParams>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConversationOracleConnectionManager {
|
||||||
|
fn new(config: Arc<OracleConfig>) -> Self {
|
||||||
|
let params = OracleConnectParams {
|
||||||
|
username: config.username.clone(),
|
||||||
|
password: config.password.clone(),
|
||||||
|
connect_descriptor: config.connect_descriptor.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
params: Arc::new(params),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for ConversationOracleConnectionManager {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("ConversationOracleConnectionManager")
|
||||||
|
.field("username", &self.params.username)
|
||||||
|
.field("connect_descriptor", &self.params.connect_descriptor)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct OracleConnectParams {
|
||||||
|
username: String,
|
||||||
|
password: String,
|
||||||
|
connect_descriptor: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for OracleConnectParams {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("OracleConnectParams")
|
||||||
|
.field("username", &self.username)
|
||||||
|
.field("connect_descriptor", &self.connect_descriptor)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Manager for ConversationOracleConnectionManager {
|
||||||
|
type Type = Connection;
|
||||||
|
type Error = oracle::Error;
|
||||||
|
|
||||||
|
fn create(
|
||||||
|
&self,
|
||||||
|
) -> impl std::future::Future<Output = std::result::Result<Connection, oracle::Error>> + Send
|
||||||
|
{
|
||||||
|
let params = self.params.clone();
|
||||||
|
async move {
|
||||||
|
let mut conn = Connection::connect(
|
||||||
|
¶ms.username,
|
||||||
|
¶ms.password,
|
||||||
|
¶ms.connect_descriptor,
|
||||||
|
)?;
|
||||||
|
conn.set_autocommit(true);
|
||||||
|
Ok(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::manual_async_fn)]
|
||||||
|
fn recycle(
|
||||||
|
&self,
|
||||||
|
conn: &mut Connection,
|
||||||
|
_: &Metrics,
|
||||||
|
) -> impl std::future::Future<Output = RecycleResult<Self::Error>> + Send {
|
||||||
|
async move { conn.ping().map_err(RecycleError::Backend) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn configure_oracle_client(config: &OracleConfig) -> Result<()> {
|
||||||
|
if let Some(wallet_path) = &config.wallet_path {
|
||||||
|
let wallet_path = Path::new(wallet_path);
|
||||||
|
if !wallet_path.is_dir() {
|
||||||
|
return Err(ConversationStorageError::StorageError(format!(
|
||||||
|
"Oracle wallet/config path '{}' is not a directory",
|
||||||
|
wallet_path.display()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !wallet_path.join("tnsnames.ora").exists() && !wallet_path.join("sqlnet.ora").exists() {
|
||||||
|
return Err(ConversationStorageError::StorageError(format!(
|
||||||
|
"Oracle wallet/config path '{}' is missing tnsnames.ora or sqlnet.ora",
|
||||||
|
wallet_path.display()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::env::set_var("TNS_ADMIN", wallet_path);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn initialize_schema(config: &OracleConfig) -> Result<()> {
|
||||||
|
let conn = Connection::connect(
|
||||||
|
&config.username,
|
||||||
|
&config.password,
|
||||||
|
&config.connect_descriptor,
|
||||||
|
)
|
||||||
|
.map_err(map_oracle_error)?;
|
||||||
|
|
||||||
|
let exists: i64 = conn
|
||||||
|
.query_row_as(
|
||||||
|
"SELECT COUNT(*) FROM user_tables WHERE table_name = 'CONVERSATIONS'",
|
||||||
|
&[],
|
||||||
|
)
|
||||||
|
.map_err(map_oracle_error)?;
|
||||||
|
|
||||||
|
if exists == 0 {
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE conversations (
|
||||||
|
id VARCHAR2(64) PRIMARY KEY,
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE,
|
||||||
|
metadata CLOB
|
||||||
|
)",
|
||||||
|
&[],
|
||||||
|
)
|
||||||
|
.map_err(map_oracle_error)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_pool_error(err: PoolError<oracle::Error>) -> ConversationStorageError {
|
||||||
|
match err {
|
||||||
|
PoolError::Backend(e) => map_oracle_error(e),
|
||||||
|
other => ConversationStorageError::StorageError(format!(
|
||||||
|
"failed to obtain Oracle conversation connection: {other}"
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_oracle_error(err: oracle::Error) -> ConversationStorageError {
|
||||||
|
if let Some(db_err) = err.db_error() {
|
||||||
|
ConversationStorageError::StorageError(format!(
|
||||||
|
"Oracle error (code {}): {}",
|
||||||
|
db_err.code(),
|
||||||
|
db_err.message()
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
ConversationStorageError::StorageError(err.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
120
sgl-router/src/data_connector/conversations.rs
Normal file
120
sgl-router/src/data_connector/conversations.rs
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use rand::RngCore;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::{Map as JsonMap, Value};
|
||||||
|
use std::fmt::{Display, Formatter};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||||
|
pub struct ConversationId(pub String);
|
||||||
|
|
||||||
|
impl ConversationId {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let mut rng = rand::rng();
|
||||||
|
let mut bytes = [0u8; 24];
|
||||||
|
rng.fill_bytes(&mut bytes);
|
||||||
|
let hex_string: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
|
||||||
|
Self(format!("conv_{}", hex_string))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ConversationId {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for ConversationId {
|
||||||
|
fn from(value: String) -> Self {
|
||||||
|
Self(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for ConversationId {
|
||||||
|
fn from(value: &str) -> Self {
|
||||||
|
Self(value.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for ConversationId {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_str(&self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Metadata payload persisted with a conversation
|
||||||
|
pub type ConversationMetadata = JsonMap<String, Value>;
|
||||||
|
|
||||||
|
/// Input payload for creating a conversation
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
pub struct NewConversation {
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<ConversationMetadata>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stored conversation data structure
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub struct Conversation {
|
||||||
|
pub id: ConversationId,
|
||||||
|
pub created_at: DateTime<Utc>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<ConversationMetadata>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Conversation {
|
||||||
|
pub fn new(new_conversation: NewConversation) -> Self {
|
||||||
|
Self {
|
||||||
|
id: ConversationId::new(),
|
||||||
|
created_at: Utc::now(),
|
||||||
|
metadata: new_conversation.metadata,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_parts(
|
||||||
|
id: ConversationId,
|
||||||
|
created_at: DateTime<Utc>,
|
||||||
|
metadata: Option<ConversationMetadata>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
id,
|
||||||
|
created_at,
|
||||||
|
metadata,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result alias for conversation storage operations
|
||||||
|
pub type Result<T> = std::result::Result<T, ConversationStorageError>;
|
||||||
|
|
||||||
|
/// Error type for conversation storage operations
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ConversationStorageError {
|
||||||
|
#[error("Conversation not found: {0}")]
|
||||||
|
ConversationNotFound(String),
|
||||||
|
|
||||||
|
#[error("Storage error: {0}")]
|
||||||
|
StorageError(String),
|
||||||
|
|
||||||
|
#[error("Serialization error: {0}")]
|
||||||
|
SerializationError(#[from] serde_json::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trait describing the CRUD interface for conversation storage backends
|
||||||
|
#[async_trait]
|
||||||
|
pub trait ConversationStorage: Send + Sync + 'static {
|
||||||
|
async fn create_conversation(&self, input: NewConversation) -> Result<Conversation>;
|
||||||
|
|
||||||
|
async fn get_conversation(&self, id: &ConversationId) -> Result<Option<Conversation>>;
|
||||||
|
|
||||||
|
async fn update_conversation(
|
||||||
|
&self,
|
||||||
|
id: &ConversationId,
|
||||||
|
metadata: Option<ConversationMetadata>,
|
||||||
|
) -> Result<Option<Conversation>>;
|
||||||
|
|
||||||
|
async fn delete_conversation(&self, id: &ConversationId) -> Result<bool>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shared pointer alias for conversation storage
|
||||||
|
pub type SharedConversationStorage = Arc<dyn ConversationStorage>;
|
||||||
@@ -1,9 +1,21 @@
|
|||||||
// Data connector module for response storage
|
// Data connector module for response storage and conversation storage
|
||||||
|
pub mod conversation_memory_store;
|
||||||
|
pub mod conversation_noop_store;
|
||||||
|
pub mod conversation_oracle_store;
|
||||||
|
pub mod conversations;
|
||||||
pub mod response_memory_store;
|
pub mod response_memory_store;
|
||||||
pub mod response_noop_store;
|
pub mod response_noop_store;
|
||||||
pub mod response_oracle_store;
|
pub mod response_oracle_store;
|
||||||
pub mod responses;
|
pub mod responses;
|
||||||
|
|
||||||
|
pub use conversation_memory_store::MemoryConversationStorage;
|
||||||
|
pub use conversation_noop_store::NoOpConversationStorage;
|
||||||
|
pub use conversation_oracle_store::OracleConversationStorage;
|
||||||
|
pub use conversations::{
|
||||||
|
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
|
||||||
|
ConversationStorageError, NewConversation, Result as ConversationResult,
|
||||||
|
SharedConversationStorage,
|
||||||
|
};
|
||||||
pub use response_memory_store::MemoryResponseStorage;
|
pub use response_memory_store::MemoryResponseStorage;
|
||||||
pub use response_noop_store::NoOpResponseStorage;
|
pub use response_noop_store::NoOpResponseStorage;
|
||||||
pub use response_oracle_store::OracleResponseStorage;
|
pub use response_oracle_store::OracleResponseStorage;
|
||||||
|
|||||||
@@ -207,10 +207,10 @@ mod tests {
|
|||||||
async fn test_store_with_custom_id() {
|
async fn test_store_with_custom_id() {
|
||||||
let store = MemoryResponseStorage::new();
|
let store = MemoryResponseStorage::new();
|
||||||
let mut response = StoredResponse::new("Input".to_string(), "Output".to_string(), None);
|
let mut response = StoredResponse::new("Input".to_string(), "Output".to_string(), None);
|
||||||
response.id = ResponseId::from_string("resp_custom".to_string());
|
response.id = ResponseId::from("resp_custom");
|
||||||
store.store_response(response.clone()).await.unwrap();
|
store.store_response(response.clone()).await.unwrap();
|
||||||
let retrieved = store
|
let retrieved = store
|
||||||
.get_response(&ResponseId::from_string("resp_custom".to_string()))
|
.get_response(&ResponseId::from("resp_custom"))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(retrieved.is_some());
|
assert!(retrieved.is_some());
|
||||||
|
|||||||
@@ -12,10 +12,6 @@ impl ResponseId {
|
|||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self(ulid::Ulid::new().to_string())
|
Self(ulid::Ulid::new().to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_string(s: String) -> Self {
|
|
||||||
Self(s)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ResponseId {
|
impl Default for ResponseId {
|
||||||
@@ -24,6 +20,18 @@ impl Default for ResponseId {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<String> for ResponseId {
|
||||||
|
fn from(value: String) -> Self {
|
||||||
|
Self(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for ResponseId {
|
||||||
|
fn from(value: &str) -> Self {
|
||||||
|
Self(value.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Stored response data
|
/// Stored response data
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct StoredResponse {
|
pub struct StoredResponse {
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ impl RouterFactory {
|
|||||||
base_url,
|
base_url,
|
||||||
Some(ctx.router_config.circuit_breaker.clone()),
|
Some(ctx.router_config.circuit_breaker.clone()),
|
||||||
ctx.response_storage.clone(),
|
ctx.response_storage.clone(),
|
||||||
|
ctx.conversation_storage.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,10 @@
|
|||||||
|
|
||||||
use crate::config::CircuitBreakerConfig;
|
use crate::config::CircuitBreakerConfig;
|
||||||
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
|
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
|
||||||
use crate::data_connector::{ResponseId, SharedResponseStorage, StoredResponse};
|
use crate::data_connector::{
|
||||||
|
Conversation, ConversationId, ConversationMetadata, ResponseId, SharedConversationStorage,
|
||||||
|
SharedResponseStorage, StoredResponse,
|
||||||
|
};
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||||
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
|
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
|
||||||
@@ -16,6 +19,7 @@ use axum::{
|
|||||||
extract::Request,
|
extract::Request,
|
||||||
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
|
Json,
|
||||||
};
|
};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
@@ -75,6 +79,8 @@ pub struct OpenAIRouter {
|
|||||||
healthy: AtomicBool,
|
healthy: AtomicBool,
|
||||||
/// Response storage for managing conversation history
|
/// Response storage for managing conversation history
|
||||||
response_storage: SharedResponseStorage,
|
response_storage: SharedResponseStorage,
|
||||||
|
/// Conversation storage backend
|
||||||
|
conversation_storage: SharedConversationStorage,
|
||||||
/// Optional MCP manager (enabled via config presence)
|
/// Optional MCP manager (enabled via config presence)
|
||||||
mcp_manager: Option<Arc<crate::mcp::McpClientManager>>,
|
mcp_manager: Option<Arc<crate::mcp::McpClientManager>>,
|
||||||
}
|
}
|
||||||
@@ -705,6 +711,7 @@ impl OpenAIRouter {
|
|||||||
base_url: String,
|
base_url: String,
|
||||||
circuit_breaker_config: Option<CircuitBreakerConfig>,
|
circuit_breaker_config: Option<CircuitBreakerConfig>,
|
||||||
response_storage: SharedResponseStorage,
|
response_storage: SharedResponseStorage,
|
||||||
|
conversation_storage: SharedConversationStorage,
|
||||||
) -> Result<Self, String> {
|
) -> Result<Self, String> {
|
||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(300))
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
@@ -751,6 +758,7 @@ impl OpenAIRouter {
|
|||||||
circuit_breaker,
|
circuit_breaker,
|
||||||
healthy: AtomicBool::new(true),
|
healthy: AtomicBool::new(true),
|
||||||
response_storage,
|
response_storage,
|
||||||
|
conversation_storage,
|
||||||
mcp_manager,
|
mcp_manager,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -2337,16 +2345,16 @@ impl OpenAIRouter {
|
|||||||
stored_response.previous_response_id = response_json
|
stored_response.previous_response_id = response_json
|
||||||
.get("previous_response_id")
|
.get("previous_response_id")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.map(|s| ResponseId::from_string(s.to_string()))
|
.map(ResponseId::from)
|
||||||
.or_else(|| {
|
.or_else(|| {
|
||||||
original_body
|
original_body
|
||||||
.previous_response_id
|
.previous_response_id
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|id| ResponseId::from_string(id.clone()))
|
.map(|id| ResponseId::from(id.as_str()))
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Some(id_str) = response_json.get("id").and_then(|v| v.as_str()) {
|
if let Some(id_str) = response_json.get("id").and_then(|v| v.as_str()) {
|
||||||
stored_response.id = ResponseId::from_string(id_str.to_string());
|
stored_response.id = ResponseId::from(id_str);
|
||||||
}
|
}
|
||||||
|
|
||||||
stored_response.raw_response = response_json.clone();
|
stored_response.raw_response = response_json.clone();
|
||||||
@@ -3393,7 +3401,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
// Handle previous_response_id by loading prior context
|
// Handle previous_response_id by loading prior context
|
||||||
let mut conversation_items: Option<Vec<ResponseInputOutputItem>> = None;
|
let mut conversation_items: Option<Vec<ResponseInputOutputItem>> = None;
|
||||||
if let Some(prev_id_str) = request_body.previous_response_id.clone() {
|
if let Some(prev_id_str) = request_body.previous_response_id.clone() {
|
||||||
let prev_id = ResponseId::from_string(prev_id_str.clone());
|
let prev_id = ResponseId::from(prev_id_str.as_str());
|
||||||
match self
|
match self
|
||||||
.response_storage
|
.response_storage
|
||||||
.get_response_chain(&prev_id, None)
|
.get_response_chain(&prev_id, None)
|
||||||
@@ -3516,7 +3524,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
response_id: &str,
|
response_id: &str,
|
||||||
params: &ResponsesGetParams,
|
params: &ResponsesGetParams,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let stored_id = ResponseId::from_string(response_id.to_string());
|
let stored_id = ResponseId::from(response_id);
|
||||||
if let Ok(Some(stored_response)) = self.response_storage.get_response(&stored_id).await {
|
if let Ok(Some(stored_response)) = self.response_storage.get_response(&stored_id).await {
|
||||||
let stream_requested = params.stream.unwrap_or(false);
|
let stream_requested = params.stream.unwrap_or(false);
|
||||||
let raw_value = stored_response.raw_response.clone();
|
let raw_value = stored_response.raw_response.clone();
|
||||||
@@ -3646,10 +3654,6 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn router_type(&self) -> &'static str {
|
|
||||||
"openai"
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn route_embeddings(
|
async fn route_embeddings(
|
||||||
&self,
|
&self,
|
||||||
_headers: Option<&HeaderMap>,
|
_headers: Option<&HeaderMap>,
|
||||||
@@ -3675,4 +3679,309 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
)
|
)
|
||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn create_conversation(&self, _headers: Option<&HeaderMap>, body: &Value) -> Response {
|
||||||
|
// TODO: move this spec validation to the right place
|
||||||
|
let metadata = match body.get("metadata") {
|
||||||
|
Some(Value::Object(map)) => {
|
||||||
|
if map.len() > MAX_METADATA_PROPERTIES {
|
||||||
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": format!(
|
||||||
|
"Invalid 'metadata': too many properties. Max {}, got {}",
|
||||||
|
MAX_METADATA_PROPERTIES, map.len()
|
||||||
|
),
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": "metadata",
|
||||||
|
"code": "metadata_max_properties_exceeded"
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
Some(map.clone())
|
||||||
|
}
|
||||||
|
Some(Value::Null) | None => None,
|
||||||
|
Some(other) => {
|
||||||
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": format!(
|
||||||
|
"Invalid 'metadata': expected object or null but got {}",
|
||||||
|
other
|
||||||
|
),
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": "metadata",
|
||||||
|
"code": "metadata_invalid_type"
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match self
|
||||||
|
.conversation_storage
|
||||||
|
.create_conversation(crate::data_connector::NewConversation { metadata })
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(conversation) => {
|
||||||
|
(StatusCode::OK, Json(conversation_to_json(&conversation))).into_response()
|
||||||
|
}
|
||||||
|
Err(err) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": err.to_string(),
|
||||||
|
"type": "internal_error",
|
||||||
|
"param": Value::Null,
|
||||||
|
"code": Value::Null
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_conversation(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
conversation_id: &str,
|
||||||
|
) -> Response {
|
||||||
|
let id: ConversationId = conversation_id.to_string().into();
|
||||||
|
match self.conversation_storage.get_conversation(&id).await {
|
||||||
|
Ok(Some(conv)) => (StatusCode::OK, Json(conversation_to_json(&conv))).into_response(),
|
||||||
|
Ok(None) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": format!("Conversation with id '{}' not found.", conversation_id),
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": Value::Null,
|
||||||
|
"code": Value::Null
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(err) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": err.to_string(),
|
||||||
|
"type": "internal_error",
|
||||||
|
"param": Value::Null,
|
||||||
|
"code": Value::Null
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_conversation(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
conversation_id: &str,
|
||||||
|
body: &Value,
|
||||||
|
) -> Response {
|
||||||
|
let id: ConversationId = conversation_id.to_string().into();
|
||||||
|
let existing = match self.conversation_storage.get_conversation(&id).await {
|
||||||
|
Ok(Some(c)) => c,
|
||||||
|
Ok(None) => {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": format!("Conversation with id '{}' not found.", conversation_id),
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": Value::Null,
|
||||||
|
"code": Value::Null
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": err.to_string(),
|
||||||
|
"type": "internal_error",
|
||||||
|
"param": Value::Null,
|
||||||
|
"code": Value::Null
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse metadata patch
|
||||||
|
enum Patch {
|
||||||
|
NoChange,
|
||||||
|
ClearAll,
|
||||||
|
Merge(ConversationMetadata),
|
||||||
|
}
|
||||||
|
let patch = match body.get("metadata") {
|
||||||
|
None => Patch::NoChange,
|
||||||
|
Some(Value::Null) => Patch::ClearAll,
|
||||||
|
Some(Value::Object(map)) => Patch::Merge(map.clone()),
|
||||||
|
Some(other) => {
|
||||||
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": format!(
|
||||||
|
"Invalid 'metadata': expected object or null but got {}",
|
||||||
|
other
|
||||||
|
),
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": "metadata",
|
||||||
|
"code": "metadata_invalid_type"
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let merged_metadata = match patch {
|
||||||
|
Patch::NoChange => {
|
||||||
|
return (StatusCode::OK, Json(conversation_to_json(&existing))).into_response();
|
||||||
|
}
|
||||||
|
Patch::ClearAll => None,
|
||||||
|
Patch::Merge(upd) => {
|
||||||
|
let mut merged = existing.metadata.clone().unwrap_or_default();
|
||||||
|
let previous = merged.len();
|
||||||
|
for (k, v) in upd.into_iter() {
|
||||||
|
if v.is_null() {
|
||||||
|
merged.remove(&k);
|
||||||
|
} else {
|
||||||
|
merged.insert(k, v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let updated = merged.len();
|
||||||
|
if updated > MAX_METADATA_PROPERTIES {
|
||||||
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": format!(
|
||||||
|
"Invalid 'metadata': too many properties after update. Max {} ({} -> {}).",
|
||||||
|
MAX_METADATA_PROPERTIES, previous, updated
|
||||||
|
),
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": "metadata",
|
||||||
|
"code": "metadata_max_properties_exceeded",
|
||||||
|
"extra": {
|
||||||
|
"previous_property_count": previous,
|
||||||
|
"updated_property_count": updated
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
if merged.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(merged)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match self
|
||||||
|
.conversation_storage
|
||||||
|
.update_conversation(&id, merged_metadata)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(Some(conv)) => (StatusCode::OK, Json(conversation_to_json(&conv))).into_response(),
|
||||||
|
Ok(None) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": format!("Conversation with id '{}' not found.", conversation_id),
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": Value::Null,
|
||||||
|
"code": Value::Null
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(err) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": err.to_string(),
|
||||||
|
"type": "internal_error",
|
||||||
|
"param": Value::Null,
|
||||||
|
"code": Value::Null
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_conversation(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
conversation_id: &str,
|
||||||
|
) -> Response {
|
||||||
|
let id: ConversationId = conversation_id.to_string().into();
|
||||||
|
match self.conversation_storage.delete_conversation(&id).await {
|
||||||
|
Ok(true) => (
|
||||||
|
StatusCode::OK,
|
||||||
|
Json(json!({
|
||||||
|
"id": conversation_id,
|
||||||
|
"object": "conversation.deleted",
|
||||||
|
"deleted": true
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Ok(false) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": format!("Conversation with id '{}' not found.", conversation_id),
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": Value::Null,
|
||||||
|
"code": Value::Null
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(err) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": err.to_string(),
|
||||||
|
"type": "internal_error",
|
||||||
|
"param": Value::Null,
|
||||||
|
"code": Value::Null
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn router_type(&self) -> &'static str {
|
||||||
|
"openai"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Maximum number of properties allowed in conversation metadata (align with server)
|
||||||
|
const MAX_METADATA_PROPERTIES: usize = 16;
|
||||||
|
|
||||||
|
fn conversation_to_json(conversation: &Conversation) -> Value {
|
||||||
|
json!({
|
||||||
|
"id": conversation.id.0,
|
||||||
|
"object": "conversation",
|
||||||
|
"created_at": conversation.created_at.timestamp(),
|
||||||
|
"metadata": to_value(&conversation.metadata).unwrap_or(Value::Null),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ use crate::protocols::spec::{
|
|||||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||||
ResponsesGetParams, ResponsesRequest,
|
ResponsesGetParams, ResponsesRequest,
|
||||||
};
|
};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod grpc;
|
pub mod grpc;
|
||||||
@@ -126,6 +127,52 @@ pub trait RouterTrait: Send + Sync + Debug {
|
|||||||
model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
) -> Response;
|
) -> Response;
|
||||||
|
|
||||||
|
// Conversations API
|
||||||
|
async fn create_conversation(&self, _headers: Option<&HeaderMap>, _body: &Value) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Conversations create endpoint not implemented",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_conversation(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_conversation_id: &str,
|
||||||
|
) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Conversations get endpoint not implemented",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_conversation(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_conversation_id: &str,
|
||||||
|
_body: &Value,
|
||||||
|
) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Conversations update endpoint not implemented",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_conversation(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_conversation_id: &str,
|
||||||
|
) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Conversations delete endpoint not implemented",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
/// Get router type name
|
/// Get router type name
|
||||||
fn router_type(&self) -> &'static str;
|
fn router_type(&self) -> &'static str;
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ use axum::{
|
|||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
|
use serde_json::Value;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::{debug, info, warn};
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
@@ -511,6 +512,83 @@ impl RouterTrait for RouterManager {
|
|||||||
fn router_type(&self) -> &'static str {
|
fn router_type(&self) -> &'static str {
|
||||||
"manager"
|
"manager"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Conversations API delegates
|
||||||
|
async fn create_conversation(&self, headers: Option<&HeaderMap>, body: &Value) -> Response {
|
||||||
|
let router = self.select_router_for_request(headers, None);
|
||||||
|
if let Some(router) = router {
|
||||||
|
router.create_conversation(headers, body).await
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
"No router available to create conversation",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_conversation(
|
||||||
|
&self,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
conversation_id: &str,
|
||||||
|
) -> Response {
|
||||||
|
let router = self.select_router_for_request(headers, None);
|
||||||
|
if let Some(router) = router {
|
||||||
|
router.get_conversation(headers, conversation_id).await
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
format!(
|
||||||
|
"No router available to get conversation '{}'",
|
||||||
|
conversation_id
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_conversation(
|
||||||
|
&self,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
conversation_id: &str,
|
||||||
|
body: &Value,
|
||||||
|
) -> Response {
|
||||||
|
let router = self.select_router_for_request(headers, None);
|
||||||
|
if let Some(router) = router {
|
||||||
|
router
|
||||||
|
.update_conversation(headers, conversation_id, body)
|
||||||
|
.await
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
format!(
|
||||||
|
"No router available to update conversation '{}'",
|
||||||
|
conversation_id
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_conversation(
|
||||||
|
&self,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
conversation_id: &str,
|
||||||
|
) -> Response {
|
||||||
|
let router = self.select_router_for_request(headers, None);
|
||||||
|
if let Some(router) = router {
|
||||||
|
router.delete_conversation(headers, conversation_id).await
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
format!(
|
||||||
|
"No router available to delete conversation '{}'",
|
||||||
|
conversation_id
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for RouterManager {
|
impl std::fmt::Debug for RouterManager {
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ use crate::{
|
|||||||
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
|
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
|
||||||
core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType},
|
core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType},
|
||||||
data_connector::{
|
data_connector::{
|
||||||
MemoryResponseStorage, NoOpResponseStorage, OracleResponseStorage, SharedResponseStorage,
|
MemoryConversationStorage, MemoryResponseStorage, NoOpConversationStorage,
|
||||||
|
NoOpResponseStorage, OracleConversationStorage, OracleResponseStorage,
|
||||||
|
SharedConversationStorage, SharedResponseStorage,
|
||||||
},
|
},
|
||||||
logging::{self, LoggingConfig},
|
logging::{self, LoggingConfig},
|
||||||
metrics::{self, PrometheusConfig},
|
metrics::{self, PrometheusConfig},
|
||||||
@@ -39,6 +41,8 @@ use std::{
|
|||||||
use tokio::{net::TcpListener, signal, spawn};
|
use tokio::{net::TcpListener, signal, spawn};
|
||||||
use tracing::{error, info, warn, Level};
|
use tracing::{error, info, warn, Level};
|
||||||
|
|
||||||
|
//
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppContext {
|
pub struct AppContext {
|
||||||
pub client: Client,
|
pub client: Client,
|
||||||
@@ -51,6 +55,7 @@ pub struct AppContext {
|
|||||||
pub policy_registry: Arc<PolicyRegistry>,
|
pub policy_registry: Arc<PolicyRegistry>,
|
||||||
pub router_manager: Option<Arc<RouterManager>>,
|
pub router_manager: Option<Arc<RouterManager>>,
|
||||||
pub response_storage: SharedResponseStorage,
|
pub response_storage: SharedResponseStorage,
|
||||||
|
pub conversation_storage: SharedConversationStorage,
|
||||||
pub load_monitor: Option<Arc<LoadMonitor>>,
|
pub load_monitor: Option<Arc<LoadMonitor>>,
|
||||||
pub configured_reasoning_parser: Option<String>,
|
pub configured_reasoning_parser: Option<String>,
|
||||||
pub configured_tool_parser: Option<String>,
|
pub configured_tool_parser: Option<String>,
|
||||||
@@ -94,19 +99,34 @@ impl AppContext {
|
|||||||
|
|
||||||
let router_manager = None;
|
let router_manager = None;
|
||||||
|
|
||||||
let response_storage: SharedResponseStorage = match router_config.history_backend {
|
let (response_storage, conversation_storage): (
|
||||||
HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()),
|
SharedResponseStorage,
|
||||||
HistoryBackend::None => Arc::new(NoOpResponseStorage::new()),
|
SharedConversationStorage,
|
||||||
|
) = match router_config.history_backend {
|
||||||
|
HistoryBackend::Memory => (
|
||||||
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
|
),
|
||||||
|
HistoryBackend::None => (
|
||||||
|
Arc::new(NoOpResponseStorage::new()),
|
||||||
|
Arc::new(NoOpConversationStorage::new()),
|
||||||
|
),
|
||||||
HistoryBackend::Oracle => {
|
HistoryBackend::Oracle => {
|
||||||
let oracle_cfg = router_config.oracle.clone().ok_or_else(|| {
|
let oracle_cfg = router_config.oracle.clone().ok_or_else(|| {
|
||||||
"oracle configuration is required when history_backend=oracle".to_string()
|
"oracle configuration is required when history_backend=oracle".to_string()
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let storage = OracleResponseStorage::new(oracle_cfg).map_err(|err| {
|
let response_storage =
|
||||||
format!("failed to initialize Oracle response storage: {err}")
|
OracleResponseStorage::new(oracle_cfg.clone()).map_err(|err| {
|
||||||
})?;
|
format!("failed to initialize Oracle response storage: {err}")
|
||||||
|
})?;
|
||||||
|
|
||||||
Arc::new(storage)
|
let conversation_storage =
|
||||||
|
OracleConversationStorage::new(oracle_cfg).map_err(|err| {
|
||||||
|
format!("failed to initialize Oracle conversation storage: {err}")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
(Arc::new(response_storage), Arc::new(conversation_storage))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -131,6 +151,7 @@ impl AppContext {
|
|||||||
policy_registry,
|
policy_registry,
|
||||||
router_manager,
|
router_manager,
|
||||||
response_storage,
|
response_storage,
|
||||||
|
conversation_storage,
|
||||||
load_monitor,
|
load_monitor,
|
||||||
configured_reasoning_parser,
|
configured_reasoning_parser,
|
||||||
configured_tool_parser,
|
configured_tool_parser,
|
||||||
@@ -334,6 +355,51 @@ async fn v1_responses_list_input_items(
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn v1_conversations_create(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
headers: http::HeaderMap,
|
||||||
|
Json(body): Json<Value>,
|
||||||
|
) -> Response {
|
||||||
|
state
|
||||||
|
.router
|
||||||
|
.create_conversation(Some(&headers), &body)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn v1_conversations_get(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
Path(conversation_id): Path<String>,
|
||||||
|
headers: http::HeaderMap,
|
||||||
|
) -> Response {
|
||||||
|
state
|
||||||
|
.router
|
||||||
|
.get_conversation(Some(&headers), &conversation_id)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn v1_conversations_update(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
Path(conversation_id): Path<String>,
|
||||||
|
headers: http::HeaderMap,
|
||||||
|
Json(body): Json<Value>,
|
||||||
|
) -> Response {
|
||||||
|
state
|
||||||
|
.router
|
||||||
|
.update_conversation(Some(&headers), &conversation_id, &body)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn v1_conversations_delete(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
Path(conversation_id): Path<String>,
|
||||||
|
headers: http::HeaderMap,
|
||||||
|
) -> Response {
|
||||||
|
state
|
||||||
|
.router
|
||||||
|
.delete_conversation(Some(&headers), &conversation_id)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct AddWorkerQuery {
|
struct AddWorkerQuery {
|
||||||
url: String,
|
url: String,
|
||||||
@@ -601,6 +667,13 @@ pub fn build_app(
|
|||||||
"/v1/responses/{response_id}/input",
|
"/v1/responses/{response_id}/input",
|
||||||
get(v1_responses_list_input_items),
|
get(v1_responses_list_input_items),
|
||||||
)
|
)
|
||||||
|
.route("/v1/conversations", post(v1_conversations_create))
|
||||||
|
.route(
|
||||||
|
"/v1/conversations/{conversation_id}",
|
||||||
|
get(v1_conversations_get)
|
||||||
|
.post(v1_conversations_update)
|
||||||
|
.delete(v1_conversations_delete),
|
||||||
|
)
|
||||||
.route_layer(axum::middleware::from_fn_with_state(
|
.route_layer(axum::middleware::from_fn_with_state(
|
||||||
app_state.clone(),
|
app_state.clone(),
|
||||||
middleware::concurrency_limit_middleware,
|
middleware::concurrency_limit_middleware,
|
||||||
|
|||||||
@@ -542,6 +542,7 @@ mod tests {
|
|||||||
tool_parser_factory: None,
|
tool_parser_factory: None,
|
||||||
router_manager: None,
|
router_manager: None,
|
||||||
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
|
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
|
||||||
|
conversation_storage: Arc::new(crate::data_connector::MemoryConversationStorage::new()),
|
||||||
load_monitor: None,
|
load_monitor: None,
|
||||||
configured_reasoning_parser: None,
|
configured_reasoning_parser: None,
|
||||||
configured_tool_parser: None,
|
configured_tool_parser: None,
|
||||||
|
|||||||
@@ -239,6 +239,100 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
|||||||
mcp.stop().await;
|
mcp.stop().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_conversations_crud_basic() {
|
||||||
|
// Router in OpenAI mode (no actual upstream calls in these tests)
|
||||||
|
let router_cfg = RouterConfig {
|
||||||
|
mode: RoutingMode::OpenAI {
|
||||||
|
worker_urls: vec!["http://localhost".to_string()],
|
||||||
|
},
|
||||||
|
connection_mode: ConnectionMode::Http,
|
||||||
|
policy: PolicyConfig::Random,
|
||||||
|
host: "127.0.0.1".to_string(),
|
||||||
|
port: 0,
|
||||||
|
max_payload_size: 8 * 1024 * 1024,
|
||||||
|
request_timeout_secs: 60,
|
||||||
|
worker_startup_timeout_secs: 1,
|
||||||
|
worker_startup_check_interval_secs: 1,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
|
discovery: None,
|
||||||
|
metrics: None,
|
||||||
|
log_dir: None,
|
||||||
|
log_level: Some("warn".to_string()),
|
||||||
|
request_id_headers: None,
|
||||||
|
max_concurrent_requests: 8,
|
||||||
|
queue_size: 0,
|
||||||
|
queue_timeout_secs: 5,
|
||||||
|
rate_limit_tokens_per_second: None,
|
||||||
|
cors_allowed_origins: vec![],
|
||||||
|
retry: RetryConfig::default(),
|
||||||
|
circuit_breaker: CircuitBreakerConfig::default(),
|
||||||
|
disable_retries: false,
|
||||||
|
disable_circuit_breaker: false,
|
||||||
|
health_check: HealthCheckConfig::default(),
|
||||||
|
enable_igw: false,
|
||||||
|
model_path: None,
|
||||||
|
tokenizer_path: None,
|
||||||
|
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||||
|
oracle: None,
|
||||||
|
reasoning_parser: None,
|
||||||
|
tool_call_parser: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let ctx = AppContext::new(router_cfg, reqwest::Client::new(), 8, None).expect("ctx");
|
||||||
|
let router = RouterFactory::create_router(&Arc::new(ctx))
|
||||||
|
.await
|
||||||
|
.expect("router");
|
||||||
|
|
||||||
|
// Create
|
||||||
|
let create_body = serde_json::json!({ "metadata": { "project": "alpha" } });
|
||||||
|
let create_resp = router.create_conversation(None, &create_body).await;
|
||||||
|
assert_eq!(create_resp.status(), axum::http::StatusCode::OK);
|
||||||
|
let create_bytes = axum::body::to_bytes(create_resp.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let create_json: serde_json::Value = serde_json::from_slice(&create_bytes).unwrap();
|
||||||
|
let conv_id = create_json["id"].as_str().expect("id missing");
|
||||||
|
assert!(conv_id.starts_with("conv_"));
|
||||||
|
assert_eq!(create_json["object"], "conversation");
|
||||||
|
|
||||||
|
// Get
|
||||||
|
let get_resp = router.get_conversation(None, conv_id).await;
|
||||||
|
assert_eq!(get_resp.status(), axum::http::StatusCode::OK);
|
||||||
|
let get_bytes = axum::body::to_bytes(get_resp.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let get_json: serde_json::Value = serde_json::from_slice(&get_bytes).unwrap();
|
||||||
|
assert_eq!(get_json["metadata"]["project"], serde_json::json!("alpha"));
|
||||||
|
|
||||||
|
// Update (merge)
|
||||||
|
let update_body = serde_json::json!({ "metadata": { "owner": "alice" } });
|
||||||
|
let upd_resp = router
|
||||||
|
.update_conversation(None, conv_id, &update_body)
|
||||||
|
.await;
|
||||||
|
assert_eq!(upd_resp.status(), axum::http::StatusCode::OK);
|
||||||
|
let upd_bytes = axum::body::to_bytes(upd_resp.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let upd_json: serde_json::Value = serde_json::from_slice(&upd_bytes).unwrap();
|
||||||
|
assert_eq!(upd_json["metadata"]["project"], serde_json::json!("alpha"));
|
||||||
|
assert_eq!(upd_json["metadata"]["owner"], serde_json::json!("alice"));
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
let del_resp = router.delete_conversation(None, conv_id).await;
|
||||||
|
assert_eq!(del_resp.status(), axum::http::StatusCode::OK);
|
||||||
|
let del_bytes = axum::body::to_bytes(del_resp.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let del_json: serde_json::Value = serde_json::from_slice(&del_bytes).unwrap();
|
||||||
|
assert_eq!(del_json["deleted"], serde_json::json!(true));
|
||||||
|
|
||||||
|
// Get again -> 404
|
||||||
|
let not_found = router.get_conversation(None, conv_id).await;
|
||||||
|
assert_eq!(not_found.status(), axum::http::StatusCode::NOT_FOUND);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_responses_request_creation() {
|
fn test_responses_request_creation() {
|
||||||
let request = ResponsesRequest {
|
let request = ResponsesRequest {
|
||||||
|
|||||||
@@ -13,7 +13,10 @@ use sglang_router_rs::{
|
|||||||
config::{
|
config::{
|
||||||
ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode,
|
ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode,
|
||||||
},
|
},
|
||||||
data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage, StoredResponse},
|
data_connector::{
|
||||||
|
MemoryConversationStorage, MemoryResponseStorage, ResponseId, ResponseStorage,
|
||||||
|
StoredResponse,
|
||||||
|
},
|
||||||
protocols::spec::{
|
protocols::spec::{
|
||||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput,
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput,
|
||||||
ResponsesGetParams, ResponsesRequest, UserMessageContent,
|
ResponsesGetParams, ResponsesRequest, UserMessageContent,
|
||||||
@@ -91,6 +94,7 @@ async fn test_openai_router_creation() {
|
|||||||
"https://api.openai.com".to_string(),
|
"https://api.openai.com".to_string(),
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -108,6 +112,7 @@ async fn test_openai_router_server_info() {
|
|||||||
"https://api.openai.com".to_string(),
|
"https://api.openai.com".to_string(),
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -137,6 +142,7 @@ async fn test_openai_router_models() {
|
|||||||
mock_server.base_url(),
|
mock_server.base_url(),
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -211,9 +217,14 @@ async fn test_openai_router_responses_with_mock() {
|
|||||||
let base_url = format!("http://{}", addr);
|
let base_url = format!("http://{}", addr);
|
||||||
let storage = Arc::new(MemoryResponseStorage::new());
|
let storage = Arc::new(MemoryResponseStorage::new());
|
||||||
|
|
||||||
let router = OpenAIRouter::new(base_url, None, storage.clone())
|
let router = OpenAIRouter::new(
|
||||||
.await
|
base_url,
|
||||||
.unwrap();
|
None,
|
||||||
|
storage.clone(),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let request1 = ResponsesRequest {
|
let request1 = ResponsesRequest {
|
||||||
model: Some("gpt-4o-mini".to_string()),
|
model: Some("gpt-4o-mini".to_string()),
|
||||||
@@ -252,7 +263,7 @@ async fn test_openai_router_responses_with_mock() {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let stored1 = storage
|
let stored1 = storage
|
||||||
.get_response(&ResponseId::from_string(resp1_id.clone()))
|
.get_response(&ResponseId::from(resp1_id.clone()))
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.expect("first response missing");
|
.expect("first response missing");
|
||||||
@@ -261,7 +272,7 @@ async fn test_openai_router_responses_with_mock() {
|
|||||||
assert!(stored1.previous_response_id.is_none());
|
assert!(stored1.previous_response_id.is_none());
|
||||||
|
|
||||||
let stored2 = storage
|
let stored2 = storage
|
||||||
.get_response(&ResponseId::from_string(resp2_id.to_string()))
|
.get_response(&ResponseId::from(resp2_id))
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.expect("second response missing");
|
.expect("second response missing");
|
||||||
@@ -463,12 +474,17 @@ async fn test_openai_router_responses_streaming_with_mock() {
|
|||||||
"Earlier answer".to_string(),
|
"Earlier answer".to_string(),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
previous.id = ResponseId::from_string("resp_prev_chain".to_string());
|
previous.id = ResponseId::from("resp_prev_chain");
|
||||||
storage.store_response(previous).await.unwrap();
|
storage.store_response(previous).await.unwrap();
|
||||||
|
|
||||||
let router = OpenAIRouter::new(base_url, None, storage.clone())
|
let router = OpenAIRouter::new(
|
||||||
.await
|
base_url,
|
||||||
.unwrap();
|
None,
|
||||||
|
storage.clone(),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut metadata = HashMap::new();
|
let mut metadata = HashMap::new();
|
||||||
metadata.insert("topic".to_string(), json!("unicorns"));
|
metadata.insert("topic".to_string(), json!("unicorns"));
|
||||||
@@ -504,7 +520,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
|
|||||||
assert!(body_text.contains("Once upon a streamed unicorn adventure."));
|
assert!(body_text.contains("Once upon a streamed unicorn adventure."));
|
||||||
|
|
||||||
// Wait for the storage task to persist the streaming response.
|
// Wait for the storage task to persist the streaming response.
|
||||||
let target_id = ResponseId::from_string("resp_stream_123".to_string());
|
let target_id = ResponseId::from("resp_stream_123");
|
||||||
let stored = loop {
|
let stored = loop {
|
||||||
if let Some(resp) = storage.get_response(&target_id).await.unwrap() {
|
if let Some(resp) = storage.get_response(&target_id).await.unwrap() {
|
||||||
break resp;
|
break resp;
|
||||||
@@ -569,6 +585,7 @@ async fn test_unsupported_endpoints() {
|
|||||||
"https://api.openai.com".to_string(),
|
"https://api.openai.com".to_string(),
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -605,9 +622,14 @@ async fn test_openai_router_chat_completion_with_mock() {
|
|||||||
let base_url = mock_server.base_url();
|
let base_url = mock_server.base_url();
|
||||||
|
|
||||||
// Create router pointing to mock server
|
// Create router pointing to mock server
|
||||||
let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new()))
|
let router = OpenAIRouter::new(
|
||||||
.await
|
base_url,
|
||||||
.unwrap();
|
None,
|
||||||
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Create a minimal chat completion request
|
// Create a minimal chat completion request
|
||||||
let mut chat_request = create_minimal_chat_request();
|
let mut chat_request = create_minimal_chat_request();
|
||||||
@@ -642,9 +664,14 @@ async fn test_openai_e2e_with_server() {
|
|||||||
let base_url = mock_server.base_url();
|
let base_url = mock_server.base_url();
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new()))
|
let router = OpenAIRouter::new(
|
||||||
.await
|
base_url,
|
||||||
.unwrap();
|
None,
|
||||||
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Create Axum app with chat completions endpoint
|
// Create Axum app with chat completions endpoint
|
||||||
let app = Router::new().route(
|
let app = Router::new().route(
|
||||||
@@ -707,9 +734,14 @@ async fn test_openai_e2e_with_server() {
|
|||||||
async fn test_openai_router_chat_streaming_with_mock() {
|
async fn test_openai_router_chat_streaming_with_mock() {
|
||||||
let mock_server = MockOpenAIServer::new().await;
|
let mock_server = MockOpenAIServer::new().await;
|
||||||
let base_url = mock_server.base_url();
|
let base_url = mock_server.base_url();
|
||||||
let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new()))
|
let router = OpenAIRouter::new(
|
||||||
.await
|
base_url,
|
||||||
.unwrap();
|
None,
|
||||||
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Build a streaming chat request
|
// Build a streaming chat request
|
||||||
let val = json!({
|
let val = json!({
|
||||||
@@ -759,6 +791,7 @@ async fn test_openai_router_circuit_breaker() {
|
|||||||
"http://invalid-url-that-will-fail".to_string(),
|
"http://invalid-url-that-will-fail".to_string(),
|
||||||
Some(cb_config),
|
Some(cb_config),
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -786,6 +819,7 @@ async fn test_openai_router_models_auth_forwarding() {
|
|||||||
mock_server.base_url(),
|
mock_server.base_url(),
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|||||||
Reference in New Issue
Block a user