[router] Support multiple worker URLs for OpenAI router (#11723)
This commit is contained in:
@@ -165,18 +165,14 @@ impl ConfigValidator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
RoutingMode::OpenAI { worker_urls } => {
|
RoutingMode::OpenAI { worker_urls } => {
|
||||||
// Require exactly one worker URL for OpenAI router
|
// Require at least one worker URL for OpenAI router
|
||||||
if worker_urls.len() != 1 {
|
if worker_urls.is_empty() {
|
||||||
return Err(ConfigError::ValidationFailed {
|
return Err(ConfigError::ValidationFailed {
|
||||||
reason: "OpenAI mode requires exactly one --worker-urls entry".to_string(),
|
reason: "OpenAI mode requires at least one --worker-urls entry".to_string(),
|
||||||
});
|
|
||||||
}
|
|
||||||
// Validate URL format
|
|
||||||
if let Err(e) = url::Url::parse(&worker_urls[0]) {
|
|
||||||
return Err(ConfigError::ValidationFailed {
|
|
||||||
reason: format!("Invalid OpenAI worker URL '{}': {}", &worker_urls[0], e),
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
// Validate URLs
|
||||||
|
Self::validate_urls(worker_urls)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ use serde_json::Value;
|
|||||||
|
|
||||||
// Import shared types from common module
|
// Import shared types from common module
|
||||||
use super::common::{
|
use super::common::{
|
||||||
default_true, ChatLogProbs, GenerationRequest, PromptTokenUsageInfo, StringOrArray, ToolChoice,
|
default_model, default_true, ChatLogProbs, GenerationRequest, PromptTokenUsageInfo,
|
||||||
UsageInfo,
|
StringOrArray, ToolChoice, UsageInfo,
|
||||||
};
|
};
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
@@ -452,9 +452,9 @@ pub struct ResponsesRequest {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub metadata: Option<HashMap<String, Value>>,
|
pub metadata: Option<HashMap<String, Value>>,
|
||||||
|
|
||||||
/// Model to use (optional to match vLLM)
|
/// Model to use
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(default = "default_model")]
|
||||||
pub model: Option<String>,
|
pub model: String,
|
||||||
|
|
||||||
/// Optional conversation id to persist input/output as items
|
/// Optional conversation id to persist input/output as items
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
@@ -565,7 +565,7 @@ impl Default for ResponsesRequest {
|
|||||||
max_output_tokens: None,
|
max_output_tokens: None,
|
||||||
max_tool_calls: None,
|
max_tool_calls: None,
|
||||||
metadata: None,
|
metadata: None,
|
||||||
model: None,
|
model: default_model(),
|
||||||
conversation: None,
|
conversation: None,
|
||||||
parallel_tool_calls: None,
|
parallel_tool_calls: None,
|
||||||
previous_response_id: None,
|
previous_response_id: None,
|
||||||
@@ -598,7 +598,7 @@ impl GenerationRequest for ResponsesRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn get_model(&self) -> Option<&str> {
|
fn get_model(&self) -> Option<&str> {
|
||||||
self.model.as_deref()
|
Some(self.model.as_str())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_text_for_routing(&self) -> String {
|
fn extract_text_for_routing(&self) -> String {
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ impl RouterFactory {
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
RoutingMode::OpenAI { worker_urls, .. } => {
|
RoutingMode::OpenAI { worker_urls } => {
|
||||||
Self::create_openai_router(worker_urls.clone(), ctx).await
|
Self::create_openai_router(worker_urls.clone(), ctx).await
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -122,13 +122,12 @@ impl RouterFactory {
|
|||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
ctx: &Arc<AppContext>,
|
ctx: &Arc<AppContext>,
|
||||||
) -> Result<Box<dyn RouterTrait>, String> {
|
) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
let base_url = worker_urls
|
if worker_urls.is_empty() {
|
||||||
.first()
|
return Err("OpenAI mode requires at least one worker URL".to_string());
|
||||||
.cloned()
|
}
|
||||||
.ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;
|
|
||||||
|
|
||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
base_url,
|
worker_urls,
|
||||||
Some(ctx.router_config.circuit_breaker.clone()),
|
Some(ctx.router_config.circuit_breaker.clone()),
|
||||||
ctx.response_storage.clone(),
|
ctx.response_storage.clone(),
|
||||||
ctx.conversation_storage.clone(),
|
ctx.conversation_storage.clone(),
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ pub(super) fn build_stored_response(
|
|||||||
.get("model")
|
.get("model")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.map(|s| s.to_string())
|
.map(|s| s.to_string())
|
||||||
.or_else(|| original_body.model.clone());
|
.or_else(|| Some(original_body.model.clone()));
|
||||||
|
|
||||||
stored_response.user = response_json
|
stored_response.user = response_json
|
||||||
.get("user")
|
.get("user")
|
||||||
@@ -143,9 +143,10 @@ pub(super) fn patch_streaming_response_json(
|
|||||||
.map(|s| s.is_empty())
|
.map(|s| s.is_empty())
|
||||||
.unwrap_or(true)
|
.unwrap_or(true)
|
||||||
{
|
{
|
||||||
if let Some(model) = &original_body.model {
|
obj.insert(
|
||||||
obj.insert("model".to_string(), Value::String(model.clone()));
|
"model".to_string(),
|
||||||
}
|
Value::String(original_body.model.clone()),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if obj.get("user").map(|v| v.is_null()).unwrap_or(false) {
|
if obj.get("user").map(|v| v.is_null()).unwrap_or(false) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
use std::{
|
use std::{
|
||||||
any::Any,
|
any::Any,
|
||||||
sync::{atomic::AtomicBool, Arc},
|
sync::{atomic::AtomicBool, Arc},
|
||||||
|
time::{Duration, Instant},
|
||||||
};
|
};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
@@ -12,6 +13,7 @@ use axum::{
|
|||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
Json,
|
Json,
|
||||||
};
|
};
|
||||||
|
use dashmap::DashMap;
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use serde_json::{json, to_value, Value};
|
use serde_json::{json, to_value, Value};
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
@@ -31,6 +33,7 @@ use super::{
|
|||||||
},
|
},
|
||||||
responses::{mask_tools_as_mcp, patch_streaming_response_json},
|
responses::{mask_tools_as_mcp, patch_streaming_response_json},
|
||||||
streaming::handle_streaming_response,
|
streaming::handle_streaming_response,
|
||||||
|
utils::{apply_provider_headers, extract_auth_header, probe_endpoint_for_model},
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
config::CircuitBreakerConfig,
|
config::CircuitBreakerConfig,
|
||||||
@@ -59,12 +62,21 @@ use crate::{
|
|||||||
// OpenAIRouter Struct
|
// OpenAIRouter Struct
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
|
/// Cached endpoint information
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct CachedEndpoint {
|
||||||
|
url: String,
|
||||||
|
cached_at: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
/// Router for OpenAI backend
|
/// Router for OpenAI backend
|
||||||
pub struct OpenAIRouter {
|
pub struct OpenAIRouter {
|
||||||
/// HTTP client for upstream OpenAI-compatible API
|
/// HTTP client for upstream OpenAI-compatible API
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
/// Base URL for identification (no trailing slash)
|
/// Multiple OpenAI-compatible API endpoints (OpenAI, xAI, etc.)
|
||||||
base_url: String,
|
worker_urls: Vec<String>,
|
||||||
|
/// Model cache: model_id -> endpoint URL
|
||||||
|
model_cache: Arc<DashMap<String, CachedEndpoint>>,
|
||||||
/// Circuit breaker
|
/// Circuit breaker
|
||||||
circuit_breaker: CircuitBreaker,
|
circuit_breaker: CircuitBreaker,
|
||||||
/// Health status
|
/// Health status
|
||||||
@@ -82,7 +94,7 @@ pub struct OpenAIRouter {
|
|||||||
impl std::fmt::Debug for OpenAIRouter {
|
impl std::fmt::Debug for OpenAIRouter {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
f.debug_struct("OpenAIRouter")
|
f.debug_struct("OpenAIRouter")
|
||||||
.field("base_url", &self.base_url)
|
.field("worker_urls", &self.worker_urls)
|
||||||
.field("healthy", &self.healthy)
|
.field("healthy", &self.healthy)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
@@ -92,28 +104,35 @@ impl OpenAIRouter {
|
|||||||
/// Maximum number of conversation items to attach as input when a conversation is provided
|
/// Maximum number of conversation items to attach as input when a conversation is provided
|
||||||
const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100;
|
const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100;
|
||||||
|
|
||||||
|
/// Model discovery cache TTL (1 hour)
|
||||||
|
const MODEL_CACHE_TTL_SECS: u64 = 3600;
|
||||||
|
|
||||||
/// Create a new OpenAI router
|
/// Create a new OpenAI router
|
||||||
pub async fn new(
|
pub async fn new(
|
||||||
base_url: String,
|
worker_urls: Vec<String>,
|
||||||
circuit_breaker_config: Option<CircuitBreakerConfig>,
|
circuit_breaker_config: Option<CircuitBreakerConfig>,
|
||||||
response_storage: SharedResponseStorage,
|
response_storage: SharedResponseStorage,
|
||||||
conversation_storage: SharedConversationStorage,
|
conversation_storage: SharedConversationStorage,
|
||||||
conversation_item_storage: SharedConversationItemStorage,
|
conversation_item_storage: SharedConversationItemStorage,
|
||||||
) -> Result<Self, String> {
|
) -> Result<Self, String> {
|
||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(300))
|
.timeout(Duration::from_secs(300))
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
let base_url = base_url.trim_end_matches('/').to_string();
|
// Normalize URLs (remove trailing slashes)
|
||||||
|
let worker_urls: Vec<String> = worker_urls
|
||||||
|
.into_iter()
|
||||||
|
.map(|url| url.trim_end_matches('/').to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
// Convert circuit breaker config
|
// Convert circuit breaker config
|
||||||
let core_cb_config = circuit_breaker_config
|
let core_cb_config = circuit_breaker_config
|
||||||
.map(|cb| CoreCircuitBreakerConfig {
|
.map(|cb| CoreCircuitBreakerConfig {
|
||||||
failure_threshold: cb.failure_threshold,
|
failure_threshold: cb.failure_threshold,
|
||||||
success_threshold: cb.success_threshold,
|
success_threshold: cb.success_threshold,
|
||||||
timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs),
|
timeout_duration: Duration::from_secs(cb.timeout_duration_secs),
|
||||||
window_duration: std::time::Duration::from_secs(cb.window_duration_secs),
|
window_duration: Duration::from_secs(cb.window_duration_secs),
|
||||||
})
|
})
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
@@ -141,7 +160,8 @@ impl OpenAIRouter {
|
|||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client,
|
client,
|
||||||
base_url,
|
worker_urls,
|
||||||
|
model_cache: Arc::new(DashMap::new()),
|
||||||
circuit_breaker,
|
circuit_breaker,
|
||||||
healthy: AtomicBool::new(true),
|
healthy: AtomicBool::new(true),
|
||||||
response_storage,
|
response_storage,
|
||||||
@@ -151,6 +171,67 @@ impl OpenAIRouter {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Discover which endpoint has the model
|
||||||
|
async fn find_endpoint_for_model(
|
||||||
|
&self,
|
||||||
|
model_id: &str,
|
||||||
|
auth_header: Option<&str>,
|
||||||
|
) -> Result<String, Response> {
|
||||||
|
// Single endpoint - fast path
|
||||||
|
if self.worker_urls.len() == 1 {
|
||||||
|
return Ok(self.worker_urls[0].clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check cache
|
||||||
|
if let Some(entry) = self.model_cache.get(model_id) {
|
||||||
|
if entry.cached_at.elapsed() < Duration::from_secs(Self::MODEL_CACHE_TTL_SECS) {
|
||||||
|
return Ok(entry.url.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Probe all endpoints in parallel
|
||||||
|
let mut handles = vec![];
|
||||||
|
let model = model_id.to_string();
|
||||||
|
let auth = auth_header.map(|s| s.to_string());
|
||||||
|
|
||||||
|
for url in &self.worker_urls {
|
||||||
|
let handle = tokio::spawn(probe_endpoint_for_model(
|
||||||
|
self.client.clone(),
|
||||||
|
url.clone(),
|
||||||
|
model.clone(),
|
||||||
|
auth.clone(),
|
||||||
|
));
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return first successful endpoint
|
||||||
|
for handle in handles {
|
||||||
|
if let Ok(Ok(url)) = handle.await {
|
||||||
|
// Cache it
|
||||||
|
self.model_cache.insert(
|
||||||
|
model_id.to_string(),
|
||||||
|
CachedEndpoint {
|
||||||
|
url: url.clone(),
|
||||||
|
cached_at: Instant::now(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
return Ok(url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model not found on any endpoint
|
||||||
|
Err((
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({
|
||||||
|
"error": {
|
||||||
|
"message": format!("Model '{}' not found on any endpoint", model_id),
|
||||||
|
"type": "model_not_found",
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response())
|
||||||
|
}
|
||||||
|
|
||||||
/// Handle non-streaming response with optional MCP tool loop
|
/// Handle non-streaming response with optional MCP tool loop
|
||||||
async fn handle_non_streaming_response(
|
async fn handle_non_streaming_response(
|
||||||
&self,
|
&self,
|
||||||
@@ -282,85 +363,145 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||||
// Simple upstream probe: GET {base}/v1/models without auth
|
// Check all endpoints in parallel - only healthy if ALL are healthy
|
||||||
let url = format!("{}/v1/models", self.base_url);
|
if self.worker_urls.is_empty() {
|
||||||
match self
|
return (StatusCode::SERVICE_UNAVAILABLE, "No endpoints configured").into_response();
|
||||||
.client
|
}
|
||||||
.get(&url)
|
|
||||||
.timeout(std::time::Duration::from_secs(2))
|
let mut handles = vec![];
|
||||||
.send()
|
for url in &self.worker_urls {
|
||||||
.await
|
let url = url.clone();
|
||||||
{
|
let client = self.client.clone();
|
||||||
Ok(resp) => {
|
|
||||||
let code = resp.status();
|
let handle = tokio::spawn(async move {
|
||||||
// Treat success and auth-required as healthy (endpoint reachable)
|
let probe_url = format!("{}/v1/models", url);
|
||||||
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
|
match client
|
||||||
(StatusCode::OK, "OK").into_response()
|
.get(&probe_url)
|
||||||
} else {
|
.timeout(Duration::from_secs(2))
|
||||||
(
|
.send()
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
.await
|
||||||
format!("Upstream status: {}", code),
|
{
|
||||||
)
|
Ok(resp) => {
|
||||||
.into_response()
|
let code = resp.status();
|
||||||
|
// Treat success and auth-required as healthy (endpoint reachable)
|
||||||
|
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!("Endpoint {} returned status {}", url, code))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => Err(format!("Endpoint {} error: {}", url, e)),
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect all results
|
||||||
|
let mut errors = Vec::new();
|
||||||
|
for handle in handles {
|
||||||
|
match handle.await {
|
||||||
|
Ok(Ok(())) => (),
|
||||||
|
Ok(Err(e)) => errors.push(e),
|
||||||
|
Err(e) => errors.push(format!("Task join error: {}", e)),
|
||||||
}
|
}
|
||||||
Err(e) => (
|
}
|
||||||
|
|
||||||
|
if errors.is_empty() {
|
||||||
|
(StatusCode::OK, "OK").into_response()
|
||||||
|
} else {
|
||||||
|
(
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
format!("Upstream error: {}", e),
|
format!("Some endpoints unhealthy: {}", errors.join(", ")),
|
||||||
)
|
)
|
||||||
.into_response(),
|
.into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||||
let info = json!({
|
let info = json!({
|
||||||
"router_type": "openai",
|
"router_type": "openai",
|
||||||
"workers": 1,
|
"workers": self.worker_urls.len(),
|
||||||
"base_url": &self.base_url
|
"worker_urls": &self.worker_urls
|
||||||
});
|
});
|
||||||
(StatusCode::OK, info.to_string()).into_response()
|
(StatusCode::OK, info.to_string()).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_models(&self, req: Request<Body>) -> Response {
|
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||||
// Proxy to upstream /v1/models; forward Authorization header if provided
|
// Aggregate models from all endpoints
|
||||||
let headers = req.headers();
|
if self.worker_urls.is_empty() {
|
||||||
|
return (StatusCode::SERVICE_UNAVAILABLE, "No endpoints configured").into_response();
|
||||||
let mut upstream = self.client.get(format!("{}/v1/models", self.base_url));
|
|
||||||
|
|
||||||
if let Some(auth) = headers
|
|
||||||
.get("authorization")
|
|
||||||
.or_else(|| headers.get("Authorization"))
|
|
||||||
{
|
|
||||||
upstream = upstream.header("Authorization", auth);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
match upstream.send().await {
|
let headers = req.headers();
|
||||||
Ok(res) => {
|
let auth = headers
|
||||||
let status = StatusCode::from_u16(res.status().as_u16())
|
.get("authorization")
|
||||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
.or_else(|| headers.get("Authorization"));
|
||||||
let content_type = res.headers().get(CONTENT_TYPE).cloned();
|
|
||||||
match res.bytes().await {
|
// Query all endpoints in parallel
|
||||||
Ok(body) => {
|
let mut handles = vec![];
|
||||||
let mut response = Response::new(Body::from(body));
|
for url in &self.worker_urls {
|
||||||
*response.status_mut() = status;
|
let url = url.clone();
|
||||||
if let Some(ct) = content_type {
|
let client = self.client.clone();
|
||||||
response.headers_mut().insert(CONTENT_TYPE, ct);
|
let auth = auth.cloned();
|
||||||
|
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
let models_url = format!("{}/v1/models", url);
|
||||||
|
let req = client.get(&models_url);
|
||||||
|
|
||||||
|
// Apply provider-specific headers (handles Anthropic, xAI, OpenAI, etc.)
|
||||||
|
let req = apply_provider_headers(req, &url, auth.as_ref());
|
||||||
|
|
||||||
|
match req.send().await {
|
||||||
|
Ok(res) => {
|
||||||
|
if res.status().is_success() {
|
||||||
|
match res.json::<Value>().await {
|
||||||
|
Ok(json) => Ok(json),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to parse models response from '{}': {}",
|
||||||
|
url,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
Err(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tracing::warn!(
|
||||||
|
"Getting models from '{}' failed with status: {}",
|
||||||
|
url,
|
||||||
|
res.status()
|
||||||
|
);
|
||||||
|
Err(())
|
||||||
}
|
}
|
||||||
response
|
|
||||||
}
|
}
|
||||||
Err(e) => (
|
Err(e) => {
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
tracing::warn!("Request to get models from '{}' failed: {}", url, e);
|
||||||
format!("Failed to read upstream response: {}", e),
|
Err(())
|
||||||
)
|
}
|
||||||
.into_response(),
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect all model lists
|
||||||
|
let mut all_models = Vec::new();
|
||||||
|
for handle in handles {
|
||||||
|
if let Ok(Ok(json)) = handle.await {
|
||||||
|
if let Some(data) = json.get("data").and_then(|v| v.as_array()) {
|
||||||
|
all_models.extend_from_slice(data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => (
|
|
||||||
StatusCode::BAD_GATEWAY,
|
|
||||||
format!("Failed to contact upstream: {}", e),
|
|
||||||
)
|
|
||||||
.into_response(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return aggregated models
|
||||||
|
let response_json = json!({
|
||||||
|
"object": "list",
|
||||||
|
"data": all_models
|
||||||
|
});
|
||||||
|
|
||||||
|
(StatusCode::OK, Json(response_json)).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||||
@@ -396,6 +537,18 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
|||||||
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
|
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract auth header
|
||||||
|
let auth = extract_auth_header(headers);
|
||||||
|
|
||||||
|
// Find endpoint for model
|
||||||
|
let base_url = match self
|
||||||
|
.find_endpoint_for_model(body.model.as_str(), auth)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(url) => url,
|
||||||
|
Err(response) => return response,
|
||||||
|
};
|
||||||
|
|
||||||
// Serialize request body, removing SGLang-only fields
|
// Serialize request body, removing SGLang-only fields
|
||||||
let mut payload = match to_value(body) {
|
let mut payload = match to_value(body) {
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
@@ -431,9 +584,14 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
|||||||
] {
|
] {
|
||||||
obj.remove(key);
|
obj.remove(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove logprobs if false (Gemini don't accept it)
|
||||||
|
if obj.get("logprobs").and_then(|v| v.as_bool()) == Some(false) {
|
||||||
|
obj.remove("logprobs");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let url = format!("{}/v1/chat/completions", self.base_url);
|
let url = format!("{}/v1/chat/completions", base_url);
|
||||||
let mut req = self.client.post(&url).json(&payload);
|
let mut req = self.client.post(&url).json(&payload);
|
||||||
|
|
||||||
// Forward Authorization header if provided
|
// Forward Authorization header if provided
|
||||||
@@ -534,7 +692,17 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
|||||||
body: &ResponsesRequest,
|
body: &ResponsesRequest,
|
||||||
model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let url = format!("{}/v1/responses", self.base_url);
|
// Extract auth header
|
||||||
|
let auth = extract_auth_header(headers);
|
||||||
|
|
||||||
|
// Find endpoint for model (use model_id if provided, otherwise use body.model)
|
||||||
|
let model = model_id.unwrap_or(body.model.as_str());
|
||||||
|
let base_url = match self.find_endpoint_for_model(model, auth).await {
|
||||||
|
Ok(url) => url,
|
||||||
|
Err(response) => return response,
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = format!("{}/v1/responses", base_url);
|
||||||
|
|
||||||
// Validate mutually exclusive params: previous_response_id and conversation
|
// Validate mutually exclusive params: previous_response_id and conversation
|
||||||
// TODO: this validation logic should move the right place, also we need a proper error message module
|
// TODO: this validation logic should move the right place, also we need a proper error message module
|
||||||
@@ -556,7 +724,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
|||||||
// Clone the body for validation and logic, but we'll build payload differently
|
// Clone the body for validation and logic, but we'll build payload differently
|
||||||
let mut request_body = body.clone();
|
let mut request_body = body.clone();
|
||||||
if let Some(model) = model_id {
|
if let Some(model) = model_id {
|
||||||
request_body.model = Some(model.to_string());
|
request_body.model = model.to_string();
|
||||||
}
|
}
|
||||||
// Do not forward conversation field upstream; retain for local persistence only
|
// Do not forward conversation field upstream; retain for local persistence only
|
||||||
request_body.conversation = None;
|
request_body.conversation = None;
|
||||||
@@ -847,34 +1015,12 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||||
// Forward cancellation to upstream
|
(
|
||||||
let url = format!("{}/v1/responses/{}/cancel", self.base_url, response_id);
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
let mut req = self.client.post(&url);
|
"Cancel response not implemented for OpenAI router",
|
||||||
|
)
|
||||||
if let Some(h) = headers {
|
.into_response()
|
||||||
req = apply_request_headers(h, req, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
match req.send().await {
|
|
||||||
Ok(resp) => {
|
|
||||||
let status = StatusCode::from_u16(resp.status().as_u16())
|
|
||||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
|
||||||
match resp.text().await {
|
|
||||||
Ok(body) => (status, body).into_response(),
|
|
||||||
Err(e) => (
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Failed to read response: {}", e),
|
|
||||||
)
|
|
||||||
.into_response(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => (
|
|
||||||
StatusCode::BAD_GATEWAY,
|
|
||||||
format!("Failed to contact upstream: {}", e),
|
|
||||||
)
|
|
||||||
.into_response(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_embeddings(
|
async fn route_embeddings(
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use axum::http::{HeaderMap, HeaderValue};
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// SSE Event Type Constants
|
// SSE Event Type Constants
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
@@ -93,6 +95,131 @@ impl OutputIndexMapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Provider Detection and Header Handling
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/// Extract authorization header from request headers
|
||||||
|
/// Checks both "authorization" and "Authorization" (case variations)
|
||||||
|
pub fn extract_auth_header(headers: Option<&HeaderMap>) -> Option<&str> {
|
||||||
|
headers.and_then(|h| {
|
||||||
|
h.get("authorization")
|
||||||
|
.or_else(|| h.get("Authorization"))
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// API provider types
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ApiProvider {
|
||||||
|
Anthropic,
|
||||||
|
Xai,
|
||||||
|
OpenAi,
|
||||||
|
Gemini,
|
||||||
|
Generic,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ApiProvider {
|
||||||
|
/// Detect provider type from URL
|
||||||
|
pub fn from_url(url: &str) -> Self {
|
||||||
|
if url.contains("anthropic") {
|
||||||
|
ApiProvider::Anthropic
|
||||||
|
} else if url.contains("x.ai") {
|
||||||
|
ApiProvider::Xai
|
||||||
|
} else if url.contains("openai.com") {
|
||||||
|
ApiProvider::OpenAi
|
||||||
|
} else if url.contains("googleapis.com") {
|
||||||
|
ApiProvider::Gemini
|
||||||
|
} else {
|
||||||
|
ApiProvider::Generic
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply provider-specific headers to request
|
||||||
|
pub fn apply_provider_headers(
|
||||||
|
mut req: reqwest::RequestBuilder,
|
||||||
|
url: &str,
|
||||||
|
auth_header: Option<&HeaderValue>,
|
||||||
|
) -> reqwest::RequestBuilder {
|
||||||
|
let provider = ApiProvider::from_url(url);
|
||||||
|
|
||||||
|
match provider {
|
||||||
|
ApiProvider::Anthropic => {
|
||||||
|
// Anthropic requires x-api-key instead of Authorization
|
||||||
|
// Extract Bearer token and use as x-api-key
|
||||||
|
if let Some(auth) = auth_header {
|
||||||
|
if let Ok(auth_str) = auth.to_str() {
|
||||||
|
let api_key = auth_str.strip_prefix("Bearer ").unwrap_or(auth_str);
|
||||||
|
req = req
|
||||||
|
.header("x-api-key", api_key)
|
||||||
|
.header("anthropic-version", "2023-06-01");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ApiProvider::Gemini | ApiProvider::Xai | ApiProvider::OpenAi | ApiProvider::Generic => {
|
||||||
|
// Standard OpenAI-compatible: use Authorization header as-is
|
||||||
|
if let Some(auth) = auth_header {
|
||||||
|
req = req.header("Authorization", auth);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Probe a single endpoint to check if it has the model
|
||||||
|
/// Returns Ok(url) if model found, Err(()) otherwise
|
||||||
|
pub async fn probe_endpoint_for_model(
|
||||||
|
client: reqwest::Client,
|
||||||
|
url: String,
|
||||||
|
model: String,
|
||||||
|
auth: Option<String>,
|
||||||
|
) -> Result<String, ()> {
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
let probe_url = format!("{}/v1/models/{}", url, model);
|
||||||
|
let req = client
|
||||||
|
.get(&probe_url)
|
||||||
|
.timeout(std::time::Duration::from_secs(5));
|
||||||
|
|
||||||
|
// Apply provider-specific headers (handles Anthropic, xAI, OpenAI, etc.)
|
||||||
|
let auth_header_value = auth.as_ref().and_then(|a| HeaderValue::from_str(a).ok());
|
||||||
|
let req = apply_provider_headers(req, &url, auth_header_value.as_ref());
|
||||||
|
|
||||||
|
match req.send().await {
|
||||||
|
Ok(resp) => {
|
||||||
|
let status = resp.status();
|
||||||
|
if status.is_success() {
|
||||||
|
debug!(
|
||||||
|
url = %url,
|
||||||
|
model = %model,
|
||||||
|
status = %status,
|
||||||
|
"Model found on endpoint"
|
||||||
|
);
|
||||||
|
Ok(url)
|
||||||
|
} else {
|
||||||
|
debug!(
|
||||||
|
url = %url,
|
||||||
|
model = %model,
|
||||||
|
status = %status,
|
||||||
|
"Model not found on endpoint (unsuccessful status)"
|
||||||
|
);
|
||||||
|
Err(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!(
|
||||||
|
url = %url,
|
||||||
|
model = %model,
|
||||||
|
error = %e,
|
||||||
|
"Probe request to endpoint failed"
|
||||||
|
);
|
||||||
|
Err(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// Re-export FunctionCallInProgress from mcp module
|
// Re-export FunctionCallInProgress from mcp module
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|||||||
@@ -410,7 +410,7 @@ impl RouterTrait for RouterManager {
|
|||||||
body: &ResponsesRequest,
|
body: &ResponsesRequest,
|
||||||
model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let selected_model = body.model.as_deref().or(model_id);
|
let selected_model = model_id.or(Some(body.model.as_str()));
|
||||||
let router = self.select_router_for_request(headers, selected_model);
|
let router = self.select_router_for_request(headers, selected_model);
|
||||||
|
|
||||||
if let Some(router) = router {
|
if let Some(router) = router {
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
|||||||
max_output_tokens: Some(64),
|
max_output_tokens: Some(64),
|
||||||
max_tool_calls: None,
|
max_tool_calls: None,
|
||||||
metadata: None,
|
metadata: None,
|
||||||
model: Some("mock-model".to_string()),
|
model: "mock-model".to_string(),
|
||||||
parallel_tool_calls: Some(true),
|
parallel_tool_calls: Some(true),
|
||||||
previous_response_id: None,
|
previous_response_id: None,
|
||||||
reasoning: None,
|
reasoning: None,
|
||||||
@@ -134,7 +134,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let resp = router
|
let resp = router
|
||||||
.route_responses(None, &req, req.model.as_deref())
|
.route_responses(None, &req, Some(req.model.as_str()))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
assert_eq!(resp.status(), StatusCode::OK);
|
||||||
@@ -349,7 +349,7 @@ fn test_responses_request_creation() {
|
|||||||
max_output_tokens: Some(100),
|
max_output_tokens: Some(100),
|
||||||
max_tool_calls: None,
|
max_tool_calls: None,
|
||||||
metadata: None,
|
metadata: None,
|
||||||
model: Some("test-model".to_string()),
|
model: "test-model".to_string(),
|
||||||
parallel_tool_calls: Some(true),
|
parallel_tool_calls: Some(true),
|
||||||
previous_response_id: None,
|
previous_response_id: None,
|
||||||
reasoning: Some(ResponseReasoningParam {
|
reasoning: Some(ResponseReasoningParam {
|
||||||
@@ -397,7 +397,7 @@ fn test_responses_request_sglang_extensions() {
|
|||||||
max_output_tokens: Some(50),
|
max_output_tokens: Some(50),
|
||||||
max_tool_calls: None,
|
max_tool_calls: None,
|
||||||
metadata: None,
|
metadata: None,
|
||||||
model: Some("test-model".to_string()),
|
model: "test-model".to_string(),
|
||||||
parallel_tool_calls: Some(true),
|
parallel_tool_calls: Some(true),
|
||||||
previous_response_id: None,
|
previous_response_id: None,
|
||||||
reasoning: None,
|
reasoning: None,
|
||||||
@@ -506,7 +506,7 @@ fn test_json_serialization() {
|
|||||||
max_output_tokens: Some(200),
|
max_output_tokens: Some(200),
|
||||||
max_tool_calls: Some(5),
|
max_tool_calls: Some(5),
|
||||||
metadata: None,
|
metadata: None,
|
||||||
model: Some("gpt-4".to_string()),
|
model: "gpt-4".to_string(),
|
||||||
parallel_tool_calls: Some(false),
|
parallel_tool_calls: Some(false),
|
||||||
previous_response_id: None,
|
previous_response_id: None,
|
||||||
reasoning: Some(ResponseReasoningParam {
|
reasoning: Some(ResponseReasoningParam {
|
||||||
@@ -545,7 +545,7 @@ fn test_json_serialization() {
|
|||||||
parsed.request_id,
|
parsed.request_id,
|
||||||
Some("resp_comprehensive_test".to_string())
|
Some("resp_comprehensive_test".to_string())
|
||||||
);
|
);
|
||||||
assert_eq!(parsed.model, Some("gpt-4".to_string()));
|
assert_eq!(parsed.model, "gpt-4");
|
||||||
assert_eq!(parsed.background, Some(true));
|
assert_eq!(parsed.background, Some(true));
|
||||||
assert_eq!(parsed.stream, Some(true));
|
assert_eq!(parsed.stream, Some(true));
|
||||||
assert_eq!(parsed.tools.as_ref().map(|t| t.len()), Some(1));
|
assert_eq!(parsed.tools.as_ref().map(|t| t.len()), Some(1));
|
||||||
@@ -636,7 +636,7 @@ async fn test_multi_turn_loop_with_mcp() {
|
|||||||
max_output_tokens: Some(128),
|
max_output_tokens: Some(128),
|
||||||
max_tool_calls: None, // No limit - test unlimited
|
max_tool_calls: None, // No limit - test unlimited
|
||||||
metadata: None,
|
metadata: None,
|
||||||
model: Some("mock-model".to_string()),
|
model: "mock-model".to_string(),
|
||||||
parallel_tool_calls: Some(true),
|
parallel_tool_calls: Some(true),
|
||||||
previous_response_id: None,
|
previous_response_id: None,
|
||||||
reasoning: None,
|
reasoning: None,
|
||||||
@@ -812,7 +812,7 @@ async fn test_max_tool_calls_limit() {
|
|||||||
max_output_tokens: Some(128),
|
max_output_tokens: Some(128),
|
||||||
max_tool_calls: Some(1), // Limit to 1 call
|
max_tool_calls: Some(1), // Limit to 1 call
|
||||||
metadata: None,
|
metadata: None,
|
||||||
model: Some("mock-model".to_string()),
|
model: "mock-model".to_string(),
|
||||||
parallel_tool_calls: Some(true),
|
parallel_tool_calls: Some(true),
|
||||||
previous_response_id: None,
|
previous_response_id: None,
|
||||||
reasoning: None,
|
reasoning: None,
|
||||||
@@ -1006,7 +1006,7 @@ async fn test_streaming_with_mcp_tool_calls() {
|
|||||||
max_output_tokens: Some(256),
|
max_output_tokens: Some(256),
|
||||||
max_tool_calls: Some(3),
|
max_tool_calls: Some(3),
|
||||||
metadata: None,
|
metadata: None,
|
||||||
model: Some("mock-model".to_string()),
|
model: "mock-model".to_string(),
|
||||||
parallel_tool_calls: Some(true),
|
parallel_tool_calls: Some(true),
|
||||||
previous_response_id: None,
|
previous_response_id: None,
|
||||||
reasoning: None,
|
reasoning: None,
|
||||||
@@ -1287,7 +1287,7 @@ async fn test_streaming_multi_turn_with_mcp() {
|
|||||||
max_output_tokens: Some(512),
|
max_output_tokens: Some(512),
|
||||||
max_tool_calls: Some(5), // Allow multiple rounds
|
max_tool_calls: Some(5), // Allow multiple rounds
|
||||||
metadata: None,
|
metadata: None,
|
||||||
model: Some("mock-model".to_string()),
|
model: "mock-model".to_string(),
|
||||||
parallel_tool_calls: Some(true),
|
parallel_tool_calls: Some(true),
|
||||||
previous_response_id: None,
|
previous_response_id: None,
|
||||||
reasoning: None,
|
reasoning: None,
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ fn create_minimal_completion_request() -> CompletionRequest {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_openai_router_creation() {
|
async fn test_openai_router_creation() {
|
||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
"https://api.openai.com".to_string(),
|
vec!["https://api.openai.com".to_string()],
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
Arc::new(MemoryConversationStorage::new()),
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
@@ -118,7 +118,7 @@ async fn test_openai_router_creation() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_openai_router_server_info() {
|
async fn test_openai_router_server_info() {
|
||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
"https://api.openai.com".to_string(),
|
vec!["https://api.openai.com".to_string()],
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
Arc::new(MemoryConversationStorage::new()),
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
@@ -149,7 +149,7 @@ async fn test_openai_router_models() {
|
|||||||
// Use mock server for deterministic models response
|
// Use mock server for deterministic models response
|
||||||
let mock_server = MockOpenAIServer::new().await;
|
let mock_server = MockOpenAIServer::new().await;
|
||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
mock_server.base_url(),
|
vec![mock_server.base_url()],
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
Arc::new(MemoryConversationStorage::new()),
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
@@ -229,7 +229,7 @@ async fn test_openai_router_responses_with_mock() {
|
|||||||
let storage = Arc::new(MemoryResponseStorage::new());
|
let storage = Arc::new(MemoryResponseStorage::new());
|
||||||
|
|
||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
base_url,
|
vec![base_url],
|
||||||
None,
|
None,
|
||||||
storage.clone(),
|
storage.clone(),
|
||||||
Arc::new(MemoryConversationStorage::new()),
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
@@ -239,7 +239,7 @@ async fn test_openai_router_responses_with_mock() {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let request1 = ResponsesRequest {
|
let request1 = ResponsesRequest {
|
||||||
model: Some("gpt-4o-mini".to_string()),
|
model: "gpt-4o-mini".to_string(),
|
||||||
input: ResponseInput::Text("Say hi".to_string()),
|
input: ResponseInput::Text("Say hi".to_string()),
|
||||||
store: Some(true),
|
store: Some(true),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@@ -255,7 +255,7 @@ async fn test_openai_router_responses_with_mock() {
|
|||||||
assert_eq!(body1["previous_response_id"], serde_json::Value::Null);
|
assert_eq!(body1["previous_response_id"], serde_json::Value::Null);
|
||||||
|
|
||||||
let request2 = ResponsesRequest {
|
let request2 = ResponsesRequest {
|
||||||
model: Some("gpt-4o-mini".to_string()),
|
model: "gpt-4o-mini".to_string(),
|
||||||
input: ResponseInput::Text("Thanks".to_string()),
|
input: ResponseInput::Text("Thanks".to_string()),
|
||||||
store: Some(true),
|
store: Some(true),
|
||||||
previous_response_id: Some(resp1_id.clone()),
|
previous_response_id: Some(resp1_id.clone()),
|
||||||
@@ -490,7 +490,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
|
|||||||
storage.store_response(previous).await.unwrap();
|
storage.store_response(previous).await.unwrap();
|
||||||
|
|
||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
base_url,
|
vec![base_url],
|
||||||
None,
|
None,
|
||||||
storage.clone(),
|
storage.clone(),
|
||||||
Arc::new(MemoryConversationStorage::new()),
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
@@ -503,7 +503,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
|
|||||||
metadata.insert("topic".to_string(), json!("unicorns"));
|
metadata.insert("topic".to_string(), json!("unicorns"));
|
||||||
|
|
||||||
let request = ResponsesRequest {
|
let request = ResponsesRequest {
|
||||||
model: Some("gpt-5-nano".to_string()),
|
model: "gpt-5-nano".to_string(),
|
||||||
input: ResponseInput::Text("Tell me a bedtime story.".to_string()),
|
input: ResponseInput::Text("Tell me a bedtime story.".to_string()),
|
||||||
instructions: Some("Be kind".to_string()),
|
instructions: Some("Be kind".to_string()),
|
||||||
metadata: Some(metadata),
|
metadata: Some(metadata),
|
||||||
@@ -595,7 +595,7 @@ async fn test_router_factory_openai_mode() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_unsupported_endpoints() {
|
async fn test_unsupported_endpoints() {
|
||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
"https://api.openai.com".to_string(),
|
vec!["https://api.openai.com".to_string()],
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
Arc::new(MemoryConversationStorage::new()),
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
@@ -660,7 +660,7 @@ async fn test_openai_router_chat_completion_with_mock() {
|
|||||||
|
|
||||||
// Create router pointing to mock server
|
// Create router pointing to mock server
|
||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
base_url,
|
vec![base_url],
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
Arc::new(MemoryConversationStorage::new()),
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
@@ -702,7 +702,7 @@ async fn test_openai_e2e_with_server() {
|
|||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
base_url,
|
vec![base_url],
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
Arc::new(MemoryConversationStorage::new()),
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
@@ -773,7 +773,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
|
|||||||
let mock_server = MockOpenAIServer::new().await;
|
let mock_server = MockOpenAIServer::new().await;
|
||||||
let base_url = mock_server.base_url();
|
let base_url = mock_server.base_url();
|
||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
base_url,
|
vec![base_url],
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
Arc::new(MemoryConversationStorage::new()),
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
@@ -827,7 +827,7 @@ async fn test_openai_router_circuit_breaker() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
"http://invalid-url-that-will-fail".to_string(),
|
vec!["http://invalid-url-that-will-fail".to_string()],
|
||||||
Some(cb_config),
|
Some(cb_config),
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
Arc::new(MemoryConversationStorage::new()),
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
@@ -856,7 +856,7 @@ async fn test_openai_router_models_auth_forwarding() {
|
|||||||
let expected_auth = "Bearer test-token".to_string();
|
let expected_auth = "Bearer test-token".to_string();
|
||||||
let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await;
|
let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await;
|
||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
mock_server.base_url(),
|
vec![mock_server.base_url()],
|
||||||
None,
|
None,
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
Arc::new(MemoryConversationStorage::new()),
|
Arc::new(MemoryConversationStorage::new()),
|
||||||
@@ -865,7 +865,8 @@ async fn test_openai_router_models_auth_forwarding() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// 1) Without auth header -> expect 401
|
// 1) Without auth header -> expect 200 with empty model list
|
||||||
|
// (multi-endpoint aggregation silently skips failed endpoints)
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method(Method::GET)
|
.method(Method::GET)
|
||||||
.uri("/models")
|
.uri("/models")
|
||||||
@@ -873,7 +874,13 @@ async fn test_openai_router_models_auth_forwarding() {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let response = router.get_models(req).await;
|
let response = router.get_models(req).await;
|
||||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
let (_, body) = response.into_parts();
|
||||||
|
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
|
||||||
|
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||||
|
let models: serde_json::Value = serde_json::from_str(&body_str).unwrap();
|
||||||
|
assert_eq!(models["object"], "list");
|
||||||
|
assert_eq!(models["data"].as_array().unwrap().len(), 0); // Empty when auth fails
|
||||||
|
|
||||||
// 2) With auth header -> expect 200
|
// 2) With auth header -> expect 200
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
|
|||||||
Reference in New Issue
Block a user