From a73eb8cd200d3f3808e16a727c06b9fef4c09828 Mon Sep 17 00:00:00 2001 From: Keyang Ru Date: Wed, 24 Sep 2025 20:59:32 -0700 Subject: [PATCH] [router] Support Oracle DB(ATP) Data Connector (#10845) --- sgl-router/Cargo.toml | 3 +- sgl-router/src/config/types.rs | 71 +++ sgl-router/src/config/validation.rs | 6 + sgl-router/src/data_connector/mod.rs | 2 + .../data_connector/response_oracle_store.rs | 548 ++++++++++++++++++ sgl-router/src/lib.rs | 1 + sgl-router/src/main.rs | 144 ++++- sgl-router/src/server.rs | 15 +- sgl-router/tests/api_endpoints_test.rs | 4 + sgl-router/tests/test_openai_routing.rs | 70 ++- sgl-router/tests/test_pd_routing.rs | 1 + 11 files changed, 854 insertions(+), 11 deletions(-) create mode 100644 sgl-router/src/data_connector/response_oracle_store.rs diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 5fc9c1a8d..18f846ed5 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -19,7 +19,7 @@ name = "sglang-router" path = "src/main.rs" [dependencies] -clap = { version = "4", features = ["derive"] } +clap = { version = "4", features = ["derive", "env"] } axum = { version = "0.8.4", features = ["macros", "ws", "tracing"] } tower = { version = "0.5", features = ["full"] } tower-http = { version = "0.6", features = ["trace", "compression-gzip", "cors", "timeout", "limit", "request-id", "util"] } @@ -69,6 +69,7 @@ rmcp = { version = "0.6.3", features = ["client", "server", "reqwest", "auth"] } serde_yaml = "0.9" +oracle = { version = "0.6.3", features = ["chrono"] } subtle = "2.6" # gRPC and Protobuf dependencies diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index a46c23dbd..6c0029f67 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -70,6 +70,9 @@ pub struct RouterConfig { /// History backend configuration (memory or none, default: memory) #[serde(default = "default_history_backend")] pub history_backend: HistoryBackend, + /// Oracle history backend configuration (required when `history_backend` = "oracle") + #[serde(skip_serializing_if = "Option::is_none")] + pub oracle: Option, } fn default_history_backend() -> HistoryBackend { @@ -84,6 +87,70 @@ pub enum HistoryBackend { Memory, /// No history storage None, + /// Oracle ATP-backed storage + Oracle, +} + +/// Oracle history backend configuration +#[derive(Clone, Serialize, Deserialize, PartialEq)] +pub struct OracleConfig { + /// Directory containing the ATP wallet or TLS config files (optional) + #[serde(skip_serializing_if = "Option::is_none")] + pub wallet_path: Option, + /// Connection descriptor / DSN (e.g. `tcps://host:port/service`) + pub connect_descriptor: String, + /// Database username + pub username: String, + /// Database password + pub password: String, + /// Minimum number of pooled connections to keep ready + #[serde(default = "default_pool_min")] + pub pool_min: usize, + /// Maximum number of pooled connections + #[serde(default = "default_pool_max")] + pub pool_max: usize, + /// Maximum time to wait for a connection from the pool (seconds) + #[serde(default = "default_pool_timeout_secs")] + pub pool_timeout_secs: u64, +} + +impl OracleConfig { + pub fn default_pool_min() -> usize { + default_pool_min() + } + + pub fn default_pool_max() -> usize { + default_pool_max() + } + + pub fn default_pool_timeout_secs() -> u64 { + default_pool_timeout_secs() + } +} + +fn default_pool_min() -> usize { + 1 +} + +fn default_pool_max() -> usize { + 16 +} + +fn default_pool_timeout_secs() -> u64 { + 30 +} + +impl std::fmt::Debug for OracleConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OracleConfig") + .field("wallet_path", &self.wallet_path) + .field("connect_descriptor", &self.connect_descriptor) + .field("username", &self.username) + .field("pool_min", &self.pool_min) + .field("pool_max", &self.pool_max) + .field("pool_timeout_secs", &self.pool_timeout_secs) + .finish() + } } #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] @@ -381,6 +448,7 @@ impl Default for RouterConfig { model_path: None, tokenizer_path: None, history_backend: default_history_backend(), + oracle: None, } } } @@ -948,6 +1016,7 @@ mod tests { model_path: None, tokenizer_path: None, history_backend: default_history_backend(), + oracle: None, }; assert!(config.mode.is_pd_mode()); @@ -1012,6 +1081,7 @@ mod tests { model_path: None, tokenizer_path: None, history_backend: default_history_backend(), + oracle: None, }; assert!(!config.mode.is_pd_mode()); @@ -1072,6 +1142,7 @@ mod tests { model_path: None, tokenizer_path: None, history_backend: default_history_backend(), + oracle: None, }; assert!(config.has_service_discovery()); diff --git a/sgl-router/src/config/validation.rs b/sgl-router/src/config/validation.rs index 710ad3fc8..2d94325aa 100644 --- a/sgl-router/src/config/validation.rs +++ b/sgl-router/src/config/validation.rs @@ -29,6 +29,12 @@ impl ConfigValidator { Self::validate_retry(&retry_cfg)?; Self::validate_circuit_breaker(&cb_cfg)?; + if config.history_backend == HistoryBackend::Oracle && config.oracle.is_none() { + return Err(ConfigError::MissingRequired { + field: "oracle".to_string(), + }); + } + Ok(()) } diff --git a/sgl-router/src/data_connector/mod.rs b/sgl-router/src/data_connector/mod.rs index 1b4bf2073..e79cacea8 100644 --- a/sgl-router/src/data_connector/mod.rs +++ b/sgl-router/src/data_connector/mod.rs @@ -1,10 +1,12 @@ // Data connector module for response storage pub mod response_memory_store; pub mod response_noop_store; +pub mod response_oracle_store; pub mod responses; pub use response_memory_store::MemoryResponseStorage; pub use response_noop_store::NoOpResponseStorage; +pub use response_oracle_store::OracleResponseStorage; pub use responses::{ ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, SharedResponseStorage, StoredResponse, diff --git a/sgl-router/src/data_connector/response_oracle_store.rs b/sgl-router/src/data_connector/response_oracle_store.rs new file mode 100644 index 000000000..2622b59e2 --- /dev/null +++ b/sgl-router/src/data_connector/response_oracle_store.rs @@ -0,0 +1,548 @@ +use crate::config::OracleConfig; +use crate::data_connector::responses::{ + ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult, + StoredResponse, +}; +use async_trait::async_trait; +use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult}; +use oracle::{Connection, Row}; +use serde_json::Value; +use std::collections::HashMap; +use std::path::Path; +use std::sync::Arc; +use std::time::Duration; + +const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \ + tool_calls, metadata, created_at, user_id, model, raw_response FROM responses"; + +#[derive(Clone)] +pub struct OracleResponseStorage { + pool: Pool, +} + +impl OracleResponseStorage { + pub fn new(config: OracleConfig) -> StorageResult { + let config = Arc::new(config); + configure_oracle_client(&config)?; + initialize_schema(&config)?; + + let manager = OracleConnectionManager::new(config.clone()); + let mut builder = Pool::builder(manager) + .max_size(config.pool_max) + .runtime(deadpool::Runtime::Tokio1); + + if config.pool_timeout_secs > 0 { + builder = builder.wait_timeout(Some(Duration::from_secs(config.pool_timeout_secs))); + } + + let pool = builder.build().map_err(|err| { + ResponseStorageError::StorageError(format!( + "failed to build Oracle connection pool: {err}" + )) + })?; + + Ok(Self { pool }) + } + + async fn with_connection(&self, func: F) -> StorageResult + where + F: FnOnce(&Connection) -> StorageResult + Send + 'static, + T: Send + 'static, + { + let connection = self.pool.get().await.map_err(map_pool_error)?; + + tokio::task::spawn_blocking(move || { + let result = func(&connection); + drop(connection); + result + }) + .await + .map_err(|err| { + ResponseStorageError::StorageError(format!( + "failed to execute Oracle query task: {err}" + )) + })? + } + + fn build_response_from_row(row: &Row) -> StorageResult { + let id: String = row + .get(0) + .map_err(|err| map_oracle_error(err).into_storage_error("fetch id"))?; + let previous: Option = row.get(1).map_err(|err| { + map_oracle_error(err).into_storage_error("fetch previous_response_id") + })?; + let input: String = row + .get(2) + .map_err(|err| map_oracle_error(err).into_storage_error("fetch input"))?; + let instructions: Option = row + .get(3) + .map_err(|err| map_oracle_error(err).into_storage_error("fetch instructions"))?; + let output: String = row + .get(4) + .map_err(|err| map_oracle_error(err).into_storage_error("fetch output"))?; + let tool_calls_json: Option = row + .get(5) + .map_err(|err| map_oracle_error(err).into_storage_error("fetch tool_calls"))?; + let metadata_json: Option = row + .get(6) + .map_err(|err| map_oracle_error(err).into_storage_error("fetch metadata"))?; + let created_at: chrono::DateTime = row + .get(7) + .map_err(|err| map_oracle_error(err).into_storage_error("fetch created_at"))?; + let user_id: Option = row + .get(8) + .map_err(|err| map_oracle_error(err).into_storage_error("fetch user_id"))?; + let model: Option = row + .get(9) + .map_err(|err| map_oracle_error(err).into_storage_error("fetch model"))?; + let raw_response_json: Option = row + .get(10) + .map_err(|err| map_oracle_error(err).into_storage_error("fetch raw_response"))?; + + let previous_response_id = previous.map(ResponseId); + let tool_calls = parse_tool_calls(tool_calls_json)?; + let metadata = parse_metadata(metadata_json)?; + let raw_response = parse_raw_response(raw_response_json)?; + + Ok(StoredResponse { + id: ResponseId(id), + previous_response_id, + input, + instructions, + output, + tool_calls, + metadata, + created_at, + user: user_id, + model, + raw_response, + }) + } +} + +#[async_trait] +impl ResponseStorage for OracleResponseStorage { + async fn store_response(&self, response: StoredResponse) -> StorageResult { + let StoredResponse { + id, + previous_response_id, + input, + instructions, + output, + tool_calls, + metadata, + created_at, + user, + model, + raw_response, + } = response; + + let response_id = id.clone(); + let response_id_str = response_id.0.clone(); + let previous_id = previous_response_id.map(|r| r.0); + let json_tool_calls = serde_json::to_string(&tool_calls)?; + let json_metadata = serde_json::to_string(&metadata)?; + let json_raw_response = serde_json::to_string(&raw_response)?; + + self.with_connection(move |conn| { + conn.execute( + "INSERT INTO responses (id, previous_response_id, input, instructions, output, \ + tool_calls, metadata, created_at, user_id, model, raw_response) \ + VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11)", + &[ + &response_id_str, + &previous_id, + &input, + &instructions, + &output, + &json_tool_calls, + &json_metadata, + &created_at, + &user, + &model, + &json_raw_response, + ], + ) + .map(|_| ()) + .map_err(map_oracle_error) + }) + .await?; + + Ok(response_id) + } + + async fn get_response( + &self, + response_id: &ResponseId, + ) -> StorageResult> { + let id = response_id.0.clone(); + self.with_connection(move |conn| { + let mut stmt = conn + .statement(&format!("{} WHERE id = :1", SELECT_BASE)) + .build() + .map_err(map_oracle_error)?; + let mut rows = stmt.query(&[&id]).map_err(map_oracle_error)?; + match rows.next() { + Some(row) => { + let row = row.map_err(map_oracle_error)?; + OracleResponseStorage::build_response_from_row(&row).map(Some) + } + None => Ok(None), + } + }) + .await + } + + async fn delete_response(&self, response_id: &ResponseId) -> StorageResult<()> { + let id = response_id.0.clone(); + self.with_connection(move |conn| { + conn.execute("DELETE FROM responses WHERE id = :1", &[&id]) + .map(|_| ()) + .map_err(map_oracle_error) + }) + .await + } + + async fn get_response_chain( + &self, + response_id: &ResponseId, + max_depth: Option, + ) -> StorageResult { + let mut chain = ResponseChain::new(); + let mut current_id = Some(response_id.clone()); + let mut visited = 0usize; + + while let Some(ref lookup_id) = current_id { + if let Some(limit) = max_depth { + if visited >= limit { + break; + } + } + + let fetched = self.get_response(lookup_id).await?; + match fetched { + Some(response) => { + current_id = response.previous_response_id.clone(); + chain.responses.push(response); + visited += 1; + } + None => break, + } + } + + chain.responses.reverse(); + Ok(chain) + } + + async fn list_user_responses( + &self, + user: &str, + limit: Option, + ) -> StorageResult> { + let user = user.to_string(); + + self.with_connection(move |conn| { + let sql = if let Some(limit) = limit { + format!( + "SELECT * FROM ({} WHERE user_id = :1 ORDER BY created_at DESC) WHERE ROWNUM <= {}", + SELECT_BASE, limit + ) + } else { + format!("{} WHERE user_id = :1 ORDER BY created_at DESC", SELECT_BASE) + }; + + let mut stmt = conn.statement(&sql).build().map_err(map_oracle_error)?; + let mut rows = stmt.query(&[&user]).map_err(map_oracle_error)?; + let mut results = Vec::new(); + + for row in &mut rows { + let row = row.map_err(map_oracle_error)?; + results.push(OracleResponseStorage::build_response_from_row(&row)?); + } + + Ok(results) + }) + .await + } + + async fn delete_user_responses(&self, user: &str) -> StorageResult { + let user = user.to_string(); + let affected = self + .with_connection(move |conn| { + conn.execute("DELETE FROM responses WHERE user_id = :1", &[&user]) + .map_err(map_oracle_error) + }) + .await?; + + let deleted = affected.row_count().map_err(map_oracle_error)? as usize; + Ok(deleted) + } +} + +#[derive(Clone)] +struct OracleConnectionManager { + params: Arc, +} + +impl OracleConnectionManager { + fn new(config: Arc) -> Self { + let params = OracleConnectParams { + username: config.username.clone(), + password: config.password.clone(), + connect_descriptor: config.connect_descriptor.clone(), + }; + + Self { + params: Arc::new(params), + } + } +} + +impl std::fmt::Debug for OracleConnectionManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OracleConnectionManager") + .field("username", &self.params.username) + .field("connect_descriptor", &self.params.connect_descriptor) + .finish() + } +} + +#[derive(Clone)] +struct OracleConnectParams { + username: String, + password: String, + connect_descriptor: String, +} + +impl std::fmt::Debug for OracleConnectParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OracleConnectParams") + .field("username", &self.username) + .field("connect_descriptor", &self.connect_descriptor) + .finish() + } +} + +#[async_trait] +impl Manager for OracleConnectionManager { + type Type = Connection; + type Error = oracle::Error; + + fn create( + &self, + ) -> impl std::future::Future> + Send { + let params = self.params.clone(); + async move { + let mut conn = Connection::connect( + ¶ms.username, + ¶ms.password, + ¶ms.connect_descriptor, + )?; + conn.set_autocommit(true); + Ok(conn) + } + } + + #[allow(clippy::manual_async_fn)] + fn recycle( + &self, + conn: &mut Connection, + _: &Metrics, + ) -> impl std::future::Future> + Send { + async move { conn.ping().map_err(RecycleError::Backend) } + } +} + +fn configure_oracle_client(config: &OracleConfig) -> StorageResult<()> { + if let Some(wallet_path) = &config.wallet_path { + let wallet_path = Path::new(wallet_path); + if !wallet_path.is_dir() { + return Err(ResponseStorageError::StorageError(format!( + "Oracle wallet/config path '{}' is not a directory", + wallet_path.display() + ))); + } + + if !wallet_path.join("tnsnames.ora").exists() && !wallet_path.join("sqlnet.ora").exists() { + return Err(ResponseStorageError::StorageError(format!( + "Oracle wallet/config path '{}' is missing tnsnames.ora or sqlnet.ora", + wallet_path.display() + ))); + } + + std::env::set_var("TNS_ADMIN", wallet_path); + } + Ok(()) +} + +fn initialize_schema(config: &OracleConfig) -> StorageResult<()> { + let conn = Connection::connect( + &config.username, + &config.password, + &config.connect_descriptor, + ) + .map_err(map_oracle_error)?; + + let exists: i64 = conn + .query_row_as( + "SELECT COUNT(*) FROM user_tables WHERE table_name = 'RESPONSES'", + &[], + ) + .map_err(map_oracle_error)?; + + if exists == 0 { + conn.execute( + "CREATE TABLE responses ( + id VARCHAR2(64) PRIMARY KEY, + previous_response_id VARCHAR2(64), + input CLOB, + instructions CLOB, + output CLOB, + tool_calls CLOB, + metadata CLOB, + created_at TIMESTAMP WITH TIME ZONE, + user_id VARCHAR2(128), + model VARCHAR2(128), + raw_response CLOB + )", + &[], + ) + .map_err(map_oracle_error)?; + } + + create_index_if_missing( + &conn, + "RESPONSES_PREV_IDX", + "CREATE INDEX responses_prev_idx ON responses(previous_response_id)", + )?; + create_index_if_missing( + &conn, + "RESPONSES_USER_IDX", + "CREATE INDEX responses_user_idx ON responses(user_id)", + )?; + + Ok(()) +} + +fn create_index_if_missing(conn: &Connection, index_name: &str, ddl: &str) -> StorageResult<()> { + let count: i64 = conn + .query_row_as( + "SELECT COUNT(*) FROM user_indexes WHERE table_name = 'RESPONSES' AND index_name = :1", + &[&index_name], + ) + .map_err(map_oracle_error)?; + + if count == 0 { + if let Err(err) = conn.execute(ddl, &[]) { + if err.db_error().map(|db| db.code()) != Some(1408) { + return Err(map_oracle_error(err)); + } + } + } + + Ok(()) +} + +fn parse_tool_calls(raw: Option) -> StorageResult> { + match raw { + Some(s) if !s.is_empty() => { + serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError) + } + _ => Ok(Vec::new()), + } +} + +fn parse_metadata(raw: Option) -> StorageResult> { + match raw { + Some(s) if !s.is_empty() => { + serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError) + } + _ => Ok(HashMap::new()), + } +} + +fn parse_raw_response(raw: Option) -> StorageResult { + match raw { + Some(s) if !s.is_empty() => { + serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError) + } + _ => Ok(Value::Null), + } +} + +fn map_pool_error(err: PoolError) -> ResponseStorageError { + match err { + PoolError::Backend(e) => map_oracle_error(e), + other => ResponseStorageError::StorageError(format!( + "failed to obtain Oracle connection: {other}" + )), + } +} + +fn map_oracle_error(err: oracle::Error) -> ResponseStorageError { + if let Some(db_err) = err.db_error() { + ResponseStorageError::StorageError(format!( + "Oracle error (code {}): {}", + db_err.code(), + db_err.message() + )) + } else { + ResponseStorageError::StorageError(err.to_string()) + } +} + +trait OracleErrorExt { + fn into_storage_error(self, context: &str) -> ResponseStorageError; +} + +impl OracleErrorExt for ResponseStorageError { + fn into_storage_error(self, context: &str) -> ResponseStorageError { + ResponseStorageError::StorageError(format!("{context}: {self}")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn parse_tool_calls_handles_empty_input() { + assert!(parse_tool_calls(None).unwrap().is_empty()); + assert!(parse_tool_calls(Some(String::new())).unwrap().is_empty()); + } + + #[test] + fn parse_tool_calls_round_trips() { + let payload = json!([{ "type": "test", "value": 1 }]).to_string(); + let parsed = parse_tool_calls(Some(payload)).unwrap(); + assert_eq!(parsed.len(), 1); + assert_eq!(parsed[0]["type"], "test"); + assert_eq!(parsed[0]["value"], 1); + } + + #[test] + fn parse_metadata_defaults_to_empty_map() { + assert!(parse_metadata(None).unwrap().is_empty()); + } + + #[test] + fn parse_metadata_round_trips() { + let payload = json!({"key": "value", "nested": {"bool": true}}).to_string(); + let parsed = parse_metadata(Some(payload)).unwrap(); + assert_eq!(parsed.get("key").unwrap(), "value"); + assert_eq!(parsed["nested"]["bool"], true); + } + + #[test] + fn parse_raw_response_handles_null() { + assert_eq!(parse_raw_response(None).unwrap(), Value::Null); + } + + #[test] + fn parse_raw_response_round_trips() { + let payload = json!({"id": "abc"}).to_string(); + let parsed = parse_raw_response(Some(payload)).unwrap(); + assert_eq!(parsed["id"], "abc"); + } +} diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 1907ed2e5..36c6a02d7 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -231,6 +231,7 @@ impl Router { model_path: self.model_path.clone(), tokenizer_path: self.tokenizer_path.clone(), history_backend: config::HistoryBackend::Memory, + oracle: None, }) } } diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index 243c32b33..3380ab8a1 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -1,8 +1,8 @@ use clap::{ArgAction, Parser, ValueEnum}; use sglang_router_rs::config::{ CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig, - HealthCheckConfig, HistoryBackend, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, - RoutingMode, + HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, + RouterConfig, RoutingMode, }; use sglang_router_rs::metrics::PrometheusConfig; use sglang_router_rs::server::{self, ServerConfig}; @@ -314,9 +314,46 @@ struct CliArgs { #[arg(long)] tokenizer_path: Option, - /// History backend configuration (memory or none) - #[arg(long, default_value = "memory", value_parser = ["memory", "none"])] + /// History backend configuration (memory, none, or oracle) + #[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])] history_backend: String, + + /// Directory containing the Oracle ATP wallet/config files (optional) + #[arg(long, env = "ATP_WALLET_PATH")] + oracle_wallet_path: Option, + + /// Wallet TNS alias to use (e.g. `_low`) + #[arg(long, env = "ATP_TNS_ALIAS")] + oracle_tns_alias: Option, + + /// Oracle connection descriptor / DSN (e.g. `tcps://host:port/service_name`) + #[arg(long, env = "ATP_DSN")] + oracle_dsn: Option, + + /// Oracle ATP username + #[arg(long, env = "ATP_USER")] + oracle_user: Option, + + /// Oracle ATP password + #[arg(long, env = "ATP_PASSWORD")] + oracle_password: Option, + + /// Minimum number of pooled ATP connections (defaults to 1 when omitted) + #[arg(long, env = "ATP_POOL_MIN")] + oracle_pool_min: Option, + + /// Maximum number of pooled ATP connections (defaults to 16 when omitted) + #[arg(long, env = "ATP_POOL_MAX")] + oracle_pool_max: Option, + + /// Connection acquisition timeout in seconds (defaults to 30 when omitted) + #[arg(long, env = "ATP_POOL_TIMEOUT_SECS")] + oracle_pool_timeout_secs: Option, +} + +enum OracleConnectSource { + Dsn { descriptor: String }, + Wallet { path: String, alias: String }, } impl CliArgs { @@ -364,6 +401,87 @@ impl CliArgs { } } + fn resolve_oracle_connect_details(&self) -> ConfigResult { + if let Some(dsn) = self.oracle_dsn.clone() { + return Ok(OracleConnectSource::Dsn { descriptor: dsn }); + } + + let wallet_path = self + .oracle_wallet_path + .clone() + .ok_or(ConfigError::MissingRequired { + field: "oracle_wallet_path or ATP_WALLET_PATH".to_string(), + })?; + + let tns_alias = self + .oracle_tns_alias + .clone() + .ok_or(ConfigError::MissingRequired { + field: "oracle_tns_alias or ATP_TNS_ALIAS".to_string(), + })?; + + Ok(OracleConnectSource::Wallet { + path: wallet_path, + alias: tns_alias, + }) + } + + fn build_oracle_config(&self) -> ConfigResult { + let (wallet_path, connect_descriptor) = match self.resolve_oracle_connect_details()? { + OracleConnectSource::Dsn { descriptor } => (None, descriptor), + OracleConnectSource::Wallet { path, alias } => (Some(path), alias), + }; + let username = self + .oracle_user + .clone() + .ok_or(ConfigError::MissingRequired { + field: "oracle_user or ATP_USER".to_string(), + })?; + let password = self + .oracle_password + .clone() + .ok_or(ConfigError::MissingRequired { + field: "oracle_password or ATP_PASSWORD".to_string(), + })?; + + let pool_min = self + .oracle_pool_min + .unwrap_or_else(OracleConfig::default_pool_min); + let pool_max = self + .oracle_pool_max + .unwrap_or_else(OracleConfig::default_pool_max); + + if pool_min == 0 { + return Err(ConfigError::InvalidValue { + field: "oracle_pool_min".to_string(), + value: pool_min.to_string(), + reason: "pool minimum must be at least 1".to_string(), + }); + } + + if pool_max < pool_min { + return Err(ConfigError::InvalidValue { + field: "oracle_pool_max".to_string(), + value: pool_max.to_string(), + reason: "pool maximum must be greater than or equal to minimum".to_string(), + }); + } + + let pool_timeout_secs = self + .oracle_pool_timeout_secs + .unwrap_or_else(OracleConfig::default_pool_timeout_secs); + + Ok(OracleConfig { + wallet_path, + connect_descriptor, + username, + password, + pool_min, + pool_max, + pool_timeout_secs, + }) + } + /// Convert CLI arguments to RouterConfig fn to_router_config( &self, @@ -459,6 +577,18 @@ impl CliArgs { _ => Self::determine_connection_mode(&all_urls), }; + let history_backend = match self.history_backend.as_str() { + "none" => HistoryBackend::None, + "oracle" => HistoryBackend::Oracle, + _ => HistoryBackend::Memory, + }; + + let oracle = if history_backend == HistoryBackend::Oracle { + Some(self.build_oracle_config()?) + } else { + None + }; + // Build RouterConfig Ok(RouterConfig { mode, @@ -511,10 +641,8 @@ impl CliArgs { rate_limit_tokens_per_second: None, model_path: self.model_path.clone(), tokenizer_path: self.tokenizer_path.clone(), - history_backend: match self.history_backend.as_str() { - "none" => HistoryBackend::None, - _ => HistoryBackend::Memory, - }, + history_backend, + oracle, }) } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 7b9a5dd4a..7a7c0be65 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,7 +1,9 @@ use crate::{ config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, core::{WorkerManager, WorkerRegistry, WorkerType}, - data_connector::{MemoryResponseStorage, NoOpResponseStorage, SharedResponseStorage}, + data_connector::{ + MemoryResponseStorage, NoOpResponseStorage, OracleResponseStorage, SharedResponseStorage, + }, logging::{self, LoggingConfig}, metrics::{self, PrometheusConfig}, middleware::{self, AuthConfig, QueuedRequest, TokenBucket}, @@ -92,6 +94,17 @@ impl AppContext { let response_storage: SharedResponseStorage = match router_config.history_backend { HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()), HistoryBackend::None => Arc::new(NoOpResponseStorage::new()), + HistoryBackend::Oracle => { + let oracle_cfg = router_config.oracle.clone().ok_or_else(|| { + "oracle configuration is required when history_backend=oracle".to_string() + })?; + + let storage = OracleResponseStorage::new(oracle_cfg).map_err(|err| { + format!("failed to initialize Oracle response storage: {err}") + })?; + + Arc::new(storage) + } }; Ok(Self { diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index b5515b261..b1d62af89 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -62,6 +62,7 @@ impl TestContext { model_path: None, tokenizer_path: None, history_backend: sglang_router_rs::config::HistoryBackend::Memory, + oracle: None, }; Self::new_with_config(config, worker_configs).await @@ -1401,6 +1402,7 @@ mod error_tests { model_path: None, tokenizer_path: None, history_backend: sglang_router_rs::config::HistoryBackend::Memory, + oracle: None, }; let ctx = TestContext::new_with_config( @@ -1760,6 +1762,7 @@ mod pd_mode_tests { model_path: None, tokenizer_path: None, history_backend: sglang_router_rs::config::HistoryBackend::Memory, + oracle: None, }; // Create app context @@ -1923,6 +1926,7 @@ mod request_id_tests { model_path: None, tokenizer_path: None, history_backend: sglang_router_rs::config::HistoryBackend::Memory, + oracle: None, }; let ctx = TestContext::new_with_config( diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index feeab95b9..a77faac15 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -10,7 +10,9 @@ use axum::{ }; use serde_json::json; use sglang_router_rs::{ - config::{RouterConfig, RoutingMode}, + config::{ + ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode, + }, data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage, StoredResponse}, protocols::spec::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput, @@ -823,3 +825,69 @@ async fn test_openai_router_models_auth_forwarding() { let models: serde_json::Value = serde_json::from_str(&body_str).unwrap(); assert_eq!(models["object"], "list"); } + +#[test] +fn oracle_config_validation_requires_config_when_enabled() { + let config = RouterConfig { + mode: RoutingMode::OpenAI { + worker_urls: vec!["https://api.openai.com".to_string()], + }, + history_backend: HistoryBackend::Oracle, + oracle: None, + ..Default::default() + }; + + let err = + ConfigValidator::validate(&config).expect_err("config should fail without oracle details"); + + match err { + ConfigError::MissingRequired { field } => { + assert_eq!(field, "oracle"); + } + other => panic!("unexpected error: {:?}", other), + } +} + +#[test] +fn oracle_config_validation_accepts_dsn_only() { + let config = RouterConfig { + mode: RoutingMode::OpenAI { + worker_urls: vec!["https://api.openai.com".to_string()], + }, + history_backend: HistoryBackend::Oracle, + oracle: Some(OracleConfig { + wallet_path: None, + connect_descriptor: "tcps://db.example.com:1522/service".to_string(), + username: "scott".to_string(), + password: "tiger".to_string(), + pool_min: 1, + pool_max: 4, + pool_timeout_secs: 30, + }), + ..Default::default() + }; + + ConfigValidator::validate(&config).expect("dsn-based config should validate"); +} + +#[test] +fn oracle_config_validation_accepts_wallet_alias() { + let config = RouterConfig { + mode: RoutingMode::OpenAI { + worker_urls: vec!["https://api.openai.com".to_string()], + }, + history_backend: HistoryBackend::Oracle, + oracle: Some(OracleConfig { + wallet_path: Some("/etc/sglang/oracle-wallet".to_string()), + connect_descriptor: "db_low".to_string(), + username: "app_user".to_string(), + password: "secret".to_string(), + pool_min: 1, + pool_max: 8, + pool_timeout_secs: 45, + }), + ..Default::default() + }; + + ConfigValidator::validate(&config).expect("wallet-based config should validate"); +} diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 9bce02740..044aad4a6 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -208,6 +208,7 @@ mod test_pd_routing { model_path: None, tokenizer_path: None, history_backend: sglang_router_rs::config::HistoryBackend::Memory, + oracle: None, }; let app_context =