From 4634fd59536938c08385c02c44a0a9bf17b8b17a Mon Sep 17 00:00:00 2001 From: Frank Fang Date: Sat, 13 Sep 2025 00:10:18 +0800 Subject: [PATCH] [router] Add Rerank Routing Logic in Regular Router (#10219) --- sgl-router/src/protocols/spec.rs | 89 +++-- sgl-router/src/routers/grpc/pd_router.rs | 6 +- sgl-router/src/routers/grpc/router.rs | 6 +- sgl-router/src/routers/http/openai_router.rs | 6 +- sgl-router/src/routers/http/pd_router.rs | 25 +- sgl-router/src/routers/http/router.rs | 45 ++- sgl-router/src/routers/mod.rs | 4 +- sgl-router/src/server.rs | 24 +- sgl-router/tests/api_endpoints_test.rs | 329 +++++++++++++++++++ sgl-router/tests/common/mock_worker.rs | 51 +++ 10 files changed, 545 insertions(+), 40 deletions(-) diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index 43e60244c..583829747 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -1891,7 +1891,7 @@ pub struct RerankResponse { pub object: String, /// Response ID - pub id: String, + pub id: Option, /// Creation timestamp pub created: i64, @@ -1976,7 +1976,11 @@ impl RerankRequest { } impl RerankResponse { - pub fn new(results: Vec, model: String, request_id: String) -> Self { + pub fn new( + results: Vec, + model: String, + request_id: Option, + ) -> Self { RerankResponse { results, model, @@ -2000,6 +2004,13 @@ impl RerankResponse { pub fn apply_top_k(&mut self, k: usize) { self.results.truncate(k); } + + /// Drop documents from results + pub fn drop_documents(&mut self) { + self.results.iter_mut().for_each(|result| { + result.document = None; + }); + } } // ================================================================== @@ -2268,12 +2279,15 @@ mod tests { let response = RerankResponse::new( results.clone(), "test-model".to_string(), - "req-123".to_string(), + Some(StringOrArray::String("req-123".to_string())), ); assert_eq!(response.results.len(), 2); assert_eq!(response.model, "test-model"); - assert_eq!(response.id, "req-123"); + assert_eq!( + response.id, + Some(StringOrArray::String("req-123".to_string())) + ); assert_eq!(response.object, "rerank"); assert!(response.created > 0); } @@ -2287,8 +2301,11 @@ mod tests { meta_info: None, }]; - let response = - RerankResponse::new(results, "test-model".to_string(), "req-123".to_string()); + let response = RerankResponse::new( + results, + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); let serialized = serde_json::to_string(&response).unwrap(); let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap(); @@ -2322,8 +2339,11 @@ mod tests { }, ]; - let mut response = - RerankResponse::new(results, "test-model".to_string(), "req-123".to_string()); + let mut response = RerankResponse::new( + results, + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); response.sort_by_score(); @@ -2358,8 +2378,11 @@ mod tests { }, ]; - let mut response = - RerankResponse::new(results, "test-model".to_string(), "req-123".to_string()); + let mut response = RerankResponse::new( + results, + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); response.apply_top_k(2); @@ -2377,14 +2400,36 @@ mod tests { meta_info: None, }]; - let mut response = - RerankResponse::new(results, "test-model".to_string(), "req-123".to_string()); + let mut response = RerankResponse::new( + results, + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); response.apply_top_k(5); assert_eq!(response.results.len(), 1); } + #[test] + fn test_rerank_response_drop_documents() { + let results = vec![RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }]; + let mut response = RerankResponse::new( + results, + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); + + response.drop_documents(); + + assert_eq!(response.results[0].document, None); + } + // ================================================================== // = RERANK RESULT TESTS = // ================================================================== @@ -2570,8 +2615,11 @@ mod tests { meta_info: None, }]; - let mut response = - RerankResponse::new(results, "test-model".to_string(), "req-123".to_string()); + let mut response = RerankResponse::new( + results, + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); response.usage = Some(UsageInfo { prompt_tokens: 100, @@ -2645,18 +2693,7 @@ mod tests { ]; // Create response - let mut response = RerankResponse::new( - results, - request.model.clone(), - request - .rid - .as_ref() - .and_then(|r| match r { - StringOrArray::String(s) => Some(s.clone()), - StringOrArray::Array(arr) => arr.first().cloned(), - }) - .unwrap_or_else(|| "unknown".to_string()), - ); + let mut response = RerankResponse::new(results, request.model.clone(), request.rid.clone()); // Sort by score response.sort_by_score(); diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index 8c9645eca..a0a3c7911 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -301,7 +301,11 @@ impl RouterTrait for GrpcPDRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } - async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + async fn route_rerank( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::RerankRequest, + ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index d42753fc1..245513b37 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -234,7 +234,11 @@ impl RouterTrait for GrpcRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } - async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + async fn route_rerank( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::RerankRequest, + ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index b06f20810..0f5a56974 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -2,7 +2,9 @@ use crate::config::CircuitBreakerConfig; use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}; -use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; +use crate::protocols::spec::{ + ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, +}; use async_trait::async_trait; use axum::{ body::Body, @@ -381,7 +383,7 @@ impl super::super::RouterTrait for OpenAIRouter { .into_response() } - async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: &RerankRequest) -> Response { ( StatusCode::NOT_IMPLEMENTED, "Rerank endpoint not implemented for OpenAI backend", diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index d66eb8077..af4d605f0 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -9,8 +9,8 @@ use crate::core::{ use crate::metrics::RouterMetrics; use crate::policies::LoadBalancingPolicy; use crate::protocols::spec::{ - ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponsesRequest, - StringOrArray, UserMessageContent, + ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, RerankRequest, + ResponsesRequest, StringOrArray, UserMessageContent, }; use crate::routers::header_utils; use crate::routers::{RouterTrait, WorkerManagement}; @@ -1946,8 +1946,25 @@ impl RouterTrait for PDRouter { todo!() } - async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { - todo!() + async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response { + // Extract text for cache-aware routing + let req_text = if self.policies_need_request_text() { + Some(body.query.clone()) + } else { + None + }; + + // Create context + let context = PDRequestContext { + route: "/v1/rerank", + batch_size: None, + is_stream: false, + return_logprob: false, + request_text: req_text, + }; + + // Execute with retry and bootstrap injection + self.execute_dual_dispatch(headers, body, context).await } async fn flush_cache(&self) -> Response { diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index f0dc4f3b5..ca1b4d68f 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -6,10 +6,12 @@ use crate::core::{ use crate::metrics::RouterMetrics; use crate::policies::LoadBalancingPolicy; use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, ResponsesRequest, + ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, RerankRequest, + RerankResponse, RerankResult, ResponsesRequest, }; use crate::routers::header_utils; use crate::routers::{RouterTrait, WorkerManagement}; +use axum::body::to_bytes; use axum::{ body::Body, extract::Request, @@ -1124,6 +1126,25 @@ impl Router { } } } + + async fn build_rerank_response( + req: &RerankRequest, + response: Response, + ) -> anyhow::Result { + let (_, response_body) = response.into_parts(); + let body_bytes = to_bytes(response_body, usize::MAX).await?; + let rerank_results = serde_json::from_slice::>(&body_bytes)?; + let mut rerank_response = + RerankResponse::new(rerank_results, req.model.clone(), req.rid.clone()); + rerank_response.sort_by_score(); + if let Some(top_k) = req.top_k { + rerank_response.apply_top_k(top_k); + } + if !req.return_documents { + rerank_response.drop_documents(); + } + Ok(Json(rerank_response).into_response()) + } } use async_trait::async_trait; @@ -1223,8 +1244,26 @@ impl RouterTrait for Router { todo!() } - async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { - todo!() + async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response { + if let Err(e) = body.validate() { + return (StatusCode::BAD_REQUEST, e).into_response(); + } + let response = self.route_typed_request(headers, body, "/v1/rerank").await; + if response.status().is_success() { + match Self::build_rerank_response(body, response).await { + Ok(rerank_response) => rerank_response, + Err(e) => { + error!("Failed to build rerank response: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to build rerank response".to_string(), + ) + .into_response(); + } + } + } else { + response + } } async fn flush_cache(&self) -> Response { diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 6c12edbc8..3fe339d8f 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -10,7 +10,7 @@ use axum::{ use std::fmt::Debug; use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, GenerateRequest, ResponsesRequest, + ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest, }; pub mod factory; @@ -89,7 +89,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response; - async fn route_rerank(&self, headers: Option<&HeaderMap>, body: Body) -> Response; + async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response; /// Flush cache on all workers async fn flush_cache(&self) -> Response; diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 9aca370ae..f44924e38 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -3,7 +3,8 @@ use crate::logging::{self, LoggingConfig}; use crate::metrics::{self, PrometheusConfig}; use crate::middleware::TokenBucket; use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, GenerateRequest, ResponsesRequest, + ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest, + V1RerankReqInput, }; use crate::reasoning_parser::ParserFactory; use crate::routers::{RouterFactory, RouterTrait}; @@ -152,6 +153,25 @@ async fn v1_completions( state.router.route_completion(Some(&headers), &body).await } +async fn rerank( + State(state): State>, + headers: http::HeaderMap, + Json(body): Json, +) -> Response { + state.router.route_rerank(Some(&headers), &body).await +} + +async fn v1_rerank( + State(state): State>, + headers: http::HeaderMap, + Json(body): Json, +) -> Response { + state + .router + .route_rerank(Some(&headers), &body.into()) + .await +} + async fn v1_responses( State(state): State>, headers: http::HeaderMap, @@ -237,6 +257,8 @@ pub fn build_app( .route("/generate", post(generate)) .route("/v1/chat/completions", post(v1_chat_completions)) .route("/v1/completions", post(v1_completions)) + .route("/rerank", post(rerank)) + .route("/v1/rerank", post(v1_rerank)) .route("/v1/responses", post(v1_responses)) .route_layer(axum::middleware::from_fn_with_state( app_state.clone(), diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index 09099a7b8..39911a20d 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -1752,3 +1752,332 @@ mod request_id_tests { ctx.shutdown().await; } } + +#[cfg(test)] +mod rerank_tests { + use super::*; + // Note: RerankRequest and RerankResult are available for future use + + #[tokio::test] + async fn test_rerank_success() { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18105, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "query": "machine learning algorithms", + "documents": [ + "Introduction to machine learning concepts", + "Deep learning neural networks tutorial" + ], + "model": "test-rerank-model", + "top_k": 2, + "return_documents": true, + "rid": "test-request-123" + }); + + let req = Request::builder() + .method("POST") + .uri("/rerank") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + // Verify response structure + assert!(body_json.get("results").is_some()); + assert!(body_json.get("model").is_some()); + assert_eq!(body_json["model"], "test-rerank-model"); + + let results = body_json["results"].as_array().unwrap(); + assert_eq!(results.len(), 2); + + // Verify results are sorted by score (highest first) + assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap()); + + ctx.shutdown().await; + } + + #[tokio::test] + async fn test_rerank_with_top_k() { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18106, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "query": "test query", + "documents": [ + "Document 1", + "Document 2", + "Document 3" + ], + "model": "test-model", + "top_k": 1, + "return_documents": true + }); + + let req = Request::builder() + .method("POST") + .uri("/rerank") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + // Should only return top_k results + let results = body_json["results"].as_array().unwrap(); + assert_eq!(results.len(), 1); + + ctx.shutdown().await; + } + + #[tokio::test] + async fn test_rerank_without_documents() { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18107, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "query": "test query", + "documents": ["Document 1", "Document 2"], + "model": "test-model", + "return_documents": false + }); + + let req = Request::builder() + .method("POST") + .uri("/rerank") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + // Documents should be null when return_documents is false + let results = body_json["results"].as_array().unwrap(); + for result in results { + assert!(result.get("document").is_none()); + } + + ctx.shutdown().await; + } + + #[tokio::test] + async fn test_rerank_worker_failure() { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18108, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 1.0, // Always fail + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "query": "test query", + "documents": ["Document 1"], + "model": "test-model" + }); + + let req = Request::builder() + .method("POST") + .uri("/rerank") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + // Should return the worker's error response + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + + ctx.shutdown().await; + } + + #[tokio::test] + async fn test_v1_rerank_compatibility() { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18110, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // Test V1 API format (simplified input) + let payload = json!({ + "query": "machine learning algorithms", + "documents": [ + "Introduction to machine learning concepts", + "Deep learning neural networks tutorial", + "Statistical learning theory basics" + ] + }); + + let req = Request::builder() + .method("POST") + .uri("/v1/rerank") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + // Verify response structure + assert!(body_json.get("results").is_some()); + assert!(body_json.get("model").is_some()); + + // V1 API should use default model name + assert_eq!(body_json["model"], "default"); + + let results = body_json["results"].as_array().unwrap(); + assert_eq!(results.len(), 3); // All documents should be returned + + // Verify results are sorted by score (highest first) + assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap()); + assert!(results[1]["score"].as_f64().unwrap() >= results[2]["score"].as_f64().unwrap()); + + // V1 API should return documents by default + for result in results { + assert!(result.get("document").is_some()); + } + + ctx.shutdown().await; + } + + #[tokio::test] + async fn test_rerank_invalid_request() { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18111, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // Test empty query string (validation should fail) + let payload = json!({ + "query": "", + "documents": ["Document 1", "Document 2"], + "model": "test-model" + }); + + let req = Request::builder() + .method("POST") + .uri("/rerank") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + // Test query with only whitespace (validation should fail) + let payload = json!({ + "query": " ", + "documents": ["Document 1", "Document 2"], + "model": "test-model" + }); + + let req = Request::builder() + .method("POST") + .uri("/rerank") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + // Test empty documents list (validation should fail) + let payload = json!({ + "query": "test query", + "documents": [], + "model": "test-model" + }); + + let req = Request::builder() + .method("POST") + .uri("/rerank") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + // Test invalid top_k (validation should fail) + let payload = json!({ + "query": "test query", + "documents": ["Document 1", "Document 2"], + "model": "test-model", + "top_k": 0 + }); + + let req = Request::builder() + .method("POST") + .uri("/rerank") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + ctx.shutdown().await; + } +} diff --git a/sgl-router/tests/common/mock_worker.rs b/sgl-router/tests/common/mock_worker.rs index a04a76008..daf2a4889 100755 --- a/sgl-router/tests/common/mock_worker.rs +++ b/sgl-router/tests/common/mock_worker.rs @@ -81,6 +81,7 @@ impl MockWorker { .route("/generate", post(generate_handler)) .route("/v1/chat/completions", post(chat_completions_handler)) .route("/v1/completions", post(completions_handler)) + .route("/v1/rerank", post(rerank_handler)) .route("/v1/responses", post(responses_handler)) .route("/flush_cache", post(flush_cache_handler)) .route("/v1/models", get(v1_models_handler)) @@ -687,6 +688,56 @@ async fn v1_models_handler(State(config): State>>) .into_response() } +async fn rerank_handler( + State(config): State>>, + Json(payload): Json, +) -> impl IntoResponse { + let config = config.read().await; + + // Simulate response delay + if config.response_delay_ms > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await; + } + + // Simulate failure rate + if rand::random::() < config.fail_rate { + return (StatusCode::INTERNAL_SERVER_ERROR, "Simulated failure").into_response(); + } + + // Extract documents from the request to create mock results + let empty_vec = vec![]; + let documents = payload + .get("documents") + .and_then(|d| d.as_array()) + .unwrap_or(&empty_vec); + + // Create mock rerank results with scores based on document index + let mut mock_results = Vec::new(); + for (i, doc) in documents.iter().enumerate() { + let score = 0.95 - (i as f32 * 0.1); // Decreasing scores + let result = serde_json::json!({ + "score": score, + "document": doc.as_str().unwrap_or(""), + "index": i, + "meta_info": { + "confidence": if score > 0.9 { "high" } else { "medium" } + } + }); + mock_results.push(result); + } + + // Sort by score (highest first) to simulate proper ranking + mock_results.sort_by(|a, b| { + b["score"] + .as_f64() + .unwrap() + .partial_cmp(&a["score"].as_f64().unwrap()) + .unwrap() + }); + + (StatusCode::OK, Json(mock_results)).into_response() +} + impl Default for MockWorkerConfig { fn default() -> Self { Self {