[router] Support multiple worker URLs for OpenAI router (#11723)
This commit is contained in:
@@ -165,18 +165,14 @@ impl ConfigValidator {
|
||||
}
|
||||
}
|
||||
RoutingMode::OpenAI { worker_urls } => {
|
||||
// Require exactly one worker URL for OpenAI router
|
||||
if worker_urls.len() != 1 {
|
||||
// Require at least one worker URL for OpenAI router
|
||||
if worker_urls.is_empty() {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: "OpenAI mode requires exactly one --worker-urls entry".to_string(),
|
||||
});
|
||||
}
|
||||
// Validate URL format
|
||||
if let Err(e) = url::Url::parse(&worker_urls[0]) {
|
||||
return Err(ConfigError::ValidationFailed {
|
||||
reason: format!("Invalid OpenAI worker URL '{}': {}", &worker_urls[0], e),
|
||||
reason: "OpenAI mode requires at least one --worker-urls entry".to_string(),
|
||||
});
|
||||
}
|
||||
// Validate URLs
|
||||
Self::validate_urls(worker_urls)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
||||
@@ -8,8 +8,8 @@ use serde_json::Value;
|
||||
|
||||
// Import shared types from common module
|
||||
use super::common::{
|
||||
default_true, ChatLogProbs, GenerationRequest, PromptTokenUsageInfo, StringOrArray, ToolChoice,
|
||||
UsageInfo,
|
||||
default_model, default_true, ChatLogProbs, GenerationRequest, PromptTokenUsageInfo,
|
||||
StringOrArray, ToolChoice, UsageInfo,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
@@ -452,9 +452,9 @@ pub struct ResponsesRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<HashMap<String, Value>>,
|
||||
|
||||
/// Model to use (optional to match vLLM)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
/// Model to use
|
||||
#[serde(default = "default_model")]
|
||||
pub model: String,
|
||||
|
||||
/// Optional conversation id to persist input/output as items
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -565,7 +565,7 @@ impl Default for ResponsesRequest {
|
||||
max_output_tokens: None,
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: None,
|
||||
model: default_model(),
|
||||
conversation: None,
|
||||
parallel_tool_calls: None,
|
||||
previous_response_id: None,
|
||||
@@ -598,7 +598,7 @@ impl GenerationRequest for ResponsesRequest {
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
self.model.as_deref()
|
||||
Some(self.model.as_str())
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
|
||||
@@ -55,7 +55,7 @@ impl RouterFactory {
|
||||
)
|
||||
.await
|
||||
}
|
||||
RoutingMode::OpenAI { worker_urls, .. } => {
|
||||
RoutingMode::OpenAI { worker_urls } => {
|
||||
Self::create_openai_router(worker_urls.clone(), ctx).await
|
||||
}
|
||||
},
|
||||
@@ -122,13 +122,12 @@ impl RouterFactory {
|
||||
worker_urls: Vec<String>,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
let base_url = worker_urls
|
||||
.first()
|
||||
.cloned()
|
||||
.ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;
|
||||
if worker_urls.is_empty() {
|
||||
return Err("OpenAI mode requires at least one worker URL".to_string());
|
||||
}
|
||||
|
||||
let router = OpenAIRouter::new(
|
||||
base_url,
|
||||
worker_urls,
|
||||
Some(ctx.router_config.circuit_breaker.clone()),
|
||||
ctx.response_storage.clone(),
|
||||
ctx.conversation_storage.clone(),
|
||||
|
||||
@@ -39,7 +39,7 @@ pub(super) fn build_stored_response(
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| original_body.model.clone());
|
||||
.or_else(|| Some(original_body.model.clone()));
|
||||
|
||||
stored_response.user = response_json
|
||||
.get("user")
|
||||
@@ -143,9 +143,10 @@ pub(super) fn patch_streaming_response_json(
|
||||
.map(|s| s.is_empty())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
if let Some(model) = &original_body.model {
|
||||
obj.insert("model".to_string(), Value::String(model.clone()));
|
||||
}
|
||||
obj.insert(
|
||||
"model".to_string(),
|
||||
Value::String(original_body.model.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
if obj.get("user").map(|v| v.is_null()).unwrap_or(false) {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use std::{
|
||||
any::Any,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
@@ -12,6 +13,7 @@ use axum::{
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use futures_util::StreamExt;
|
||||
use serde_json::{json, to_value, Value};
|
||||
use tokio::sync::mpsc;
|
||||
@@ -31,6 +33,7 @@ use super::{
|
||||
},
|
||||
responses::{mask_tools_as_mcp, patch_streaming_response_json},
|
||||
streaming::handle_streaming_response,
|
||||
utils::{apply_provider_headers, extract_auth_header, probe_endpoint_for_model},
|
||||
};
|
||||
use crate::{
|
||||
config::CircuitBreakerConfig,
|
||||
@@ -59,12 +62,21 @@ use crate::{
|
||||
// OpenAIRouter Struct
|
||||
// ============================================================================
|
||||
|
||||
/// Cached endpoint information
|
||||
#[derive(Clone, Debug)]
|
||||
struct CachedEndpoint {
|
||||
url: String,
|
||||
cached_at: Instant,
|
||||
}
|
||||
|
||||
/// Router for OpenAI backend
|
||||
pub struct OpenAIRouter {
|
||||
/// HTTP client for upstream OpenAI-compatible API
|
||||
client: reqwest::Client,
|
||||
/// Base URL for identification (no trailing slash)
|
||||
base_url: String,
|
||||
/// Multiple OpenAI-compatible API endpoints (OpenAI, xAI, etc.)
|
||||
worker_urls: Vec<String>,
|
||||
/// Model cache: model_id -> endpoint URL
|
||||
model_cache: Arc<DashMap<String, CachedEndpoint>>,
|
||||
/// Circuit breaker
|
||||
circuit_breaker: CircuitBreaker,
|
||||
/// Health status
|
||||
@@ -82,7 +94,7 @@ pub struct OpenAIRouter {
|
||||
impl std::fmt::Debug for OpenAIRouter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OpenAIRouter")
|
||||
.field("base_url", &self.base_url)
|
||||
.field("worker_urls", &self.worker_urls)
|
||||
.field("healthy", &self.healthy)
|
||||
.finish()
|
||||
}
|
||||
@@ -92,28 +104,35 @@ impl OpenAIRouter {
|
||||
/// Maximum number of conversation items to attach as input when a conversation is provided
|
||||
const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100;
|
||||
|
||||
/// Model discovery cache TTL (1 hour)
|
||||
const MODEL_CACHE_TTL_SECS: u64 = 3600;
|
||||
|
||||
/// Create a new OpenAI router
|
||||
pub async fn new(
|
||||
base_url: String,
|
||||
worker_urls: Vec<String>,
|
||||
circuit_breaker_config: Option<CircuitBreakerConfig>,
|
||||
response_storage: SharedResponseStorage,
|
||||
conversation_storage: SharedConversationStorage,
|
||||
conversation_item_storage: SharedConversationItemStorage,
|
||||
) -> Result<Self, String> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.timeout(Duration::from_secs(300))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
let base_url = base_url.trim_end_matches('/').to_string();
|
||||
// Normalize URLs (remove trailing slashes)
|
||||
let worker_urls: Vec<String> = worker_urls
|
||||
.into_iter()
|
||||
.map(|url| url.trim_end_matches('/').to_string())
|
||||
.collect();
|
||||
|
||||
// Convert circuit breaker config
|
||||
let core_cb_config = circuit_breaker_config
|
||||
.map(|cb| CoreCircuitBreakerConfig {
|
||||
failure_threshold: cb.failure_threshold,
|
||||
success_threshold: cb.success_threshold,
|
||||
timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs),
|
||||
window_duration: std::time::Duration::from_secs(cb.window_duration_secs),
|
||||
timeout_duration: Duration::from_secs(cb.timeout_duration_secs),
|
||||
window_duration: Duration::from_secs(cb.window_duration_secs),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
@@ -141,7 +160,8 @@ impl OpenAIRouter {
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url,
|
||||
worker_urls,
|
||||
model_cache: Arc::new(DashMap::new()),
|
||||
circuit_breaker,
|
||||
healthy: AtomicBool::new(true),
|
||||
response_storage,
|
||||
@@ -151,6 +171,67 @@ impl OpenAIRouter {
|
||||
})
|
||||
}
|
||||
|
||||
/// Discover which endpoint has the model
|
||||
async fn find_endpoint_for_model(
|
||||
&self,
|
||||
model_id: &str,
|
||||
auth_header: Option<&str>,
|
||||
) -> Result<String, Response> {
|
||||
// Single endpoint - fast path
|
||||
if self.worker_urls.len() == 1 {
|
||||
return Ok(self.worker_urls[0].clone());
|
||||
}
|
||||
|
||||
// Check cache
|
||||
if let Some(entry) = self.model_cache.get(model_id) {
|
||||
if entry.cached_at.elapsed() < Duration::from_secs(Self::MODEL_CACHE_TTL_SECS) {
|
||||
return Ok(entry.url.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Probe all endpoints in parallel
|
||||
let mut handles = vec![];
|
||||
let model = model_id.to_string();
|
||||
let auth = auth_header.map(|s| s.to_string());
|
||||
|
||||
for url in &self.worker_urls {
|
||||
let handle = tokio::spawn(probe_endpoint_for_model(
|
||||
self.client.clone(),
|
||||
url.clone(),
|
||||
model.clone(),
|
||||
auth.clone(),
|
||||
));
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Return first successful endpoint
|
||||
for handle in handles {
|
||||
if let Ok(Ok(url)) = handle.await {
|
||||
// Cache it
|
||||
self.model_cache.insert(
|
||||
model_id.to_string(),
|
||||
CachedEndpoint {
|
||||
url: url.clone(),
|
||||
cached_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
return Ok(url);
|
||||
}
|
||||
}
|
||||
|
||||
// Model not found on any endpoint
|
||||
Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({
|
||||
"error": {
|
||||
"message": format!("Model '{}' not found on any endpoint", model_id),
|
||||
"type": "model_not_found",
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response())
|
||||
}
|
||||
|
||||
/// Handle non-streaming response with optional MCP tool loop
|
||||
async fn handle_non_streaming_response(
|
||||
&self,
|
||||
@@ -282,85 +363,145 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
}
|
||||
|
||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||
// Simple upstream probe: GET {base}/v1/models without auth
|
||||
let url = format!("{}/v1/models", self.base_url);
|
||||
match self
|
||||
.client
|
||||
.get(&url)
|
||||
.timeout(std::time::Duration::from_secs(2))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let code = resp.status();
|
||||
// Treat success and auth-required as healthy (endpoint reachable)
|
||||
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
|
||||
(StatusCode::OK, "OK").into_response()
|
||||
} else {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("Upstream status: {}", code),
|
||||
)
|
||||
.into_response()
|
||||
// Check all endpoints in parallel - only healthy if ALL are healthy
|
||||
if self.worker_urls.is_empty() {
|
||||
return (StatusCode::SERVICE_UNAVAILABLE, "No endpoints configured").into_response();
|
||||
}
|
||||
|
||||
let mut handles = vec![];
|
||||
for url in &self.worker_urls {
|
||||
let url = url.clone();
|
||||
let client = self.client.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let probe_url = format!("{}/v1/models", url);
|
||||
match client
|
||||
.get(&probe_url)
|
||||
.timeout(Duration::from_secs(2))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let code = resp.status();
|
||||
// Treat success and auth-required as healthy (endpoint reachable)
|
||||
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Endpoint {} returned status {}", url, code))
|
||||
}
|
||||
}
|
||||
Err(e) => Err(format!("Endpoint {} error: {}", url, e)),
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Collect all results
|
||||
let mut errors = Vec::new();
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(Ok(())) => (),
|
||||
Ok(Err(e)) => errors.push(e),
|
||||
Err(e) => errors.push(format!("Task join error: {}", e)),
|
||||
}
|
||||
Err(e) => (
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
(StatusCode::OK, "OK").into_response()
|
||||
} else {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("Upstream error: {}", e),
|
||||
format!("Some endpoints unhealthy: {}", errors.join(", ")),
|
||||
)
|
||||
.into_response(),
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||
let info = json!({
|
||||
"router_type": "openai",
|
||||
"workers": 1,
|
||||
"base_url": &self.base_url
|
||||
"workers": self.worker_urls.len(),
|
||||
"worker_urls": &self.worker_urls
|
||||
});
|
||||
(StatusCode::OK, info.to_string()).into_response()
|
||||
}
|
||||
|
||||
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||
// Proxy to upstream /v1/models; forward Authorization header if provided
|
||||
let headers = req.headers();
|
||||
|
||||
let mut upstream = self.client.get(format!("{}/v1/models", self.base_url));
|
||||
|
||||
if let Some(auth) = headers
|
||||
.get("authorization")
|
||||
.or_else(|| headers.get("Authorization"))
|
||||
{
|
||||
upstream = upstream.header("Authorization", auth);
|
||||
// Aggregate models from all endpoints
|
||||
if self.worker_urls.is_empty() {
|
||||
return (StatusCode::SERVICE_UNAVAILABLE, "No endpoints configured").into_response();
|
||||
}
|
||||
|
||||
match upstream.send().await {
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let content_type = res.headers().get(CONTENT_TYPE).cloned();
|
||||
match res.bytes().await {
|
||||
Ok(body) => {
|
||||
let mut response = Response::new(Body::from(body));
|
||||
*response.status_mut() = status;
|
||||
if let Some(ct) = content_type {
|
||||
response.headers_mut().insert(CONTENT_TYPE, ct);
|
||||
let headers = req.headers();
|
||||
let auth = headers
|
||||
.get("authorization")
|
||||
.or_else(|| headers.get("Authorization"));
|
||||
|
||||
// Query all endpoints in parallel
|
||||
let mut handles = vec![];
|
||||
for url in &self.worker_urls {
|
||||
let url = url.clone();
|
||||
let client = self.client.clone();
|
||||
let auth = auth.cloned();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let models_url = format!("{}/v1/models", url);
|
||||
let req = client.get(&models_url);
|
||||
|
||||
// Apply provider-specific headers (handles Anthropic, xAI, OpenAI, etc.)
|
||||
let req = apply_provider_headers(req, &url, auth.as_ref());
|
||||
|
||||
match req.send().await {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
match res.json::<Value>().await {
|
||||
Ok(json) => Ok(json),
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to parse models response from '{}': {}",
|
||||
url,
|
||||
e
|
||||
);
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Getting models from '{}' failed with status: {}",
|
||||
url,
|
||||
res.status()
|
||||
);
|
||||
Err(())
|
||||
}
|
||||
response
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read upstream response: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
Err(e) => {
|
||||
tracing::warn!("Request to get models from '{}' failed: {}", url, e);
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Collect all model lists
|
||||
let mut all_models = Vec::new();
|
||||
for handle in handles {
|
||||
if let Ok(Ok(json)) = handle.await {
|
||||
if let Some(data) = json.get("data").and_then(|v| v.as_array()) {
|
||||
all_models.extend_from_slice(data);
|
||||
}
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to contact upstream: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
|
||||
// Return aggregated models
|
||||
let response_json = json!({
|
||||
"object": "list",
|
||||
"data": all_models
|
||||
});
|
||||
|
||||
(StatusCode::OK, Json(response_json)).into_response()
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||
@@ -396,6 +537,18 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
|
||||
}
|
||||
|
||||
// Extract auth header
|
||||
let auth = extract_auth_header(headers);
|
||||
|
||||
// Find endpoint for model
|
||||
let base_url = match self
|
||||
.find_endpoint_for_model(body.model.as_str(), auth)
|
||||
.await
|
||||
{
|
||||
Ok(url) => url,
|
||||
Err(response) => return response,
|
||||
};
|
||||
|
||||
// Serialize request body, removing SGLang-only fields
|
||||
let mut payload = match to_value(body) {
|
||||
Ok(v) => v,
|
||||
@@ -431,9 +584,14 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
] {
|
||||
obj.remove(key);
|
||||
}
|
||||
|
||||
// Remove logprobs if false (Gemini don't accept it)
|
||||
if obj.get("logprobs").and_then(|v| v.as_bool()) == Some(false) {
|
||||
obj.remove("logprobs");
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("{}/v1/chat/completions", self.base_url);
|
||||
let url = format!("{}/v1/chat/completions", base_url);
|
||||
let mut req = self.client.post(&url).json(&payload);
|
||||
|
||||
// Forward Authorization header if provided
|
||||
@@ -534,7 +692,17 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
body: &ResponsesRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
let url = format!("{}/v1/responses", self.base_url);
|
||||
// Extract auth header
|
||||
let auth = extract_auth_header(headers);
|
||||
|
||||
// Find endpoint for model (use model_id if provided, otherwise use body.model)
|
||||
let model = model_id.unwrap_or(body.model.as_str());
|
||||
let base_url = match self.find_endpoint_for_model(model, auth).await {
|
||||
Ok(url) => url,
|
||||
Err(response) => return response,
|
||||
};
|
||||
|
||||
let url = format!("{}/v1/responses", base_url);
|
||||
|
||||
// Validate mutually exclusive params: previous_response_id and conversation
|
||||
// TODO: this validation logic should move the right place, also we need a proper error message module
|
||||
@@ -556,7 +724,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
// Clone the body for validation and logic, but we'll build payload differently
|
||||
let mut request_body = body.clone();
|
||||
if let Some(model) = model_id {
|
||||
request_body.model = Some(model.to_string());
|
||||
request_body.model = model.to_string();
|
||||
}
|
||||
// Do not forward conversation field upstream; retain for local persistence only
|
||||
request_body.conversation = None;
|
||||
@@ -847,34 +1015,12 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
}
|
||||
}
|
||||
|
||||
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
||||
// Forward cancellation to upstream
|
||||
let url = format!("{}/v1/responses/{}/cancel", self.base_url, response_id);
|
||||
let mut req = self.client.post(&url);
|
||||
|
||||
if let Some(h) = headers {
|
||||
req = apply_request_headers(h, req, false);
|
||||
}
|
||||
|
||||
match req.send().await {
|
||||
Ok(resp) => {
|
||||
let status = StatusCode::from_u16(resp.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
match resp.text().await {
|
||||
Ok(body) => (status, body).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to contact upstream: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Cancel response not implemented for OpenAI router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use axum::http::{HeaderMap, HeaderValue};
|
||||
|
||||
// ============================================================================
|
||||
// SSE Event Type Constants
|
||||
// ============================================================================
|
||||
@@ -93,6 +95,131 @@ impl OutputIndexMapper {
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Detection and Header Handling
|
||||
// ============================================================================
|
||||
|
||||
/// Extract authorization header from request headers
|
||||
/// Checks both "authorization" and "Authorization" (case variations)
|
||||
pub fn extract_auth_header(headers: Option<&HeaderMap>) -> Option<&str> {
|
||||
headers.and_then(|h| {
|
||||
h.get("authorization")
|
||||
.or_else(|| h.get("Authorization"))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
})
|
||||
}
|
||||
|
||||
/// API provider types
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ApiProvider {
|
||||
Anthropic,
|
||||
Xai,
|
||||
OpenAi,
|
||||
Gemini,
|
||||
Generic,
|
||||
}
|
||||
|
||||
impl ApiProvider {
|
||||
/// Detect provider type from URL
|
||||
pub fn from_url(url: &str) -> Self {
|
||||
if url.contains("anthropic") {
|
||||
ApiProvider::Anthropic
|
||||
} else if url.contains("x.ai") {
|
||||
ApiProvider::Xai
|
||||
} else if url.contains("openai.com") {
|
||||
ApiProvider::OpenAi
|
||||
} else if url.contains("googleapis.com") {
|
||||
ApiProvider::Gemini
|
||||
} else {
|
||||
ApiProvider::Generic
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply provider-specific headers to request
|
||||
pub fn apply_provider_headers(
|
||||
mut req: reqwest::RequestBuilder,
|
||||
url: &str,
|
||||
auth_header: Option<&HeaderValue>,
|
||||
) -> reqwest::RequestBuilder {
|
||||
let provider = ApiProvider::from_url(url);
|
||||
|
||||
match provider {
|
||||
ApiProvider::Anthropic => {
|
||||
// Anthropic requires x-api-key instead of Authorization
|
||||
// Extract Bearer token and use as x-api-key
|
||||
if let Some(auth) = auth_header {
|
||||
if let Ok(auth_str) = auth.to_str() {
|
||||
let api_key = auth_str.strip_prefix("Bearer ").unwrap_or(auth_str);
|
||||
req = req
|
||||
.header("x-api-key", api_key)
|
||||
.header("anthropic-version", "2023-06-01");
|
||||
}
|
||||
}
|
||||
}
|
||||
ApiProvider::Gemini | ApiProvider::Xai | ApiProvider::OpenAi | ApiProvider::Generic => {
|
||||
// Standard OpenAI-compatible: use Authorization header as-is
|
||||
if let Some(auth) = auth_header {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req
|
||||
}
|
||||
|
||||
/// Probe a single endpoint to check if it has the model
|
||||
/// Returns Ok(url) if model found, Err(()) otherwise
|
||||
pub async fn probe_endpoint_for_model(
|
||||
client: reqwest::Client,
|
||||
url: String,
|
||||
model: String,
|
||||
auth: Option<String>,
|
||||
) -> Result<String, ()> {
|
||||
use tracing::debug;
|
||||
|
||||
let probe_url = format!("{}/v1/models/{}", url, model);
|
||||
let req = client
|
||||
.get(&probe_url)
|
||||
.timeout(std::time::Duration::from_secs(5));
|
||||
|
||||
// Apply provider-specific headers (handles Anthropic, xAI, OpenAI, etc.)
|
||||
let auth_header_value = auth.as_ref().and_then(|a| HeaderValue::from_str(a).ok());
|
||||
let req = apply_provider_headers(req, &url, auth_header_value.as_ref());
|
||||
|
||||
match req.send().await {
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
if status.is_success() {
|
||||
debug!(
|
||||
url = %url,
|
||||
model = %model,
|
||||
status = %status,
|
||||
"Model found on endpoint"
|
||||
);
|
||||
Ok(url)
|
||||
} else {
|
||||
debug!(
|
||||
url = %url,
|
||||
model = %model,
|
||||
status = %status,
|
||||
"Model not found on endpoint (unsuccessful status)"
|
||||
);
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!(
|
||||
url = %url,
|
||||
model = %model,
|
||||
error = %e,
|
||||
"Probe request to endpoint failed"
|
||||
);
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Re-export FunctionCallInProgress from mcp module
|
||||
// ============================================================================
|
||||
|
||||
@@ -410,7 +410,7 @@ impl RouterTrait for RouterManager {
|
||||
body: &ResponsesRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
let selected_model = body.model.as_deref().or(model_id);
|
||||
let selected_model = model_id.or(Some(body.model.as_str()));
|
||||
let router = self.select_router_for_request(headers, selected_model);
|
||||
|
||||
if let Some(router) = router {
|
||||
|
||||
@@ -100,7 +100,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
||||
max_output_tokens: Some(64),
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
model: "mock-model".to_string(),
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
@@ -134,7 +134,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
||||
};
|
||||
|
||||
let resp = router
|
||||
.route_responses(None, &req, req.model.as_deref())
|
||||
.route_responses(None, &req, Some(req.model.as_str()))
|
||||
.await;
|
||||
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
@@ -349,7 +349,7 @@ fn test_responses_request_creation() {
|
||||
max_output_tokens: Some(100),
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("test-model".to_string()),
|
||||
model: "test-model".to_string(),
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: Some(ResponseReasoningParam {
|
||||
@@ -397,7 +397,7 @@ fn test_responses_request_sglang_extensions() {
|
||||
max_output_tokens: Some(50),
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("test-model".to_string()),
|
||||
model: "test-model".to_string(),
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
@@ -506,7 +506,7 @@ fn test_json_serialization() {
|
||||
max_output_tokens: Some(200),
|
||||
max_tool_calls: Some(5),
|
||||
metadata: None,
|
||||
model: Some("gpt-4".to_string()),
|
||||
model: "gpt-4".to_string(),
|
||||
parallel_tool_calls: Some(false),
|
||||
previous_response_id: None,
|
||||
reasoning: Some(ResponseReasoningParam {
|
||||
@@ -545,7 +545,7 @@ fn test_json_serialization() {
|
||||
parsed.request_id,
|
||||
Some("resp_comprehensive_test".to_string())
|
||||
);
|
||||
assert_eq!(parsed.model, Some("gpt-4".to_string()));
|
||||
assert_eq!(parsed.model, "gpt-4");
|
||||
assert_eq!(parsed.background, Some(true));
|
||||
assert_eq!(parsed.stream, Some(true));
|
||||
assert_eq!(parsed.tools.as_ref().map(|t| t.len()), Some(1));
|
||||
@@ -636,7 +636,7 @@ async fn test_multi_turn_loop_with_mcp() {
|
||||
max_output_tokens: Some(128),
|
||||
max_tool_calls: None, // No limit - test unlimited
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
model: "mock-model".to_string(),
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
@@ -812,7 +812,7 @@ async fn test_max_tool_calls_limit() {
|
||||
max_output_tokens: Some(128),
|
||||
max_tool_calls: Some(1), // Limit to 1 call
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
model: "mock-model".to_string(),
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
@@ -1006,7 +1006,7 @@ async fn test_streaming_with_mcp_tool_calls() {
|
||||
max_output_tokens: Some(256),
|
||||
max_tool_calls: Some(3),
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
model: "mock-model".to_string(),
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
@@ -1287,7 +1287,7 @@ async fn test_streaming_multi_turn_with_mcp() {
|
||||
max_output_tokens: Some(512),
|
||||
max_tool_calls: Some(5), // Allow multiple rounds
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
model: "mock-model".to_string(),
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
|
||||
@@ -99,7 +99,7 @@ fn create_minimal_completion_request() -> CompletionRequest {
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_creation() {
|
||||
let router = OpenAIRouter::new(
|
||||
"https://api.openai.com".to_string(),
|
||||
vec!["https://api.openai.com".to_string()],
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
Arc::new(MemoryConversationStorage::new()),
|
||||
@@ -118,7 +118,7 @@ async fn test_openai_router_creation() {
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_server_info() {
|
||||
let router = OpenAIRouter::new(
|
||||
"https://api.openai.com".to_string(),
|
||||
vec!["https://api.openai.com".to_string()],
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
Arc::new(MemoryConversationStorage::new()),
|
||||
@@ -149,7 +149,7 @@ async fn test_openai_router_models() {
|
||||
// Use mock server for deterministic models response
|
||||
let mock_server = MockOpenAIServer::new().await;
|
||||
let router = OpenAIRouter::new(
|
||||
mock_server.base_url(),
|
||||
vec![mock_server.base_url()],
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
Arc::new(MemoryConversationStorage::new()),
|
||||
@@ -229,7 +229,7 @@ async fn test_openai_router_responses_with_mock() {
|
||||
let storage = Arc::new(MemoryResponseStorage::new());
|
||||
|
||||
let router = OpenAIRouter::new(
|
||||
base_url,
|
||||
vec![base_url],
|
||||
None,
|
||||
storage.clone(),
|
||||
Arc::new(MemoryConversationStorage::new()),
|
||||
@@ -239,7 +239,7 @@ async fn test_openai_router_responses_with_mock() {
|
||||
.unwrap();
|
||||
|
||||
let request1 = ResponsesRequest {
|
||||
model: Some("gpt-4o-mini".to_string()),
|
||||
model: "gpt-4o-mini".to_string(),
|
||||
input: ResponseInput::Text("Say hi".to_string()),
|
||||
store: Some(true),
|
||||
..Default::default()
|
||||
@@ -255,7 +255,7 @@ async fn test_openai_router_responses_with_mock() {
|
||||
assert_eq!(body1["previous_response_id"], serde_json::Value::Null);
|
||||
|
||||
let request2 = ResponsesRequest {
|
||||
model: Some("gpt-4o-mini".to_string()),
|
||||
model: "gpt-4o-mini".to_string(),
|
||||
input: ResponseInput::Text("Thanks".to_string()),
|
||||
store: Some(true),
|
||||
previous_response_id: Some(resp1_id.clone()),
|
||||
@@ -490,7 +490,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
|
||||
storage.store_response(previous).await.unwrap();
|
||||
|
||||
let router = OpenAIRouter::new(
|
||||
base_url,
|
||||
vec![base_url],
|
||||
None,
|
||||
storage.clone(),
|
||||
Arc::new(MemoryConversationStorage::new()),
|
||||
@@ -503,7 +503,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
|
||||
metadata.insert("topic".to_string(), json!("unicorns"));
|
||||
|
||||
let request = ResponsesRequest {
|
||||
model: Some("gpt-5-nano".to_string()),
|
||||
model: "gpt-5-nano".to_string(),
|
||||
input: ResponseInput::Text("Tell me a bedtime story.".to_string()),
|
||||
instructions: Some("Be kind".to_string()),
|
||||
metadata: Some(metadata),
|
||||
@@ -595,7 +595,7 @@ async fn test_router_factory_openai_mode() {
|
||||
#[tokio::test]
|
||||
async fn test_unsupported_endpoints() {
|
||||
let router = OpenAIRouter::new(
|
||||
"https://api.openai.com".to_string(),
|
||||
vec!["https://api.openai.com".to_string()],
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
Arc::new(MemoryConversationStorage::new()),
|
||||
@@ -660,7 +660,7 @@ async fn test_openai_router_chat_completion_with_mock() {
|
||||
|
||||
// Create router pointing to mock server
|
||||
let router = OpenAIRouter::new(
|
||||
base_url,
|
||||
vec![base_url],
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
Arc::new(MemoryConversationStorage::new()),
|
||||
@@ -702,7 +702,7 @@ async fn test_openai_e2e_with_server() {
|
||||
|
||||
// Create router
|
||||
let router = OpenAIRouter::new(
|
||||
base_url,
|
||||
vec![base_url],
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
Arc::new(MemoryConversationStorage::new()),
|
||||
@@ -773,7 +773,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
|
||||
let mock_server = MockOpenAIServer::new().await;
|
||||
let base_url = mock_server.base_url();
|
||||
let router = OpenAIRouter::new(
|
||||
base_url,
|
||||
vec![base_url],
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
Arc::new(MemoryConversationStorage::new()),
|
||||
@@ -827,7 +827,7 @@ async fn test_openai_router_circuit_breaker() {
|
||||
};
|
||||
|
||||
let router = OpenAIRouter::new(
|
||||
"http://invalid-url-that-will-fail".to_string(),
|
||||
vec!["http://invalid-url-that-will-fail".to_string()],
|
||||
Some(cb_config),
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
Arc::new(MemoryConversationStorage::new()),
|
||||
@@ -856,7 +856,7 @@ async fn test_openai_router_models_auth_forwarding() {
|
||||
let expected_auth = "Bearer test-token".to_string();
|
||||
let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await;
|
||||
let router = OpenAIRouter::new(
|
||||
mock_server.base_url(),
|
||||
vec![mock_server.base_url()],
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
Arc::new(MemoryConversationStorage::new()),
|
||||
@@ -865,7 +865,8 @@ async fn test_openai_router_models_auth_forwarding() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// 1) Without auth header -> expect 401
|
||||
// 1) Without auth header -> expect 200 with empty model list
|
||||
// (multi-endpoint aggregation silently skips failed endpoints)
|
||||
let req = Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("/models")
|
||||
@@ -873,7 +874,13 @@ async fn test_openai_router_models_auth_forwarding() {
|
||||
.unwrap();
|
||||
|
||||
let response = router.get_models(req).await;
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
let (_, body) = response.into_parts();
|
||||
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
|
||||
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||
let models: serde_json::Value = serde_json::from_str(&body_str).unwrap();
|
||||
assert_eq!(models["object"], "list");
|
||||
assert_eq!(models["data"].as_array().unwrap().len(), 0); // Empty when auth fails
|
||||
|
||||
// 2) With auth header -> expect 200
|
||||
let req = Request::builder()
|
||||
|
||||
Reference in New Issue
Block a user