[router] Add Rerank Routing Logic in Regular Router (#10219)
This commit is contained in:
@@ -1752,3 +1752,332 @@ mod request_id_tests {
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod rerank_tests {
|
||||
use super::*;
|
||||
// Note: RerankRequest and RerankResult are available for future use
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rerank_success() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18105,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"query": "machine learning algorithms",
|
||||
"documents": [
|
||||
"Introduction to machine learning concepts",
|
||||
"Deep learning neural networks tutorial"
|
||||
],
|
||||
"model": "test-rerank-model",
|
||||
"top_k": 2,
|
||||
"return_documents": true,
|
||||
"rid": "test-request-123"
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Verify response structure
|
||||
assert!(body_json.get("results").is_some());
|
||||
assert!(body_json.get("model").is_some());
|
||||
assert_eq!(body_json["model"], "test-rerank-model");
|
||||
|
||||
let results = body_json["results"].as_array().unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
|
||||
// Verify results are sorted by score (highest first)
|
||||
assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rerank_with_top_k() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18106,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"query": "test query",
|
||||
"documents": [
|
||||
"Document 1",
|
||||
"Document 2",
|
||||
"Document 3"
|
||||
],
|
||||
"model": "test-model",
|
||||
"top_k": 1,
|
||||
"return_documents": true
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Should only return top_k results
|
||||
let results = body_json["results"].as_array().unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rerank_without_documents() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18107,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"query": "test query",
|
||||
"documents": ["Document 1", "Document 2"],
|
||||
"model": "test-model",
|
||||
"return_documents": false
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Documents should be null when return_documents is false
|
||||
let results = body_json["results"].as_array().unwrap();
|
||||
for result in results {
|
||||
assert!(result.get("document").is_none());
|
||||
}
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rerank_worker_failure() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18108,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 1.0, // Always fail
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"query": "test query",
|
||||
"documents": ["Document 1"],
|
||||
"model": "test-model"
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
// Should return the worker's error response
|
||||
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_v1_rerank_compatibility() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18110,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
// Test V1 API format (simplified input)
|
||||
let payload = json!({
|
||||
"query": "machine learning algorithms",
|
||||
"documents": [
|
||||
"Introduction to machine learning concepts",
|
||||
"Deep learning neural networks tutorial",
|
||||
"Statistical learning theory basics"
|
||||
]
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// Verify response structure
|
||||
assert!(body_json.get("results").is_some());
|
||||
assert!(body_json.get("model").is_some());
|
||||
|
||||
// V1 API should use default model name
|
||||
assert_eq!(body_json["model"], "default");
|
||||
|
||||
let results = body_json["results"].as_array().unwrap();
|
||||
assert_eq!(results.len(), 3); // All documents should be returned
|
||||
|
||||
// Verify results are sorted by score (highest first)
|
||||
assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap());
|
||||
assert!(results[1]["score"].as_f64().unwrap() >= results[2]["score"].as_f64().unwrap());
|
||||
|
||||
// V1 API should return documents by default
|
||||
for result in results {
|
||||
assert!(result.get("document").is_some());
|
||||
}
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rerank_invalid_request() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18111,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
// Test empty query string (validation should fail)
|
||||
let payload = json!({
|
||||
"query": "",
|
||||
"documents": ["Document 1", "Document 2"],
|
||||
"model": "test-model"
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Test query with only whitespace (validation should fail)
|
||||
let payload = json!({
|
||||
"query": " ",
|
||||
"documents": ["Document 1", "Document 2"],
|
||||
"model": "test-model"
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Test empty documents list (validation should fail)
|
||||
let payload = json!({
|
||||
"query": "test query",
|
||||
"documents": [],
|
||||
"model": "test-model"
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Test invalid top_k (validation should fail)
|
||||
let payload = json!({
|
||||
"query": "test query",
|
||||
"documents": ["Document 1", "Document 2"],
|
||||
"model": "test-model",
|
||||
"top_k": 0
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rerank")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,6 +81,7 @@ impl MockWorker {
|
||||
.route("/generate", post(generate_handler))
|
||||
.route("/v1/chat/completions", post(chat_completions_handler))
|
||||
.route("/v1/completions", post(completions_handler))
|
||||
.route("/v1/rerank", post(rerank_handler))
|
||||
.route("/v1/responses", post(responses_handler))
|
||||
.route("/flush_cache", post(flush_cache_handler))
|
||||
.route("/v1/models", get(v1_models_handler))
|
||||
@@ -687,6 +688,56 @@ async fn v1_models_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn rerank_handler(
|
||||
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
||||
Json(payload): Json<serde_json::Value>,
|
||||
) -> impl IntoResponse {
|
||||
let config = config.read().await;
|
||||
|
||||
// Simulate response delay
|
||||
if config.response_delay_ms > 0 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await;
|
||||
}
|
||||
|
||||
// Simulate failure rate
|
||||
if rand::random::<f32>() < config.fail_rate {
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, "Simulated failure").into_response();
|
||||
}
|
||||
|
||||
// Extract documents from the request to create mock results
|
||||
let empty_vec = vec![];
|
||||
let documents = payload
|
||||
.get("documents")
|
||||
.and_then(|d| d.as_array())
|
||||
.unwrap_or(&empty_vec);
|
||||
|
||||
// Create mock rerank results with scores based on document index
|
||||
let mut mock_results = Vec::new();
|
||||
for (i, doc) in documents.iter().enumerate() {
|
||||
let score = 0.95 - (i as f32 * 0.1); // Decreasing scores
|
||||
let result = serde_json::json!({
|
||||
"score": score,
|
||||
"document": doc.as_str().unwrap_or(""),
|
||||
"index": i,
|
||||
"meta_info": {
|
||||
"confidence": if score > 0.9 { "high" } else { "medium" }
|
||||
}
|
||||
});
|
||||
mock_results.push(result);
|
||||
}
|
||||
|
||||
// Sort by score (highest first) to simulate proper ranking
|
||||
mock_results.sort_by(|a, b| {
|
||||
b["score"]
|
||||
.as_f64()
|
||||
.unwrap()
|
||||
.partial_cmp(&a["score"].as_f64().unwrap())
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
(StatusCode::OK, Json(mock_results)).into_response()
|
||||
}
|
||||
|
||||
impl Default for MockWorkerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
||||
Reference in New Issue
Block a user