[router] Add OpenAI backend support - core function (#10254)
This commit is contained in:
@@ -63,10 +63,7 @@ impl ServerHandler for MockSearchServer {
|
||||
ServerInfo {
|
||||
protocol_version: ProtocolVersion::V_2024_11_05,
|
||||
capabilities: ServerCapabilities::builder().enable_tools().build(),
|
||||
server_info: Implementation {
|
||||
name: "Mock MCP Server".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
},
|
||||
server_info: Implementation::from_build_env(),
|
||||
instructions: Some("Mock server for testing".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
238
sgl-router/tests/common/mock_openai_server.rs
Normal file
238
sgl-router/tests/common/mock_openai_server.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
//! Mock servers for testing
|
||||
|
||||
#![allow(dead_code)]
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::{HeaderValue, StatusCode},
|
||||
response::sse::{Event, KeepAlive},
|
||||
response::{IntoResponse, Response, Sse},
|
||||
routing::post,
|
||||
Json, Router,
|
||||
};
|
||||
use futures_util::stream::{self, StreamExt};
|
||||
use serde_json::json;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Mock OpenAI API server for testing
|
||||
pub struct MockOpenAIServer {
|
||||
addr: SocketAddr,
|
||||
_handle: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MockServerState {
|
||||
require_auth: bool,
|
||||
expected_auth: Option<String>,
|
||||
}
|
||||
|
||||
impl MockOpenAIServer {
|
||||
/// Create and start a new mock OpenAI server
|
||||
pub async fn new() -> Self {
|
||||
Self::new_with_auth(None).await
|
||||
}
|
||||
|
||||
/// Create and start a new mock OpenAI server with optional auth requirement
|
||||
pub async fn new_with_auth(expected_auth: Option<String>) -> Self {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let state = Arc::new(MockServerState {
|
||||
require_auth: expected_auth.is_some(),
|
||||
expected_auth,
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
.route("/v1/chat/completions", post(mock_chat_completions))
|
||||
.route("/v1/completions", post(mock_completions))
|
||||
.route("/v1/models", post(mock_models).get(mock_models))
|
||||
.with_state(state);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
|
||||
// Give the server a moment to start
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
||||
|
||||
Self {
|
||||
addr,
|
||||
_handle: handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the base URL for this mock server
|
||||
pub fn base_url(&self) -> String {
|
||||
format!("http://{}", self.addr)
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock chat completions endpoint
|
||||
async fn mock_chat_completions(req: Request<Body>) -> Response {
|
||||
let (_, body) = req.into_parts();
|
||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
||||
};
|
||||
|
||||
let request: serde_json::Value = match serde_json::from_slice(&body_bytes) {
|
||||
Ok(req) => req,
|
||||
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
||||
};
|
||||
|
||||
// Extract model from request or use default (owned String to satisfy 'static in stream)
|
||||
let model: String = request
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("gpt-3.5-turbo")
|
||||
.to_string();
|
||||
|
||||
// If stream requested, return SSE
|
||||
let is_stream = request
|
||||
.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_stream {
|
||||
let created = 1677652288u64;
|
||||
// Single chunk then [DONE]
|
||||
let model_chunk = model.clone();
|
||||
let event_stream = stream::once(async move {
|
||||
let chunk = json!({
|
||||
"id": "chatcmpl-123456789",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model_chunk,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": "Hello!"
|
||||
},
|
||||
"finish_reason": null
|
||||
}]
|
||||
});
|
||||
Ok::<_, std::convert::Infallible>(Event::default().data(chunk.to_string()))
|
||||
})
|
||||
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
|
||||
|
||||
Sse::new(event_stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
} else {
|
||||
// Create a mock non-streaming response
|
||||
let response = json!({
|
||||
"id": "chatcmpl-123456789",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I'm a mock OpenAI assistant. How can I help you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
});
|
||||
|
||||
Json(response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock completions endpoint (legacy)
|
||||
async fn mock_completions(req: Request<Body>) -> Response {
|
||||
let (_, body) = req.into_parts();
|
||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
||||
};
|
||||
|
||||
let request: serde_json::Value = match serde_json::from_slice(&body_bytes) {
|
||||
Ok(req) => req,
|
||||
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
||||
};
|
||||
|
||||
let model = request["model"].as_str().unwrap_or("text-davinci-003");
|
||||
|
||||
let response = json!({
|
||||
"id": "cmpl-123456789",
|
||||
"object": "text_completion",
|
||||
"created": 1677652288,
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"text": " This is a mock completion response.",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 12
|
||||
}
|
||||
});
|
||||
|
||||
Json(response).into_response()
|
||||
}
|
||||
|
||||
/// Mock models endpoint
|
||||
async fn mock_models(State(state): State<Arc<MockServerState>>, req: Request<Body>) -> Response {
|
||||
// Optionally enforce Authorization header
|
||||
if state.require_auth {
|
||||
let auth = req
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.or_else(|| req.headers().get("Authorization"))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
let auth_ok = match (&state.expected_auth, auth) {
|
||||
(Some(expected), Some(got)) => &got == expected,
|
||||
(None, Some(_)) => true,
|
||||
_ => false,
|
||||
};
|
||||
if !auth_ok {
|
||||
let mut response = Response::new(Body::from(
|
||||
json!({
|
||||
"error": {
|
||||
"message": "Unauthorized",
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
));
|
||||
*response.status_mut() = StatusCode::UNAUTHORIZED;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("WWW-Authenticate", HeaderValue::from_static("Bearer"));
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
let response = json!({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"object": "model",
|
||||
"created": 1677610602,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-3.5-turbo",
|
||||
"object": "model",
|
||||
"created": 1677610602,
|
||||
"owned_by": "openai"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
Json(response).into_response()
|
||||
}
|
||||
0
sgl-router/tests/common/mock_worker.rs
Normal file → Executable file
0
sgl-router/tests/common/mock_worker.rs
Normal file → Executable file
@@ -2,6 +2,7 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
pub mod mock_mcp_server;
|
||||
pub mod mock_openai_server;
|
||||
pub mod mock_worker;
|
||||
pub mod test_app;
|
||||
|
||||
|
||||
419
sgl-router/tests/test_openai_routing.rs
Normal file
419
sgl-router/tests/test_openai_routing.rs
Normal file
@@ -0,0 +1,419 @@
|
||||
//! Comprehensive integration tests for OpenAI backend functionality
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{Method, StatusCode},
|
||||
routing::post,
|
||||
Router,
|
||||
};
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::{
|
||||
config::{RouterConfig, RoutingMode},
|
||||
protocols::spec::{
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, UserMessageContent,
|
||||
},
|
||||
routers::{openai_router::OpenAIRouter, RouterTrait},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tower::ServiceExt;
|
||||
|
||||
mod common;
|
||||
use common::mock_openai_server::MockOpenAIServer;
|
||||
|
||||
/// Helper function to create a minimal chat completion request for testing
|
||||
fn create_minimal_chat_request() -> ChatCompletionRequest {
|
||||
let val = json!({
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"max_tokens": 100
|
||||
});
|
||||
serde_json::from_value(val).unwrap()
|
||||
}
|
||||
|
||||
/// Helper function to create a minimal completion request for testing
|
||||
fn create_minimal_completion_request() -> CompletionRequest {
|
||||
CompletionRequest {
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
prompt: sglang_router_rs::protocols::spec::StringOrArray::String("Hello".to_string()),
|
||||
suffix: None,
|
||||
max_tokens: Some(100),
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
n: None,
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
logprobs: None,
|
||||
echo: false,
|
||||
stop: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
best_of: None,
|
||||
logit_bias: None,
|
||||
user: None,
|
||||
seed: None,
|
||||
top_k: None,
|
||||
min_p: None,
|
||||
min_tokens: None,
|
||||
repetition_penalty: None,
|
||||
regex: None,
|
||||
ebnf: None,
|
||||
json_schema: None,
|
||||
stop_token_ids: None,
|
||||
no_stop_trim: false,
|
||||
ignore_eos: false,
|
||||
skip_special_tokens: true,
|
||||
lora_path: None,
|
||||
session_params: None,
|
||||
return_hidden_states: false,
|
||||
other: serde_json::Map::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Basic Unit Tests =============
|
||||
|
||||
/// Test basic OpenAI router creation and configuration
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_creation() {
|
||||
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None).await;
|
||||
|
||||
assert!(router.is_ok(), "Router creation should succeed");
|
||||
|
||||
let router = router.unwrap();
|
||||
assert_eq!(router.router_type(), "openai");
|
||||
assert!(!router.is_pd_mode());
|
||||
}
|
||||
|
||||
/// Test health endpoints
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_health() {
|
||||
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let req = Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("/health")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = router.health(req).await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
/// Test server info endpoint
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_server_info() {
|
||||
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let req = Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("/info")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = router.get_server_info(req).await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let (_, body) = response.into_parts();
|
||||
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
|
||||
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||
|
||||
assert!(body_str.contains("openai"));
|
||||
}
|
||||
|
||||
/// Test models endpoint
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_models() {
|
||||
// Use mock server for deterministic models response
|
||||
let mock_server = MockOpenAIServer::new().await;
|
||||
let router = OpenAIRouter::new(mock_server.base_url(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let req = Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("/models")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = router.get_models(req).await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let (_, body) = response.into_parts();
|
||||
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
|
||||
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||
let models: serde_json::Value = serde_json::from_str(&body_str).unwrap();
|
||||
|
||||
assert_eq!(models["object"], "list");
|
||||
assert!(models["data"].is_array());
|
||||
}
|
||||
|
||||
/// Test router factory with OpenAI routing mode
|
||||
#[tokio::test]
|
||||
async fn test_router_factory_openai_mode() {
|
||||
let routing_mode = RoutingMode::OpenAI {
|
||||
worker_urls: vec!["https://api.openai.com".to_string()],
|
||||
};
|
||||
|
||||
let router_config =
|
||||
RouterConfig::new(routing_mode, sglang_router_rs::config::PolicyConfig::Random);
|
||||
|
||||
let app_context = common::create_test_context(router_config);
|
||||
|
||||
let router = sglang_router_rs::routers::RouterFactory::create_router(&app_context).await;
|
||||
assert!(
|
||||
router.is_ok(),
|
||||
"Router factory should create OpenAI router successfully"
|
||||
);
|
||||
|
||||
let router = router.unwrap();
|
||||
assert_eq!(router.router_type(), "openai");
|
||||
}
|
||||
|
||||
/// Test that unsupported endpoints return proper error codes
|
||||
#[tokio::test]
|
||||
async fn test_unsupported_endpoints() {
|
||||
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Test generate endpoint (SGLang-specific, should not be supported)
|
||||
let generate_request = GenerateRequest {
|
||||
prompt: None,
|
||||
text: Some("Hello world".to_string()),
|
||||
input_ids: None,
|
||||
parameters: None,
|
||||
sampling_params: None,
|
||||
stream: false,
|
||||
return_logprob: false,
|
||||
lora_path: None,
|
||||
session_params: None,
|
||||
return_hidden_states: false,
|
||||
rid: None,
|
||||
};
|
||||
|
||||
let response = router.route_generate(None, &generate_request).await;
|
||||
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
||||
|
||||
// Test completion endpoint (should also not be supported)
|
||||
let completion_request = create_minimal_completion_request();
|
||||
let response = router.route_completion(None, &completion_request).await;
|
||||
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
// ============= Mock Server E2E Tests =============
|
||||
|
||||
/// Test chat completion with mock OpenAI server
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_chat_completion_with_mock() {
|
||||
// Start a mock OpenAI server
|
||||
let mock_server = MockOpenAIServer::new().await;
|
||||
let base_url = mock_server.base_url();
|
||||
|
||||
// Create router pointing to mock server
|
||||
let router = OpenAIRouter::new(base_url, None).await.unwrap();
|
||||
|
||||
// Create a minimal chat completion request
|
||||
let mut chat_request = create_minimal_chat_request();
|
||||
chat_request.messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text("Hello, how are you?".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
chat_request.temperature = Some(0.7);
|
||||
|
||||
// Route the request
|
||||
let response = router.route_chat(None, &chat_request).await;
|
||||
|
||||
// Should get a successful response from mock server
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let (_, body) = response.into_parts();
|
||||
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
|
||||
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||
let chat_response: serde_json::Value = serde_json::from_str(&body_str).unwrap();
|
||||
|
||||
// Verify it's a valid chat completion response
|
||||
assert_eq!(chat_response["object"], "chat.completion");
|
||||
assert_eq!(chat_response["model"], "gpt-3.5-turbo");
|
||||
assert!(!chat_response["choices"].as_array().unwrap().is_empty());
|
||||
}
|
||||
|
||||
/// Test full E2E flow with Axum server
|
||||
#[tokio::test]
|
||||
async fn test_openai_e2e_with_server() {
|
||||
// Start mock OpenAI server
|
||||
let mock_server = MockOpenAIServer::new().await;
|
||||
let base_url = mock_server.base_url();
|
||||
|
||||
// Create router
|
||||
let router = OpenAIRouter::new(base_url, None).await.unwrap();
|
||||
|
||||
// Create Axum app with chat completions endpoint
|
||||
let app = Router::new().route(
|
||||
"/v1/chat/completions",
|
||||
post({
|
||||
let router = Arc::new(router);
|
||||
move |req: Request<Body>| {
|
||||
let router = router.clone();
|
||||
async move {
|
||||
let (parts, body) = req.into_parts();
|
||||
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
|
||||
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||
|
||||
let chat_request: ChatCompletionRequest =
|
||||
serde_json::from_str(&body_str).unwrap();
|
||||
|
||||
router.route_chat(Some(&parts.headers), &chat_request).await
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Make a request to the server
|
||||
let request = Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri("/v1/chat/completions")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
json!({
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, world!"
|
||||
}
|
||||
],
|
||||
"max_tokens": 100
|
||||
})
|
||||
.to_string(),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let response_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Verify the response structure
|
||||
assert_eq!(response_json["object"], "chat.completion");
|
||||
assert_eq!(response_json["model"], "gpt-3.5-turbo");
|
||||
assert!(!response_json["choices"].as_array().unwrap().is_empty());
|
||||
}
|
||||
|
||||
/// Test streaming chat completions pass-through with mock server
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_chat_streaming_with_mock() {
|
||||
let mock_server = MockOpenAIServer::new().await;
|
||||
let base_url = mock_server.base_url();
|
||||
let router = OpenAIRouter::new(base_url, None).await.unwrap();
|
||||
|
||||
// Build a streaming chat request
|
||||
let val = json!({
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"max_tokens": 10,
|
||||
"stream": true
|
||||
});
|
||||
let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap();
|
||||
|
||||
let response = router.route_chat(None, &chat_request).await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
// Should be SSE
|
||||
let headers = response.headers();
|
||||
let ct = headers
|
||||
.get("content-type")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_ascii_lowercase();
|
||||
assert!(ct.contains("text/event-stream"));
|
||||
|
||||
// Read entire stream body and assert chunks + DONE
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let text = String::from_utf8(body.to_vec()).unwrap();
|
||||
assert!(text.contains("chat.completion.chunk"));
|
||||
assert!(text.contains("[DONE]"));
|
||||
}
|
||||
|
||||
/// Test circuit breaker functionality
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_circuit_breaker() {
|
||||
// Create router with circuit breaker config
|
||||
let cb_config = sglang_router_rs::config::CircuitBreakerConfig {
|
||||
failure_threshold: 2,
|
||||
success_threshold: 1,
|
||||
timeout_duration_secs: 1,
|
||||
window_duration_secs: 10,
|
||||
};
|
||||
|
||||
let router = OpenAIRouter::new(
|
||||
"http://invalid-url-that-will-fail".to_string(),
|
||||
Some(cb_config),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let chat_request = create_minimal_chat_request();
|
||||
|
||||
// First few requests should fail and record failures
|
||||
for _ in 0..3 {
|
||||
let response = router.route_chat(None, &chat_request).await;
|
||||
// Should get either an error or circuit breaker response
|
||||
assert!(
|
||||
response.status() == StatusCode::INTERNAL_SERVER_ERROR
|
||||
|| response.status() == StatusCode::SERVICE_UNAVAILABLE
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Test that Authorization header is forwarded in /v1/models
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_models_auth_forwarding() {
|
||||
// Start a mock server that requires Authorization
|
||||
let expected_auth = "Bearer test-token".to_string();
|
||||
let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await;
|
||||
let router = OpenAIRouter::new(mock_server.base_url(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// 1) Without auth header -> expect 401
|
||||
let req = Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("/models")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = router.get_models(req).await;
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
|
||||
// 2) With auth header -> expect 200
|
||||
let req = Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("/models")
|
||||
.header("Authorization", expected_auth)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = router.get_models(req).await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let (_, body) = response.into_parts();
|
||||
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
|
||||
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||
let models: serde_json::Value = serde_json::from_str(&body_str).unwrap();
|
||||
assert_eq!(models["object"], "list");
|
||||
}
|
||||
Reference in New Issue
Block a user