[router] Support Oracle DB(ATP) Data Connector (#10845)
This commit is contained in:
@@ -19,7 +19,7 @@ name = "sglang-router"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
clap = { version = "4", features = ["derive", "env"] }
|
||||
axum = { version = "0.8.4", features = ["macros", "ws", "tracing"] }
|
||||
tower = { version = "0.5", features = ["full"] }
|
||||
tower-http = { version = "0.6", features = ["trace", "compression-gzip", "cors", "timeout", "limit", "request-id", "util"] }
|
||||
@@ -69,6 +69,7 @@ rmcp = { version = "0.6.3", features = ["client", "server",
|
||||
"reqwest",
|
||||
"auth"] }
|
||||
serde_yaml = "0.9"
|
||||
oracle = { version = "0.6.3", features = ["chrono"] }
|
||||
subtle = "2.6"
|
||||
|
||||
# gRPC and Protobuf dependencies
|
||||
|
||||
@@ -70,6 +70,9 @@ pub struct RouterConfig {
|
||||
/// History backend configuration (memory or none, default: memory)
|
||||
#[serde(default = "default_history_backend")]
|
||||
pub history_backend: HistoryBackend,
|
||||
/// Oracle history backend configuration (required when `history_backend` = "oracle")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub oracle: Option<OracleConfig>,
|
||||
}
|
||||
|
||||
fn default_history_backend() -> HistoryBackend {
|
||||
@@ -84,6 +87,70 @@ pub enum HistoryBackend {
|
||||
Memory,
|
||||
/// No history storage
|
||||
None,
|
||||
/// Oracle ATP-backed storage
|
||||
Oracle,
|
||||
}
|
||||
|
||||
/// Oracle history backend configuration
|
||||
#[derive(Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct OracleConfig {
|
||||
/// Directory containing the ATP wallet or TLS config files (optional)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub wallet_path: Option<String>,
|
||||
/// Connection descriptor / DSN (e.g. `tcps://host:port/service`)
|
||||
pub connect_descriptor: String,
|
||||
/// Database username
|
||||
pub username: String,
|
||||
/// Database password
|
||||
pub password: String,
|
||||
/// Minimum number of pooled connections to keep ready
|
||||
#[serde(default = "default_pool_min")]
|
||||
pub pool_min: usize,
|
||||
/// Maximum number of pooled connections
|
||||
#[serde(default = "default_pool_max")]
|
||||
pub pool_max: usize,
|
||||
/// Maximum time to wait for a connection from the pool (seconds)
|
||||
#[serde(default = "default_pool_timeout_secs")]
|
||||
pub pool_timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl OracleConfig {
|
||||
pub fn default_pool_min() -> usize {
|
||||
default_pool_min()
|
||||
}
|
||||
|
||||
pub fn default_pool_max() -> usize {
|
||||
default_pool_max()
|
||||
}
|
||||
|
||||
pub fn default_pool_timeout_secs() -> u64 {
|
||||
default_pool_timeout_secs()
|
||||
}
|
||||
}
|
||||
|
||||
fn default_pool_min() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_pool_max() -> usize {
|
||||
16
|
||||
}
|
||||
|
||||
fn default_pool_timeout_secs() -> u64 {
|
||||
30
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OracleConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OracleConfig")
|
||||
.field("wallet_path", &self.wallet_path)
|
||||
.field("connect_descriptor", &self.connect_descriptor)
|
||||
.field("username", &self.username)
|
||||
.field("pool_min", &self.pool_min)
|
||||
.field("pool_max", &self.pool_max)
|
||||
.field("pool_timeout_secs", &self.pool_timeout_secs)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
|
||||
@@ -381,6 +448,7 @@ impl Default for RouterConfig {
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: default_history_backend(),
|
||||
oracle: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -948,6 +1016,7 @@ mod tests {
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: default_history_backend(),
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
assert!(config.mode.is_pd_mode());
|
||||
@@ -1012,6 +1081,7 @@ mod tests {
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: default_history_backend(),
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
assert!(!config.mode.is_pd_mode());
|
||||
@@ -1072,6 +1142,7 @@ mod tests {
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: default_history_backend(),
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
assert!(config.has_service_discovery());
|
||||
|
||||
@@ -29,6 +29,12 @@ impl ConfigValidator {
|
||||
Self::validate_retry(&retry_cfg)?;
|
||||
Self::validate_circuit_breaker(&cb_cfg)?;
|
||||
|
||||
if config.history_backend == HistoryBackend::Oracle && config.oracle.is_none() {
|
||||
return Err(ConfigError::MissingRequired {
|
||||
field: "oracle".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
// Data connector module for response storage
|
||||
pub mod response_memory_store;
|
||||
pub mod response_noop_store;
|
||||
pub mod response_oracle_store;
|
||||
pub mod responses;
|
||||
|
||||
pub use response_memory_store::MemoryResponseStorage;
|
||||
pub use response_noop_store::NoOpResponseStorage;
|
||||
pub use response_oracle_store::OracleResponseStorage;
|
||||
pub use responses::{
|
||||
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, SharedResponseStorage,
|
||||
StoredResponse,
|
||||
|
||||
548
sgl-router/src/data_connector/response_oracle_store.rs
Normal file
548
sgl-router/src/data_connector/response_oracle_store.rs
Normal file
@@ -0,0 +1,548 @@
|
||||
use crate::config::OracleConfig;
|
||||
use crate::data_connector::responses::{
|
||||
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
|
||||
StoredResponse,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
|
||||
use oracle::{Connection, Row};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \
|
||||
tool_calls, metadata, created_at, user_id, model, raw_response FROM responses";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OracleResponseStorage {
|
||||
pool: Pool<OracleConnectionManager>,
|
||||
}
|
||||
|
||||
impl OracleResponseStorage {
|
||||
pub fn new(config: OracleConfig) -> StorageResult<Self> {
|
||||
let config = Arc::new(config);
|
||||
configure_oracle_client(&config)?;
|
||||
initialize_schema(&config)?;
|
||||
|
||||
let manager = OracleConnectionManager::new(config.clone());
|
||||
let mut builder = Pool::builder(manager)
|
||||
.max_size(config.pool_max)
|
||||
.runtime(deadpool::Runtime::Tokio1);
|
||||
|
||||
if config.pool_timeout_secs > 0 {
|
||||
builder = builder.wait_timeout(Some(Duration::from_secs(config.pool_timeout_secs)));
|
||||
}
|
||||
|
||||
let pool = builder.build().map_err(|err| {
|
||||
ResponseStorageError::StorageError(format!(
|
||||
"failed to build Oracle connection pool: {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
async fn with_connection<F, T>(&self, func: F) -> StorageResult<T>
|
||||
where
|
||||
F: FnOnce(&Connection) -> StorageResult<T> + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let connection = self.pool.get().await.map_err(map_pool_error)?;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let result = func(&connection);
|
||||
drop(connection);
|
||||
result
|
||||
})
|
||||
.await
|
||||
.map_err(|err| {
|
||||
ResponseStorageError::StorageError(format!(
|
||||
"failed to execute Oracle query task: {err}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
|
||||
fn build_response_from_row(row: &Row) -> StorageResult<StoredResponse> {
|
||||
let id: String = row
|
||||
.get(0)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch id"))?;
|
||||
let previous: Option<String> = row.get(1).map_err(|err| {
|
||||
map_oracle_error(err).into_storage_error("fetch previous_response_id")
|
||||
})?;
|
||||
let input: String = row
|
||||
.get(2)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch input"))?;
|
||||
let instructions: Option<String> = row
|
||||
.get(3)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch instructions"))?;
|
||||
let output: String = row
|
||||
.get(4)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch output"))?;
|
||||
let tool_calls_json: Option<String> = row
|
||||
.get(5)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch tool_calls"))?;
|
||||
let metadata_json: Option<String> = row
|
||||
.get(6)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch metadata"))?;
|
||||
let created_at: chrono::DateTime<chrono::Utc> = row
|
||||
.get(7)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch created_at"))?;
|
||||
let user_id: Option<String> = row
|
||||
.get(8)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch user_id"))?;
|
||||
let model: Option<String> = row
|
||||
.get(9)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch model"))?;
|
||||
let raw_response_json: Option<String> = row
|
||||
.get(10)
|
||||
.map_err(|err| map_oracle_error(err).into_storage_error("fetch raw_response"))?;
|
||||
|
||||
let previous_response_id = previous.map(ResponseId);
|
||||
let tool_calls = parse_tool_calls(tool_calls_json)?;
|
||||
let metadata = parse_metadata(metadata_json)?;
|
||||
let raw_response = parse_raw_response(raw_response_json)?;
|
||||
|
||||
Ok(StoredResponse {
|
||||
id: ResponseId(id),
|
||||
previous_response_id,
|
||||
input,
|
||||
instructions,
|
||||
output,
|
||||
tool_calls,
|
||||
metadata,
|
||||
created_at,
|
||||
user: user_id,
|
||||
model,
|
||||
raw_response,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ResponseStorage for OracleResponseStorage {
|
||||
async fn store_response(&self, response: StoredResponse) -> StorageResult<ResponseId> {
|
||||
let StoredResponse {
|
||||
id,
|
||||
previous_response_id,
|
||||
input,
|
||||
instructions,
|
||||
output,
|
||||
tool_calls,
|
||||
metadata,
|
||||
created_at,
|
||||
user,
|
||||
model,
|
||||
raw_response,
|
||||
} = response;
|
||||
|
||||
let response_id = id.clone();
|
||||
let response_id_str = response_id.0.clone();
|
||||
let previous_id = previous_response_id.map(|r| r.0);
|
||||
let json_tool_calls = serde_json::to_string(&tool_calls)?;
|
||||
let json_metadata = serde_json::to_string(&metadata)?;
|
||||
let json_raw_response = serde_json::to_string(&raw_response)?;
|
||||
|
||||
self.with_connection(move |conn| {
|
||||
conn.execute(
|
||||
"INSERT INTO responses (id, previous_response_id, input, instructions, output, \
|
||||
tool_calls, metadata, created_at, user_id, model, raw_response) \
|
||||
VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11)",
|
||||
&[
|
||||
&response_id_str,
|
||||
&previous_id,
|
||||
&input,
|
||||
&instructions,
|
||||
&output,
|
||||
&json_tool_calls,
|
||||
&json_metadata,
|
||||
&created_at,
|
||||
&user,
|
||||
&model,
|
||||
&json_raw_response,
|
||||
],
|
||||
)
|
||||
.map(|_| ())
|
||||
.map_err(map_oracle_error)
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(response_id)
|
||||
}
|
||||
|
||||
async fn get_response(
|
||||
&self,
|
||||
response_id: &ResponseId,
|
||||
) -> StorageResult<Option<StoredResponse>> {
|
||||
let id = response_id.0.clone();
|
||||
self.with_connection(move |conn| {
|
||||
let mut stmt = conn
|
||||
.statement(&format!("{} WHERE id = :1", SELECT_BASE))
|
||||
.build()
|
||||
.map_err(map_oracle_error)?;
|
||||
let mut rows = stmt.query(&[&id]).map_err(map_oracle_error)?;
|
||||
match rows.next() {
|
||||
Some(row) => {
|
||||
let row = row.map_err(map_oracle_error)?;
|
||||
OracleResponseStorage::build_response_from_row(&row).map(Some)
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_response(&self, response_id: &ResponseId) -> StorageResult<()> {
|
||||
let id = response_id.0.clone();
|
||||
self.with_connection(move |conn| {
|
||||
conn.execute("DELETE FROM responses WHERE id = :1", &[&id])
|
||||
.map(|_| ())
|
||||
.map_err(map_oracle_error)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_response_chain(
|
||||
&self,
|
||||
response_id: &ResponseId,
|
||||
max_depth: Option<usize>,
|
||||
) -> StorageResult<ResponseChain> {
|
||||
let mut chain = ResponseChain::new();
|
||||
let mut current_id = Some(response_id.clone());
|
||||
let mut visited = 0usize;
|
||||
|
||||
while let Some(ref lookup_id) = current_id {
|
||||
if let Some(limit) = max_depth {
|
||||
if visited >= limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let fetched = self.get_response(lookup_id).await?;
|
||||
match fetched {
|
||||
Some(response) => {
|
||||
current_id = response.previous_response_id.clone();
|
||||
chain.responses.push(response);
|
||||
visited += 1;
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
chain.responses.reverse();
|
||||
Ok(chain)
|
||||
}
|
||||
|
||||
async fn list_user_responses(
|
||||
&self,
|
||||
user: &str,
|
||||
limit: Option<usize>,
|
||||
) -> StorageResult<Vec<StoredResponse>> {
|
||||
let user = user.to_string();
|
||||
|
||||
self.with_connection(move |conn| {
|
||||
let sql = if let Some(limit) = limit {
|
||||
format!(
|
||||
"SELECT * FROM ({} WHERE user_id = :1 ORDER BY created_at DESC) WHERE ROWNUM <= {}",
|
||||
SELECT_BASE, limit
|
||||
)
|
||||
} else {
|
||||
format!("{} WHERE user_id = :1 ORDER BY created_at DESC", SELECT_BASE)
|
||||
};
|
||||
|
||||
let mut stmt = conn.statement(&sql).build().map_err(map_oracle_error)?;
|
||||
let mut rows = stmt.query(&[&user]).map_err(map_oracle_error)?;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for row in &mut rows {
|
||||
let row = row.map_err(map_oracle_error)?;
|
||||
results.push(OracleResponseStorage::build_response_from_row(&row)?);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_user_responses(&self, user: &str) -> StorageResult<usize> {
|
||||
let user = user.to_string();
|
||||
let affected = self
|
||||
.with_connection(move |conn| {
|
||||
conn.execute("DELETE FROM responses WHERE user_id = :1", &[&user])
|
||||
.map_err(map_oracle_error)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let deleted = affected.row_count().map_err(map_oracle_error)? as usize;
|
||||
Ok(deleted)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct OracleConnectionManager {
|
||||
params: Arc<OracleConnectParams>,
|
||||
}
|
||||
|
||||
impl OracleConnectionManager {
|
||||
fn new(config: Arc<OracleConfig>) -> Self {
|
||||
let params = OracleConnectParams {
|
||||
username: config.username.clone(),
|
||||
password: config.password.clone(),
|
||||
connect_descriptor: config.connect_descriptor.clone(),
|
||||
};
|
||||
|
||||
Self {
|
||||
params: Arc::new(params),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OracleConnectionManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OracleConnectionManager")
|
||||
.field("username", &self.params.username)
|
||||
.field("connect_descriptor", &self.params.connect_descriptor)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct OracleConnectParams {
|
||||
username: String,
|
||||
password: String,
|
||||
connect_descriptor: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OracleConnectParams {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OracleConnectParams")
|
||||
.field("username", &self.username)
|
||||
.field("connect_descriptor", &self.connect_descriptor)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Manager for OracleConnectionManager {
|
||||
type Type = Connection;
|
||||
type Error = oracle::Error;
|
||||
|
||||
fn create(
|
||||
&self,
|
||||
) -> impl std::future::Future<Output = Result<Connection, oracle::Error>> + Send {
|
||||
let params = self.params.clone();
|
||||
async move {
|
||||
let mut conn = Connection::connect(
|
||||
¶ms.username,
|
||||
¶ms.password,
|
||||
¶ms.connect_descriptor,
|
||||
)?;
|
||||
conn.set_autocommit(true);
|
||||
Ok(conn)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::manual_async_fn)]
|
||||
fn recycle(
|
||||
&self,
|
||||
conn: &mut Connection,
|
||||
_: &Metrics,
|
||||
) -> impl std::future::Future<Output = RecycleResult<Self::Error>> + Send {
|
||||
async move { conn.ping().map_err(RecycleError::Backend) }
|
||||
}
|
||||
}
|
||||
|
||||
fn configure_oracle_client(config: &OracleConfig) -> StorageResult<()> {
|
||||
if let Some(wallet_path) = &config.wallet_path {
|
||||
let wallet_path = Path::new(wallet_path);
|
||||
if !wallet_path.is_dir() {
|
||||
return Err(ResponseStorageError::StorageError(format!(
|
||||
"Oracle wallet/config path '{}' is not a directory",
|
||||
wallet_path.display()
|
||||
)));
|
||||
}
|
||||
|
||||
if !wallet_path.join("tnsnames.ora").exists() && !wallet_path.join("sqlnet.ora").exists() {
|
||||
return Err(ResponseStorageError::StorageError(format!(
|
||||
"Oracle wallet/config path '{}' is missing tnsnames.ora or sqlnet.ora",
|
||||
wallet_path.display()
|
||||
)));
|
||||
}
|
||||
|
||||
std::env::set_var("TNS_ADMIN", wallet_path);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn initialize_schema(config: &OracleConfig) -> StorageResult<()> {
|
||||
let conn = Connection::connect(
|
||||
&config.username,
|
||||
&config.password,
|
||||
&config.connect_descriptor,
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
|
||||
let exists: i64 = conn
|
||||
.query_row_as(
|
||||
"SELECT COUNT(*) FROM user_tables WHERE table_name = 'RESPONSES'",
|
||||
&[],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
|
||||
if exists == 0 {
|
||||
conn.execute(
|
||||
"CREATE TABLE responses (
|
||||
id VARCHAR2(64) PRIMARY KEY,
|
||||
previous_response_id VARCHAR2(64),
|
||||
input CLOB,
|
||||
instructions CLOB,
|
||||
output CLOB,
|
||||
tool_calls CLOB,
|
||||
metadata CLOB,
|
||||
created_at TIMESTAMP WITH TIME ZONE,
|
||||
user_id VARCHAR2(128),
|
||||
model VARCHAR2(128),
|
||||
raw_response CLOB
|
||||
)",
|
||||
&[],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
}
|
||||
|
||||
create_index_if_missing(
|
||||
&conn,
|
||||
"RESPONSES_PREV_IDX",
|
||||
"CREATE INDEX responses_prev_idx ON responses(previous_response_id)",
|
||||
)?;
|
||||
create_index_if_missing(
|
||||
&conn,
|
||||
"RESPONSES_USER_IDX",
|
||||
"CREATE INDEX responses_user_idx ON responses(user_id)",
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_index_if_missing(conn: &Connection, index_name: &str, ddl: &str) -> StorageResult<()> {
|
||||
let count: i64 = conn
|
||||
.query_row_as(
|
||||
"SELECT COUNT(*) FROM user_indexes WHERE table_name = 'RESPONSES' AND index_name = :1",
|
||||
&[&index_name],
|
||||
)
|
||||
.map_err(map_oracle_error)?;
|
||||
|
||||
if count == 0 {
|
||||
if let Err(err) = conn.execute(ddl, &[]) {
|
||||
if err.db_error().map(|db| db.code()) != Some(1408) {
|
||||
return Err(map_oracle_error(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_tool_calls(raw: Option<String>) -> StorageResult<Vec<Value>> {
|
||||
match raw {
|
||||
Some(s) if !s.is_empty() => {
|
||||
serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError)
|
||||
}
|
||||
_ => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_metadata(raw: Option<String>) -> StorageResult<HashMap<String, Value>> {
|
||||
match raw {
|
||||
Some(s) if !s.is_empty() => {
|
||||
serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError)
|
||||
}
|
||||
_ => Ok(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_raw_response(raw: Option<String>) -> StorageResult<Value> {
|
||||
match raw {
|
||||
Some(s) if !s.is_empty() => {
|
||||
serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError)
|
||||
}
|
||||
_ => Ok(Value::Null),
|
||||
}
|
||||
}
|
||||
|
||||
fn map_pool_error(err: PoolError<oracle::Error>) -> ResponseStorageError {
|
||||
match err {
|
||||
PoolError::Backend(e) => map_oracle_error(e),
|
||||
other => ResponseStorageError::StorageError(format!(
|
||||
"failed to obtain Oracle connection: {other}"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn map_oracle_error(err: oracle::Error) -> ResponseStorageError {
|
||||
if let Some(db_err) = err.db_error() {
|
||||
ResponseStorageError::StorageError(format!(
|
||||
"Oracle error (code {}): {}",
|
||||
db_err.code(),
|
||||
db_err.message()
|
||||
))
|
||||
} else {
|
||||
ResponseStorageError::StorageError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
trait OracleErrorExt {
|
||||
fn into_storage_error(self, context: &str) -> ResponseStorageError;
|
||||
}
|
||||
|
||||
impl OracleErrorExt for ResponseStorageError {
|
||||
fn into_storage_error(self, context: &str) -> ResponseStorageError {
|
||||
ResponseStorageError::StorageError(format!("{context}: {self}"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_handles_empty_input() {
|
||||
assert!(parse_tool_calls(None).unwrap().is_empty());
|
||||
assert!(parse_tool_calls(Some(String::new())).unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_round_trips() {
|
||||
let payload = json!([{ "type": "test", "value": 1 }]).to_string();
|
||||
let parsed = parse_tool_calls(Some(payload)).unwrap();
|
||||
assert_eq!(parsed.len(), 1);
|
||||
assert_eq!(parsed[0]["type"], "test");
|
||||
assert_eq!(parsed[0]["value"], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_metadata_defaults_to_empty_map() {
|
||||
assert!(parse_metadata(None).unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_metadata_round_trips() {
|
||||
let payload = json!({"key": "value", "nested": {"bool": true}}).to_string();
|
||||
let parsed = parse_metadata(Some(payload)).unwrap();
|
||||
assert_eq!(parsed.get("key").unwrap(), "value");
|
||||
assert_eq!(parsed["nested"]["bool"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_raw_response_handles_null() {
|
||||
assert_eq!(parse_raw_response(None).unwrap(), Value::Null);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_raw_response_round_trips() {
|
||||
let payload = json!({"id": "abc"}).to_string();
|
||||
let parsed = parse_raw_response(Some(payload)).unwrap();
|
||||
assert_eq!(parsed["id"], "abc");
|
||||
}
|
||||
}
|
||||
@@ -231,6 +231,7 @@ impl Router {
|
||||
model_path: self.model_path.clone(),
|
||||
tokenizer_path: self.tokenizer_path.clone(),
|
||||
history_backend: config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use clap::{ArgAction, Parser, ValueEnum};
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
|
||||
HealthCheckConfig, HistoryBackend, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig,
|
||||
RoutingMode,
|
||||
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
|
||||
RouterConfig, RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::metrics::PrometheusConfig;
|
||||
use sglang_router_rs::server::{self, ServerConfig};
|
||||
@@ -314,9 +314,46 @@ struct CliArgs {
|
||||
#[arg(long)]
|
||||
tokenizer_path: Option<String>,
|
||||
|
||||
/// History backend configuration (memory or none)
|
||||
#[arg(long, default_value = "memory", value_parser = ["memory", "none"])]
|
||||
/// History backend configuration (memory, none, or oracle)
|
||||
#[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])]
|
||||
history_backend: String,
|
||||
|
||||
/// Directory containing the Oracle ATP wallet/config files (optional)
|
||||
#[arg(long, env = "ATP_WALLET_PATH")]
|
||||
oracle_wallet_path: Option<String>,
|
||||
|
||||
/// Wallet TNS alias to use (e.g. `<db_name>_low`)
|
||||
#[arg(long, env = "ATP_TNS_ALIAS")]
|
||||
oracle_tns_alias: Option<String>,
|
||||
|
||||
/// Oracle connection descriptor / DSN (e.g. `tcps://host:port/service_name`)
|
||||
#[arg(long, env = "ATP_DSN")]
|
||||
oracle_dsn: Option<String>,
|
||||
|
||||
/// Oracle ATP username
|
||||
#[arg(long, env = "ATP_USER")]
|
||||
oracle_user: Option<String>,
|
||||
|
||||
/// Oracle ATP password
|
||||
#[arg(long, env = "ATP_PASSWORD")]
|
||||
oracle_password: Option<String>,
|
||||
|
||||
/// Minimum number of pooled ATP connections (defaults to 1 when omitted)
|
||||
#[arg(long, env = "ATP_POOL_MIN")]
|
||||
oracle_pool_min: Option<usize>,
|
||||
|
||||
/// Maximum number of pooled ATP connections (defaults to 16 when omitted)
|
||||
#[arg(long, env = "ATP_POOL_MAX")]
|
||||
oracle_pool_max: Option<usize>,
|
||||
|
||||
/// Connection acquisition timeout in seconds (defaults to 30 when omitted)
|
||||
#[arg(long, env = "ATP_POOL_TIMEOUT_SECS")]
|
||||
oracle_pool_timeout_secs: Option<u64>,
|
||||
}
|
||||
|
||||
enum OracleConnectSource {
|
||||
Dsn { descriptor: String },
|
||||
Wallet { path: String, alias: String },
|
||||
}
|
||||
|
||||
impl CliArgs {
|
||||
@@ -364,6 +401,87 @@ impl CliArgs {
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_oracle_connect_details(&self) -> ConfigResult<OracleConnectSource> {
|
||||
if let Some(dsn) = self.oracle_dsn.clone() {
|
||||
return Ok(OracleConnectSource::Dsn { descriptor: dsn });
|
||||
}
|
||||
|
||||
let wallet_path = self
|
||||
.oracle_wallet_path
|
||||
.clone()
|
||||
.ok_or(ConfigError::MissingRequired {
|
||||
field: "oracle_wallet_path or ATP_WALLET_PATH".to_string(),
|
||||
})?;
|
||||
|
||||
let tns_alias = self
|
||||
.oracle_tns_alias
|
||||
.clone()
|
||||
.ok_or(ConfigError::MissingRequired {
|
||||
field: "oracle_tns_alias or ATP_TNS_ALIAS".to_string(),
|
||||
})?;
|
||||
|
||||
Ok(OracleConnectSource::Wallet {
|
||||
path: wallet_path,
|
||||
alias: tns_alias,
|
||||
})
|
||||
}
|
||||
|
||||
fn build_oracle_config(&self) -> ConfigResult<OracleConfig> {
|
||||
let (wallet_path, connect_descriptor) = match self.resolve_oracle_connect_details()? {
|
||||
OracleConnectSource::Dsn { descriptor } => (None, descriptor),
|
||||
OracleConnectSource::Wallet { path, alias } => (Some(path), alias),
|
||||
};
|
||||
let username = self
|
||||
.oracle_user
|
||||
.clone()
|
||||
.ok_or(ConfigError::MissingRequired {
|
||||
field: "oracle_user or ATP_USER".to_string(),
|
||||
})?;
|
||||
let password = self
|
||||
.oracle_password
|
||||
.clone()
|
||||
.ok_or(ConfigError::MissingRequired {
|
||||
field: "oracle_password or ATP_PASSWORD".to_string(),
|
||||
})?;
|
||||
|
||||
let pool_min = self
|
||||
.oracle_pool_min
|
||||
.unwrap_or_else(OracleConfig::default_pool_min);
|
||||
let pool_max = self
|
||||
.oracle_pool_max
|
||||
.unwrap_or_else(OracleConfig::default_pool_max);
|
||||
|
||||
if pool_min == 0 {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "oracle_pool_min".to_string(),
|
||||
value: pool_min.to_string(),
|
||||
reason: "pool minimum must be at least 1".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if pool_max < pool_min {
|
||||
return Err(ConfigError::InvalidValue {
|
||||
field: "oracle_pool_max".to_string(),
|
||||
value: pool_max.to_string(),
|
||||
reason: "pool maximum must be greater than or equal to minimum".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let pool_timeout_secs = self
|
||||
.oracle_pool_timeout_secs
|
||||
.unwrap_or_else(OracleConfig::default_pool_timeout_secs);
|
||||
|
||||
Ok(OracleConfig {
|
||||
wallet_path,
|
||||
connect_descriptor,
|
||||
username,
|
||||
password,
|
||||
pool_min,
|
||||
pool_max,
|
||||
pool_timeout_secs,
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert CLI arguments to RouterConfig
|
||||
fn to_router_config(
|
||||
&self,
|
||||
@@ -459,6 +577,18 @@ impl CliArgs {
|
||||
_ => Self::determine_connection_mode(&all_urls),
|
||||
};
|
||||
|
||||
let history_backend = match self.history_backend.as_str() {
|
||||
"none" => HistoryBackend::None,
|
||||
"oracle" => HistoryBackend::Oracle,
|
||||
_ => HistoryBackend::Memory,
|
||||
};
|
||||
|
||||
let oracle = if history_backend == HistoryBackend::Oracle {
|
||||
Some(self.build_oracle_config()?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build RouterConfig
|
||||
Ok(RouterConfig {
|
||||
mode,
|
||||
@@ -511,10 +641,8 @@ impl CliArgs {
|
||||
rate_limit_tokens_per_second: None,
|
||||
model_path: self.model_path.clone(),
|
||||
tokenizer_path: self.tokenizer_path.clone(),
|
||||
history_backend: match self.history_backend.as_str() {
|
||||
"none" => HistoryBackend::None,
|
||||
_ => HistoryBackend::Memory,
|
||||
},
|
||||
history_backend,
|
||||
oracle,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use crate::{
|
||||
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
|
||||
core::{WorkerManager, WorkerRegistry, WorkerType},
|
||||
data_connector::{MemoryResponseStorage, NoOpResponseStorage, SharedResponseStorage},
|
||||
data_connector::{
|
||||
MemoryResponseStorage, NoOpResponseStorage, OracleResponseStorage, SharedResponseStorage,
|
||||
},
|
||||
logging::{self, LoggingConfig},
|
||||
metrics::{self, PrometheusConfig},
|
||||
middleware::{self, AuthConfig, QueuedRequest, TokenBucket},
|
||||
@@ -92,6 +94,17 @@ impl AppContext {
|
||||
let response_storage: SharedResponseStorage = match router_config.history_backend {
|
||||
HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()),
|
||||
HistoryBackend::None => Arc::new(NoOpResponseStorage::new()),
|
||||
HistoryBackend::Oracle => {
|
||||
let oracle_cfg = router_config.oracle.clone().ok_or_else(|| {
|
||||
"oracle configuration is required when history_backend=oracle".to_string()
|
||||
})?;
|
||||
|
||||
let storage = OracleResponseStorage::new(oracle_cfg).map_err(|err| {
|
||||
format!("failed to initialize Oracle response storage: {err}")
|
||||
})?;
|
||||
|
||||
Arc::new(storage)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
|
||||
@@ -62,6 +62,7 @@ impl TestContext {
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
Self::new_with_config(config, worker_configs).await
|
||||
@@ -1401,6 +1402,7 @@ mod error_tests {
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
let ctx = TestContext::new_with_config(
|
||||
@@ -1760,6 +1762,7 @@ mod pd_mode_tests {
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
// Create app context
|
||||
@@ -1923,6 +1926,7 @@ mod request_id_tests {
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
let ctx = TestContext::new_with_config(
|
||||
|
||||
@@ -10,7 +10,9 @@ use axum::{
|
||||
};
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::{
|
||||
config::{RouterConfig, RoutingMode},
|
||||
config::{
|
||||
ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode,
|
||||
},
|
||||
data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage, StoredResponse},
|
||||
protocols::spec::{
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput,
|
||||
@@ -823,3 +825,69 @@ async fn test_openai_router_models_auth_forwarding() {
|
||||
let models: serde_json::Value = serde_json::from_str(&body_str).unwrap();
|
||||
assert_eq!(models["object"], "list");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oracle_config_validation_requires_config_when_enabled() {
|
||||
let config = RouterConfig {
|
||||
mode: RoutingMode::OpenAI {
|
||||
worker_urls: vec!["https://api.openai.com".to_string()],
|
||||
},
|
||||
history_backend: HistoryBackend::Oracle,
|
||||
oracle: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err =
|
||||
ConfigValidator::validate(&config).expect_err("config should fail without oracle details");
|
||||
|
||||
match err {
|
||||
ConfigError::MissingRequired { field } => {
|
||||
assert_eq!(field, "oracle");
|
||||
}
|
||||
other => panic!("unexpected error: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oracle_config_validation_accepts_dsn_only() {
|
||||
let config = RouterConfig {
|
||||
mode: RoutingMode::OpenAI {
|
||||
worker_urls: vec!["https://api.openai.com".to_string()],
|
||||
},
|
||||
history_backend: HistoryBackend::Oracle,
|
||||
oracle: Some(OracleConfig {
|
||||
wallet_path: None,
|
||||
connect_descriptor: "tcps://db.example.com:1522/service".to_string(),
|
||||
username: "scott".to_string(),
|
||||
password: "tiger".to_string(),
|
||||
pool_min: 1,
|
||||
pool_max: 4,
|
||||
pool_timeout_secs: 30,
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
ConfigValidator::validate(&config).expect("dsn-based config should validate");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oracle_config_validation_accepts_wallet_alias() {
|
||||
let config = RouterConfig {
|
||||
mode: RoutingMode::OpenAI {
|
||||
worker_urls: vec!["https://api.openai.com".to_string()],
|
||||
},
|
||||
history_backend: HistoryBackend::Oracle,
|
||||
oracle: Some(OracleConfig {
|
||||
wallet_path: Some("/etc/sglang/oracle-wallet".to_string()),
|
||||
connect_descriptor: "db_low".to_string(),
|
||||
username: "app_user".to_string(),
|
||||
password: "secret".to_string(),
|
||||
pool_min: 1,
|
||||
pool_max: 8,
|
||||
pool_timeout_secs: 45,
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
ConfigValidator::validate(&config).expect("wallet-based config should validate");
|
||||
}
|
||||
|
||||
@@ -208,6 +208,7 @@ mod test_pd_routing {
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
|
||||
oracle: None,
|
||||
};
|
||||
|
||||
let app_context =
|
||||
|
||||
Reference in New Issue
Block a user