[router] Support multiple worker URLs for OpenAI router (#11723)

This commit is contained in:
Keyang Ru
2025-10-22 09:27:58 -07:00
committed by GitHub
parent 1d097aac87
commit 77258ce039
9 changed files with 426 additions and 150 deletions

View File

@@ -100,7 +100,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
max_output_tokens: Some(64),
max_tool_calls: None,
metadata: None,
model: Some("mock-model".to_string()),
model: "mock-model".to_string(),
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
@@ -134,7 +134,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
};
let resp = router
.route_responses(None, &req, req.model.as_deref())
.route_responses(None, &req, Some(req.model.as_str()))
.await;
assert_eq!(resp.status(), StatusCode::OK);
@@ -349,7 +349,7 @@ fn test_responses_request_creation() {
max_output_tokens: Some(100),
max_tool_calls: None,
metadata: None,
model: Some("test-model".to_string()),
model: "test-model".to_string(),
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: Some(ResponseReasoningParam {
@@ -397,7 +397,7 @@ fn test_responses_request_sglang_extensions() {
max_output_tokens: Some(50),
max_tool_calls: None,
metadata: None,
model: Some("test-model".to_string()),
model: "test-model".to_string(),
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
@@ -506,7 +506,7 @@ fn test_json_serialization() {
max_output_tokens: Some(200),
max_tool_calls: Some(5),
metadata: None,
model: Some("gpt-4".to_string()),
model: "gpt-4".to_string(),
parallel_tool_calls: Some(false),
previous_response_id: None,
reasoning: Some(ResponseReasoningParam {
@@ -545,7 +545,7 @@ fn test_json_serialization() {
parsed.request_id,
Some("resp_comprehensive_test".to_string())
);
assert_eq!(parsed.model, Some("gpt-4".to_string()));
assert_eq!(parsed.model, "gpt-4");
assert_eq!(parsed.background, Some(true));
assert_eq!(parsed.stream, Some(true));
assert_eq!(parsed.tools.as_ref().map(|t| t.len()), Some(1));
@@ -636,7 +636,7 @@ async fn test_multi_turn_loop_with_mcp() {
max_output_tokens: Some(128),
max_tool_calls: None, // No limit - test unlimited
metadata: None,
model: Some("mock-model".to_string()),
model: "mock-model".to_string(),
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
@@ -812,7 +812,7 @@ async fn test_max_tool_calls_limit() {
max_output_tokens: Some(128),
max_tool_calls: Some(1), // Limit to 1 call
metadata: None,
model: Some("mock-model".to_string()),
model: "mock-model".to_string(),
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
@@ -1006,7 +1006,7 @@ async fn test_streaming_with_mcp_tool_calls() {
max_output_tokens: Some(256),
max_tool_calls: Some(3),
metadata: None,
model: Some("mock-model".to_string()),
model: "mock-model".to_string(),
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
@@ -1287,7 +1287,7 @@ async fn test_streaming_multi_turn_with_mcp() {
max_output_tokens: Some(512),
max_tool_calls: Some(5), // Allow multiple rounds
metadata: None,
model: Some("mock-model".to_string()),
model: "mock-model".to_string(),
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,

View File

@@ -99,7 +99,7 @@ fn create_minimal_completion_request() -> CompletionRequest {
#[tokio::test]
async fn test_openai_router_creation() {
let router = OpenAIRouter::new(
"https://api.openai.com".to_string(),
vec!["https://api.openai.com".to_string()],
None,
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
@@ -118,7 +118,7 @@ async fn test_openai_router_creation() {
#[tokio::test]
async fn test_openai_router_server_info() {
let router = OpenAIRouter::new(
"https://api.openai.com".to_string(),
vec!["https://api.openai.com".to_string()],
None,
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
@@ -149,7 +149,7 @@ async fn test_openai_router_models() {
// Use mock server for deterministic models response
let mock_server = MockOpenAIServer::new().await;
let router = OpenAIRouter::new(
mock_server.base_url(),
vec![mock_server.base_url()],
None,
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
@@ -229,7 +229,7 @@ async fn test_openai_router_responses_with_mock() {
let storage = Arc::new(MemoryResponseStorage::new());
let router = OpenAIRouter::new(
base_url,
vec![base_url],
None,
storage.clone(),
Arc::new(MemoryConversationStorage::new()),
@@ -239,7 +239,7 @@ async fn test_openai_router_responses_with_mock() {
.unwrap();
let request1 = ResponsesRequest {
model: Some("gpt-4o-mini".to_string()),
model: "gpt-4o-mini".to_string(),
input: ResponseInput::Text("Say hi".to_string()),
store: Some(true),
..Default::default()
@@ -255,7 +255,7 @@ async fn test_openai_router_responses_with_mock() {
assert_eq!(body1["previous_response_id"], serde_json::Value::Null);
let request2 = ResponsesRequest {
model: Some("gpt-4o-mini".to_string()),
model: "gpt-4o-mini".to_string(),
input: ResponseInput::Text("Thanks".to_string()),
store: Some(true),
previous_response_id: Some(resp1_id.clone()),
@@ -490,7 +490,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
storage.store_response(previous).await.unwrap();
let router = OpenAIRouter::new(
base_url,
vec![base_url],
None,
storage.clone(),
Arc::new(MemoryConversationStorage::new()),
@@ -503,7 +503,7 @@ async fn test_openai_router_responses_streaming_with_mock() {
metadata.insert("topic".to_string(), json!("unicorns"));
let request = ResponsesRequest {
model: Some("gpt-5-nano".to_string()),
model: "gpt-5-nano".to_string(),
input: ResponseInput::Text("Tell me a bedtime story.".to_string()),
instructions: Some("Be kind".to_string()),
metadata: Some(metadata),
@@ -595,7 +595,7 @@ async fn test_router_factory_openai_mode() {
#[tokio::test]
async fn test_unsupported_endpoints() {
let router = OpenAIRouter::new(
"https://api.openai.com".to_string(),
vec!["https://api.openai.com".to_string()],
None,
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
@@ -660,7 +660,7 @@ async fn test_openai_router_chat_completion_with_mock() {
// Create router pointing to mock server
let router = OpenAIRouter::new(
base_url,
vec![base_url],
None,
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
@@ -702,7 +702,7 @@ async fn test_openai_e2e_with_server() {
// Create router
let router = OpenAIRouter::new(
base_url,
vec![base_url],
None,
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
@@ -773,7 +773,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
let mock_server = MockOpenAIServer::new().await;
let base_url = mock_server.base_url();
let router = OpenAIRouter::new(
base_url,
vec![base_url],
None,
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
@@ -827,7 +827,7 @@ async fn test_openai_router_circuit_breaker() {
};
let router = OpenAIRouter::new(
"http://invalid-url-that-will-fail".to_string(),
vec!["http://invalid-url-that-will-fail".to_string()],
Some(cb_config),
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
@@ -856,7 +856,7 @@ async fn test_openai_router_models_auth_forwarding() {
let expected_auth = "Bearer test-token".to_string();
let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await;
let router = OpenAIRouter::new(
mock_server.base_url(),
vec![mock_server.base_url()],
None,
Arc::new(MemoryResponseStorage::new()),
Arc::new(MemoryConversationStorage::new()),
@@ -865,7 +865,8 @@ async fn test_openai_router_models_auth_forwarding() {
.await
.unwrap();
// 1) Without auth header -> expect 401
// 1) Without auth header -> expect 200 with empty model list
// (multi-endpoint aggregation silently skips failed endpoints)
let req = Request::builder()
.method(Method::GET)
.uri("/models")
@@ -873,7 +874,13 @@ async fn test_openai_router_models_auth_forwarding() {
.unwrap();
let response = router.get_models(req).await;
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(response.status(), StatusCode::OK);
let (_, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
let models: serde_json::Value = serde_json::from_str(&body_str).unwrap();
assert_eq!(models["object"], "list");
assert_eq!(models["data"].as_array().unwrap().len(), 0); // Empty when auth fails
// 2) With auth header -> expect 200
let req = Request::builder()