[router] responses api POST and GET with local storage (#10581)
Co-authored-by: key4ng <rukeyang@gmail.com>
This commit is contained in:
@@ -5,17 +5,23 @@ use axum::{
|
||||
extract::Request,
|
||||
http::{Method, StatusCode},
|
||||
routing::post,
|
||||
Router,
|
||||
Json, Router,
|
||||
};
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::{
|
||||
config::{RouterConfig, RoutingMode},
|
||||
data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage},
|
||||
protocols::spec::{
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, UserMessageContent,
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput,
|
||||
ResponsesGetParams, ResponsesRequest, UserMessageContent,
|
||||
},
|
||||
routers::{openai_router::OpenAIRouter, RouterTrait},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use tokio::net::TcpListener;
|
||||
use tower::ServiceExt;
|
||||
|
||||
mod common;
|
||||
@@ -78,7 +84,12 @@ fn create_minimal_completion_request() -> CompletionRequest {
|
||||
/// Test basic OpenAI router creation and configuration
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_creation() {
|
||||
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None).await;
|
||||
let router = OpenAIRouter::new(
|
||||
"https://api.openai.com".to_string(),
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(router.is_ok(), "Router creation should succeed");
|
||||
|
||||
@@ -90,9 +101,13 @@ async fn test_openai_router_creation() {
|
||||
/// Test health endpoints
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_health() {
|
||||
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
let router = OpenAIRouter::new(
|
||||
"https://api.openai.com".to_string(),
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let req = Request::builder()
|
||||
.method(Method::GET)
|
||||
@@ -107,9 +122,13 @@ async fn test_openai_router_health() {
|
||||
/// Test server info endpoint
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_server_info() {
|
||||
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
let router = OpenAIRouter::new(
|
||||
"https://api.openai.com".to_string(),
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let req = Request::builder()
|
||||
.method(Method::GET)
|
||||
@@ -132,9 +151,13 @@ async fn test_openai_router_server_info() {
|
||||
async fn test_openai_router_models() {
|
||||
// Use mock server for deterministic models response
|
||||
let mock_server = MockOpenAIServer::new().await;
|
||||
let router = OpenAIRouter::new(mock_server.base_url(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
let router = OpenAIRouter::new(
|
||||
mock_server.base_url(),
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let req = Request::builder()
|
||||
.method(Method::GET)
|
||||
@@ -154,6 +177,138 @@ async fn test_openai_router_models() {
|
||||
assert!(models["data"].is_array());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_responses_with_mock() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let counter = Arc::new(AtomicUsize::new(0));
|
||||
let counter_clone = counter.clone();
|
||||
|
||||
let app = Router::new().route(
|
||||
"/v1/responses",
|
||||
post({
|
||||
move |Json(request): Json<serde_json::Value>| {
|
||||
let counter = counter_clone.clone();
|
||||
async move {
|
||||
let idx = counter.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
let model = request
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("gpt-4o-mini")
|
||||
.to_string();
|
||||
let id = format!("resp_mock_{idx}");
|
||||
let response = json!({
|
||||
"id": id,
|
||||
"object": "response",
|
||||
"created_at": 1_700_000_000 + idx as i64,
|
||||
"status": "completed",
|
||||
"model": model,
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"id": format!("msg_{idx}"),
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": format!("mock_output_{idx}"),
|
||||
"annotations": []
|
||||
}]
|
||||
}],
|
||||
"metadata": {}
|
||||
});
|
||||
Json(response)
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
|
||||
let base_url = format!("http://{}", addr);
|
||||
let storage = Arc::new(MemoryResponseStorage::new());
|
||||
|
||||
let router = OpenAIRouter::new(base_url, None, storage.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let request1 = ResponsesRequest {
|
||||
model: Some("gpt-4o-mini".to_string()),
|
||||
input: ResponseInput::Text("Say hi".to_string()),
|
||||
store: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let response1 = router.route_responses(None, &request1, None).await;
|
||||
assert_eq!(response1.status(), StatusCode::OK);
|
||||
let body1_bytes = axum::body::to_bytes(response1.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body1: serde_json::Value = serde_json::from_slice(&body1_bytes).unwrap();
|
||||
let resp1_id = body1["id"].as_str().expect("id missing").to_string();
|
||||
assert_eq!(body1["previous_response_id"], serde_json::Value::Null);
|
||||
|
||||
let request2 = ResponsesRequest {
|
||||
model: Some("gpt-4o-mini".to_string()),
|
||||
input: ResponseInput::Text("Thanks".to_string()),
|
||||
store: true,
|
||||
previous_response_id: Some(resp1_id.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let response2 = router.route_responses(None, &request2, None).await;
|
||||
assert_eq!(response2.status(), StatusCode::OK);
|
||||
let body2_bytes = axum::body::to_bytes(response2.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body2: serde_json::Value = serde_json::from_slice(&body2_bytes).unwrap();
|
||||
let resp2_id = body2["id"].as_str().expect("second id missing");
|
||||
assert_eq!(
|
||||
body2["previous_response_id"].as_str(),
|
||||
Some(resp1_id.as_str())
|
||||
);
|
||||
|
||||
let stored1 = storage
|
||||
.get_response(&ResponseId::from_string(resp1_id.clone()))
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("first response missing");
|
||||
assert_eq!(stored1.input, "Say hi");
|
||||
assert_eq!(stored1.output, "mock_output_1");
|
||||
assert!(stored1.previous_response_id.is_none());
|
||||
|
||||
let stored2 = storage
|
||||
.get_response(&ResponseId::from_string(resp2_id.to_string()))
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("second response missing");
|
||||
assert_eq!(stored2.previous_response_id.unwrap().0, resp1_id);
|
||||
assert_eq!(stored2.output, "mock_output_2");
|
||||
|
||||
let get1 = router
|
||||
.get_response(None, &stored1.id.0, &ResponsesGetParams::default())
|
||||
.await;
|
||||
assert_eq!(get1.status(), StatusCode::OK);
|
||||
let get1_body_bytes = axum::body::to_bytes(get1.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let get1_json: serde_json::Value = serde_json::from_slice(&get1_body_bytes).unwrap();
|
||||
assert_eq!(get1_json, body1);
|
||||
|
||||
let get2 = router
|
||||
.get_response(None, &stored2.id.0, &ResponsesGetParams::default())
|
||||
.await;
|
||||
assert_eq!(get2.status(), StatusCode::OK);
|
||||
let get2_body_bytes = axum::body::to_bytes(get2.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let get2_json: serde_json::Value = serde_json::from_slice(&get2_body_bytes).unwrap();
|
||||
assert_eq!(get2_json, body2);
|
||||
|
||||
server.abort();
|
||||
}
|
||||
|
||||
/// Test router factory with OpenAI routing mode
|
||||
#[tokio::test]
|
||||
async fn test_router_factory_openai_mode() {
|
||||
@@ -179,9 +334,13 @@ async fn test_router_factory_openai_mode() {
|
||||
/// Test that unsupported endpoints return proper error codes
|
||||
#[tokio::test]
|
||||
async fn test_unsupported_endpoints() {
|
||||
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
let router = OpenAIRouter::new(
|
||||
"https://api.openai.com".to_string(),
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Test generate endpoint (SGLang-specific, should not be supported)
|
||||
let generate_request = GenerateRequest {
|
||||
@@ -219,7 +378,9 @@ async fn test_openai_router_chat_completion_with_mock() {
|
||||
let base_url = mock_server.base_url();
|
||||
|
||||
// Create router pointing to mock server
|
||||
let router = OpenAIRouter::new(base_url, None).await.unwrap();
|
||||
let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Create a minimal chat completion request
|
||||
let mut chat_request = create_minimal_chat_request();
|
||||
@@ -255,7 +416,9 @@ async fn test_openai_e2e_with_server() {
|
||||
let base_url = mock_server.base_url();
|
||||
|
||||
// Create router
|
||||
let router = OpenAIRouter::new(base_url, None).await.unwrap();
|
||||
let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Create Axum app with chat completions endpoint
|
||||
let app = Router::new().route(
|
||||
@@ -319,7 +482,9 @@ async fn test_openai_e2e_with_server() {
|
||||
async fn test_openai_router_chat_streaming_with_mock() {
|
||||
let mock_server = MockOpenAIServer::new().await;
|
||||
let base_url = mock_server.base_url();
|
||||
let router = OpenAIRouter::new(base_url, None).await.unwrap();
|
||||
let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Build a streaming chat request
|
||||
let val = json!({
|
||||
@@ -368,6 +533,7 @@ async fn test_openai_router_circuit_breaker() {
|
||||
let router = OpenAIRouter::new(
|
||||
"http://invalid-url-that-will-fail".to_string(),
|
||||
Some(cb_config),
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -391,9 +557,13 @@ async fn test_openai_router_models_auth_forwarding() {
|
||||
// Start a mock server that requires Authorization
|
||||
let expected_auth = "Bearer test-token".to_string();
|
||||
let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await;
|
||||
let router = OpenAIRouter::new(mock_server.base_url(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
let router = OpenAIRouter::new(
|
||||
mock_server.base_url(),
|
||||
None,
|
||||
Arc::new(MemoryResponseStorage::new()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// 1) Without auth header -> expect 401
|
||||
let req = Request::builder()
|
||||
|
||||
Reference in New Issue
Block a user