[router] Add Rerank Routing Logic in Regular Router (#10219)
This commit is contained in:
@@ -1752,3 +1752,332 @@ mod request_id_tests {
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod rerank_tests {
|
||||
use super::*;
|
||||
// Note: RerankRequest and RerankResult are available for future use
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rerank_success() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18105,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"query": "machine learning algorithms",
|
||||
"documents": [
|
||||
"Introduction to machine learning concepts",
|
||||
"Deep learning neural networks tutorial"
|
||||
],
|
||||
"model": "test-rerank-model",
|
||||
"top_k": 2,
|
||||
"return_documents": true,
|
||||
"rid": "test-request-123"
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Verify response structure
|
||||
assert!(body_json.get("results").is_some());
|
||||
assert!(body_json.get("model").is_some());
|
||||
assert_eq!(body_json["model"], "test-rerank-model");
|
||||
|
||||
let results = body_json["results"].as_array().unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
|
||||
// Verify results are sorted by score (highest first)
|
||||
assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rerank_with_top_k() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18106,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"query": "test query",
|
||||
"documents": [
|
||||
"Document 1",
|
||||
"Document 2",
|
||||
"Document 3"
|
||||
],
|
||||
"model": "test-model",
|
||||
"top_k": 1,
|
||||
"return_documents": true
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Should only return top_k results
|
||||
let results = body_json["results"].as_array().unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rerank_without_documents() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18107,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"query": "test query",
|
||||
"documents": ["Document 1", "Document 2"],
|
||||
"model": "test-model",
|
||||
"return_documents": false
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Documents should be null when return_documents is false
|
||||
let results = body_json["results"].as_array().unwrap();
|
||||
for result in results {
|
||||
assert!(result.get("document").is_none());
|
||||
}
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rerank_worker_failure() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18108,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 1.0, // Always fail
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"query": "test query",
|
||||
"documents": ["Document 1"],
|
||||
"model": "test-model"
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
// Should return the worker's error response
|
||||
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_v1_rerank_compatibility() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18110,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
// Test V1 API format (simplified input)
|
||||
let payload = json!({
|
||||
"query": "machine learning algorithms",
|
||||
"documents": [
|
||||
"Introduction to machine learning concepts",
|
||||
"Deep learning neural networks tutorial",
|
||||
"Statistical learning theory basics"
|
||||
]
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Verify response structure
|
||||
assert!(body_json.get("results").is_some());
|
||||
assert!(body_json.get("model").is_some());
|
||||
|
||||
// V1 API should use default model name
|
||||
assert_eq!(body_json["model"], "default");
|
||||
|
||||
let results = body_json["results"].as_array().unwrap();
|
||||
assert_eq!(results.len(), 3); // All documents should be returned
|
||||
|
||||
// Verify results are sorted by score (highest first)
|
||||
assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap());
|
||||
assert!(results[1]["score"].as_f64().unwrap() >= results[2]["score"].as_f64().unwrap());
|
||||
|
||||
// V1 API should return documents by default
|
||||
for result in results {
|
||||
assert!(result.get("document").is_some());
|
||||
}
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rerank_invalid_request() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18111,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
// Test empty query string (validation should fail)
|
||||
let payload = json!({
|
||||
"query": "",
|
||||
"documents": ["Document 1", "Document 2"],
|
||||
"model": "test-model"
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Test query with only whitespace (validation should fail)
|
||||
let payload = json!({
|
||||
"query": " ",
|
||||
"documents": ["Document 1", "Document 2"],
|
||||
"model": "test-model"
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Test empty documents list (validation should fail)
|
||||
let payload = json!({
|
||||
"query": "test query",
|
||||
"documents": [],
|
||||
"model": "test-model"
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Test invalid top_k (validation should fail)
|
||||
let payload = json!({
|
||||
"query": "test query",
|
||||
"documents": ["Document 1", "Document 2"],
|
||||
"model": "test-model",
|
||||
"top_k": 0
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user