[router] Support Oracle DB(ATP) Data Connector (#10845)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
// Data connector module for response storage
|
||||
pub mod response_memory_store;
|
||||
pub mod response_noop_store;
|
||||
pub mod response_oracle_store;
|
||||
pub mod responses;
|
||||
|
||||
pub use response_memory_store::MemoryResponseStorage;
|
||||
pub use response_noop_store::NoOpResponseStorage;
|
||||
pub use response_oracle_store::OracleResponseStorage;
|
||||
pub use responses::{
|
||||
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, SharedResponseStorage,
|
||||
StoredResponse,
|
||||
|
||||
548
sgl-router/src/data_connector/response_oracle_store.rs
Normal file
548
sgl-router/src/data_connector/response_oracle_store.rs
Normal file
@@ -0,0 +1,548 @@
|
||||
use crate::config::OracleConfig;
|
||||
use crate::data_connector::responses::{
|
||||
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
|
||||
StoredResponse,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
|
||||
use oracle::{Connection, Row};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \
|
||||
tool_calls, metadata, created_at, user_id, model, raw_response FROM responses";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OracleResponseStorage {
|
||||
pool: Pool<OracleConnectionManager>,
|
||||
}
|
||||
|
||||
impl OracleResponseStorage {
|
||||
pub fn new(config: OracleConfig) -> StorageResult<Self> {
|
||||
let config = Arc::new(config);
|
||||
configure_oracle_client(&config)?;
|
||||
initialize_schema(&config)?;
|
||||
|
||||
let manager = OracleConnectionManager::new(config.clone());
|
||||
let mut builder = Pool::builder(manager)
|
||||
.max_size(config.pool_max)
|
||||
.runtime(deadpool::Runtime::Tokio1);
|
||||
|
||||
if config.pool_timeout_secs > 0 {
|
||||
builder = builder.wait_timeout(Some(Duration::from_secs(config.pool_timeout_secs)));
|
||||
}
|
||||
|
||||
let pool = builder.build().map_err(|err| {
|
||||
ResponseStorageError::StorageError(format!(
|
||||
"failed to build Oracle connection pool: {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
async fn with_connection<F, T>(&self, func: F) -> StorageResult<T>
|
||||
where
|
||||
F: FnOnce(&Connection) -> StorageResult<T> + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let connection = self.pool.get().await.map_err(map_pool_error)?;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let result = func(&connection);
|
||||
drop(connection);
|
||||
result
|
||||
})
|
||||
.await
|
||||
.map_err(|err| {
|
||||
ResponseStorageError::StorageError(format!(
|
||||
"failed to execute Oracle query task: {err}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
|
||||
fn build_response_from_row(row: &Row) -> StorageResult<StoredResponse> {
|
||||
let id: String = row
|
||||
.get(0)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch id"))?;
|
||||
let previous: Option<String> = row.get(1).map_err(|err| {
|
||||
map_oracle_error(err).into_storage_error("fetch previous_response_id")
|
||||
})?;
|
||||
let input: String = row
|
||||
.get(2)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch input"))?;
|
||||
let instructions: Option<String> = row
|
||||
.get(3)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch instructions"))?;
|
||||
let output: String = row
|
||||
.get(4)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch output"))?;
|
||||
let tool_calls_json: Option<String> = row
|
||||
.get(5)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch tool_calls"))?;
|
||||
let metadata_json: Option<String> = row
|
||||
.get(6)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch metadata"))?;
|
||||
let created_at: chrono::DateTime<chrono::Utc> = row
|
||||
.get(7)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch created_at"))?;
|
||||
let user_id: Option<String> = row
|
||||
.get(8)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch user_id"))?;
|
||||
let model: Option<String> = row
|
||||
.get(9)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch model"))?;
|
||||
let raw_response_json: Option<String> = row
|
||||
.get(10)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch raw_response"))?;
|
||||
|
||||
let previous_response_id = previous.map(ResponseId);
|
||||
let tool_calls = parse_tool_calls(tool_calls_json)?;
|
||||
let metadata = parse_metadata(metadata_json)?;
|
||||
let raw_response = parse_raw_response(raw_response_json)?;
|
||||
|
||||
Ok(StoredResponse {
|
||||
id: ResponseId(id),
|
||||
previous_response_id,
|
||||
input,
|
||||
instructions,
|
||||
output,
|
||||
tool_calls,
|
||||
metadata,
|
||||
created_at,
|
||||
user: user_id,
|
||||
model,
|
||||
raw_response,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ResponseStorage for OracleResponseStorage {
|
||||
async fn store_response(&self, response: StoredResponse) -> StorageResult<ResponseId> {
|
||||
let StoredResponse {
|
||||
id,
|
||||
previous_response_id,
|
||||
input,
|
||||
instructions,
|
||||
output,
|
||||
tool_calls,
|
||||
metadata,
|
||||
created_at,
|
||||
user,
|
||||
model,
|
||||
raw_response,
|
||||
} = response;
|
||||
|
||||
let response_id = id.clone();
|
||||
let response_id_str = response_id.0.clone();
|
||||
let previous_id = previous_response_id.map(|r| r.0);
|
||||
let json_tool_calls = serde_json::to_string(&tool_calls)?;
|
||||
let json_metadata = serde_json::to_string(&metadata)?;
|
||||
let json_raw_response = serde_json::to_string(&raw_response)?;
|
||||
|
||||
self.with_connection(move |conn| {
|
||||
conn.execute(
|
||||
"INSERT INTO responses (id, previous_response_id, input, instructions, output, \
|
||||
tool_calls, metadata, created_at, user_id, model, raw_response) \
|
||||
VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11)",
|
||||
&[
|
||||
&response_id_str,
|
||||
&previous_id,
|
||||
&input,
|
||||
&instructions,
|
||||
&output,
|
||||
&json_tool_calls,
|
||||
&json_metadata,
|
||||
&created_at,
|
||||
&user,
|
||||
&model,
|
||||
&json_raw_response,
|
||||
],
|
||||
)
|
||||
.map(|_| ())
|
||||
.map_err(map_oracle_error)
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(response_id)
|
||||
}
|
||||
|
||||
async fn get_response(
|
||||
&self,
|
||||
response_id: &ResponseId,
|
||||
) -> StorageResult<Option<StoredResponse>> {
|
||||
let id = response_id.0.clone();
|
||||
self.with_connection(move |conn| {
|
||||
let mut stmt = conn
|
||||
.statement(&format!("{} WHERE id = :1", SELECT_BASE))
|
||||
.build()
|
||||
.map_err(map_oracle_error)?;
|
||||
let mut rows = stmt.query(&[&id]).map_err(map_oracle_error)?;
|
||||
match rows.next() {
|
||||
Some(row) => {
|
||||
let row = row.map_err(map_oracle_error)?;
|
||||
OracleResponseStorage::build_response_from_row(&row).map(Some)
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_response(&self, response_id: &ResponseId) -> StorageResult<()> {
|
||||
let id = response_id.0.clone();
|
||||
self.with_connection(move |conn| {
|
||||
conn.execute("DELETE FROM responses WHERE id = :1", &[&id])
|
||||
.map(|_| ())
|
||||
.map_err(map_oracle_error)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_response_chain(
|
||||
&self,
|
||||
response_id: &ResponseId,
|
||||
max_depth: Option<usize>,
|
||||
) -> StorageResult<ResponseChain> {
|
||||
let mut chain = ResponseChain::new();
|
||||
let mut current_id = Some(response_id.clone());
|
||||
let mut visited = 0usize;
|
||||
|
||||
while let Some(ref lookup_id) = current_id {
|
||||
if let Some(limit) = max_depth {
|
||||
if visited >= limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let fetched = self.get_response(lookup_id).await?;
|
||||
match fetched {
|
||||
Some(response) => {
|
||||
current_id = response.previous_response_id.clone();
|
||||
chain.responses.push(response);
|
||||
visited += 1;
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
chain.responses.reverse();
|
||||
Ok(chain)
|
||||
}
|
||||
|
||||
async fn list_user_responses(
|
||||
&self,
|
||||
user: &str,
|
||||
limit: Option<usize>,
|
||||
) -> StorageResult<Vec<StoredResponse>> {
|
||||
let user = user.to_string();
|
||||
|
||||
self.with_connection(move |conn| {
|
||||
let sql = if let Some(limit) = limit {
|
||||
format!(
|
||||
"SELECT * FROM ({} WHERE user_id = :1 ORDER BY created_at DESC) WHERE ROWNUM <= {}",
|
||||
SELECT_BASE, limit
|
||||
)
|
||||
} else {
|
||||
format!("{} WHERE user_id = :1 ORDER BY created_at DESC", SELECT_BASE)
|
||||
};
|
||||
|
||||
let mut stmt = conn.statement(&sql).build().map_err(map_oracle_error)?;
|
||||
let mut rows = stmt.query(&[&user]).map_err(map_oracle_error)?;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for row in &mut rows {
|
||||
let row = row.map_err(map_oracle_error)?;
|
||||
results.push(OracleResponseStorage::build_response_from_row(&row)?);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_user_responses(&self, user: &str) -> StorageResult<usize> {
|
||||
let user = user.to_string();
|
||||
let affected = self
|
||||
.with_connection(move |conn| {
|
||||
conn.execute("DELETE FROM responses WHERE user_id = :1", &[&user])
|
||||
.map_err(map_oracle_error)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let deleted = affected.row_count().map_err(map_oracle_error)? as usize;
|
||||
Ok(deleted)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct OracleConnectionManager {
|
||||
params: Arc<OracleConnectParams>,
|
||||
}
|
||||
|
||||
impl OracleConnectionManager {
|
||||
fn new(config: Arc<OracleConfig>) -> Self {
|
||||
let params = OracleConnectParams {
|
||||
username: config.username.clone(),
|
||||
password: config.password.clone(),
|
||||
connect_descriptor: config.connect_descriptor.clone(),
|
||||
};
|
||||
|
||||
Self {
|
||||
params: Arc::new(params),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OracleConnectionManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OracleConnectionManager")
|
||||
.field("username", &self.params.username)
|
||||
.field("connect_descriptor", &self.params.connect_descriptor)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct OracleConnectParams {
|
||||
username: String,
|
||||
password: String,
|
||||
connect_descriptor: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OracleConnectParams {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OracleConnectParams")
|
||||
.field("username", &self.username)
|
||||
.field("connect_descriptor", &self.connect_descriptor)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Manager for OracleConnectionManager {
|
||||
type Type = Connection;
|
||||
type Error = oracle::Error;
|
||||
|
||||
fn create(
|
||||
&self,
|
||||
) -> impl std::future::Future<Output = Result<Connection, oracle::Error>> + Send {
|
||||
let params = self.params.clone();
|
||||
async move {
|
||||
let mut conn = Connection::connect(
|
||||
¶ms.username,
|
||||
¶ms.password,
|
||||
¶ms.connect_descriptor,
|
||||
)?;
|
||||
conn.set_autocommit(true);
|
||||
Ok(conn)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::manual_async_fn)]
|
||||
fn recycle(
|
||||
&self,
|
||||
conn: &mut Connection,
|
||||
_: &Metrics,
|
||||
) -> impl std::future::Future<Output = RecycleResult<Self::Error>> + Send {
|
||||
async move { conn.ping().map_err(RecycleError::Backend) }
|
||||
}
|
||||
}
|
||||
|
||||
fn configure_oracle_client(config: &OracleConfig) -> StorageResult<()> {
|
||||
if let Some(wallet_path) = &config.wallet_path {
|
||||
let wallet_path = Path::new(wallet_path);
|
||||
if !wallet_path.is_dir() {
|
||||
return Err(ResponseStorageError::StorageError(format!(
|
||||
"Oracle wallet/config path '{}' is not a directory",
|
||||
wallet_path.display()
|
||||
)));
|
||||
}
|
||||
|
||||
if !wallet_path.join("tnsnames.ora").exists() && !wallet_path.join("sqlnet.ora").exists() {
|
||||
return Err(ResponseStorageError::StorageError(format!(
|
||||
"Oracle wallet/config path '{}' is missing tnsnames.ora or sqlnet.ora",
|
||||
wallet_path.display()
|
||||
)));
|
||||
}
|
||||
|
||||
std::env::set_var("TNS_ADMIN", wallet_path);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn initialize_schema(config: &OracleConfig) -> StorageResult<()> {
|
||||
let conn = Connection::connect(
|
||||
&config.username,
|
||||
&config.password,
|
||||
&config.connect_descriptor,
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
|
||||
let exists: i64 = conn
|
||||
.query_row_as(
|
||||
"SELECT COUNT(*) FROM user_tables WHERE table_name = 'RESPONSES'",
|
||||
&[],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
|
||||
if exists == 0 {
|
||||
conn.execute(
|
||||
"CREATE TABLE responses (
|
||||
id VARCHAR2(64) PRIMARY KEY,
|
||||
previous_response_id VARCHAR2(64),
|
||||
input CLOB,
|
||||
instructions CLOB,
|
||||
output CLOB,
|
||||
tool_calls CLOB,
|
||||
metadata CLOB,
|
||||
created_at TIMESTAMP WITH TIME ZONE,
|
||||
user_id VARCHAR2(128),
|
||||
model VARCHAR2(128),
|
||||
raw_response CLOB
|
||||
)",
|
||||
&[],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
}
|
||||
|
||||
create_index_if_missing(
|
||||
&conn,
|
||||
"RESPONSES_PREV_IDX",
|
||||
"CREATE INDEX responses_prev_idx ON responses(previous_response_id)",
|
||||
)?;
|
||||
create_index_if_missing(
|
||||
&conn,
|
||||
"RESPONSES_USER_IDX",
|
||||
"CREATE INDEX responses_user_idx ON responses(user_id)",
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_index_if_missing(conn: &Connection, index_name: &str, ddl: &str) -> StorageResult<()> {
|
||||
let count: i64 = conn
|
||||
.query_row_as(
|
||||
"SELECT COUNT(*) FROM user_indexes WHERE table_name = 'RESPONSES' AND index_name = :1",
|
||||
&[&index_name],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
|
||||
if count == 0 {
|
||||
if let Err(err) = conn.execute(ddl, &[]) {
|
||||
if err.db_error().map(|db| db.code()) != Some(1408) {
|
||||
return Err(map_oracle_error(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_tool_calls(raw: Option<String>) -> StorageResult<Vec<Value>> {
|
||||
match raw {
|
||||
Some(s) if !s.is_empty() => {
|
||||
serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError)
|
||||
}
|
||||
_ => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_metadata(raw: Option<String>) -> StorageResult<HashMap<String, Value>> {
|
||||
match raw {
|
||||
Some(s) if !s.is_empty() => {
|
||||
serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError)
|
||||
}
|
||||
_ => Ok(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_raw_response(raw: Option<String>) -> StorageResult<Value> {
|
||||
match raw {
|
||||
Some(s) if !s.is_empty() => {
|
||||
serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError)
|
||||
}
|
||||
_ => Ok(Value::Null),
|
||||
}
|
||||
}
|
||||
|
||||
fn map_pool_error(err: PoolError<oracle::Error>) -> ResponseStorageError {
|
||||
match err {
|
||||
PoolError::Backend(e) => map_oracle_error(e),
|
||||
other => ResponseStorageError::StorageError(format!(
|
||||
"failed to obtain Oracle connection: {other}"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn map_oracle_error(err: oracle::Error) -> ResponseStorageError {
|
||||
if let Some(db_err) = err.db_error() {
|
||||
ResponseStorageError::StorageError(format!(
|
||||
"Oracle error (code {}): {}",
|
||||
db_err.code(),
|
||||
db_err.message()
|
||||
))
|
||||
} else {
|
||||
ResponseStorageError::StorageError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
trait OracleErrorExt {
|
||||
fn into_storage_error(self, context: &str) -> ResponseStorageError;
|
||||
}
|
||||
|
||||
impl OracleErrorExt for ResponseStorageError {
|
||||
fn into_storage_error(self, context: &str) -> ResponseStorageError {
|
||||
ResponseStorageError::StorageError(format!("{context}: {self}"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_handles_empty_input() {
|
||||
assert!(parse_tool_calls(None).unwrap().is_empty());
|
||||
assert!(parse_tool_calls(Some(String::new())).unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_round_trips() {
|
||||
let payload = json!([{ "type": "test", "value": 1 }]).to_string();
|
||||
let parsed = parse_tool_calls(Some(payload)).unwrap();
|
||||
assert_eq!(parsed.len(), 1);
|
||||
assert_eq!(parsed[0]["type"], "test");
|
||||
assert_eq!(parsed[0]["value"], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_metadata_defaults_to_empty_map() {
|
||||
assert!(parse_metadata(None).unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_metadata_round_trips() {
|
||||
let payload = json!({"key": "value", "nested": {"bool": true}}).to_string();
|
||||
let parsed = parse_metadata(Some(payload)).unwrap();
|
||||
assert_eq!(parsed.get("key").unwrap(), "value");
|
||||
assert_eq!(parsed["nested"]["bool"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_raw_response_handles_null() {
|
||||
assert_eq!(parse_raw_response(None).unwrap(), Value::Null);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_raw_response_round_trips() {
|
||||
let payload = json!({"id": "abc"}).to_string();
|
||||
let parsed = parse_raw_response(Some(payload)).unwrap();
|
||||
assert_eq!(parsed["id"], "abc");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user