diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index a704bf185..a7c896f75 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -38,7 +38,10 @@ use std::collections::HashMap; // - Sampling Parameters // - Request/Response structures // -// 6. **COMMON** +// 6. **SGLANG SPEC - RERANK API** +// - Request/Response structures +// +// 7. **COMMON** // - GenerationRequest trait // - StringOrArray & LoRAPath types // - Helper functions @@ -1805,6 +1808,196 @@ impl GenerationRequest for GenerateRequest { } } +// ================================================================== +// = SGLANG SPEC - RERANK API = +// ================================================================== + +// Constants for rerank API +pub const DEFAULT_MODEL_NAME: &str = "default"; + +/// Rerank request for scoring documents against a query +/// Used for RAG systems and document relevance scoring +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RerankRequest { + /// The query text to rank documents against + pub query: String, + + /// List of documents to be ranked + pub documents: Vec, + + /// Model to use for reranking + #[serde(default = "default_model_name")] + pub model: String, + + /// Maximum number of documents to return (optional) + pub top_k: Option, + + /// Whether to return documents in addition to scores + #[serde(default = "default_return_documents")] + pub return_documents: bool, + + // SGLang specific extensions + /// Request ID for tracking + pub rid: Option, + + /// User identifier + pub user: Option, +} + +fn default_model_name() -> String { + DEFAULT_MODEL_NAME.to_string() +} + +fn default_return_documents() -> bool { + true +} + +/// Individual rerank result +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RerankResult { + /// Relevance score for the document + pub score: f32, + + /// The document text (if return_documents was true) + #[serde(skip_serializing_if = "Option::is_none")] + pub document: Option, + + /// Original index of the document in the request + pub index: usize, + + /// Additional metadata about the ranking + #[serde(skip_serializing_if = "Option::is_none")] + pub meta_info: Option>, +} + +/// Rerank response containing sorted results +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RerankResponse { + /// Ranked results sorted by score (highest first) + pub results: Vec, + + /// Model used for reranking + pub model: String, + + /// Usage information + pub usage: Option, + + /// Response object type + #[serde(default = "default_rerank_object")] + pub object: String, + + /// Response ID + pub id: String, + + /// Creation timestamp + pub created: i64, +} + +fn default_rerank_object() -> String { + "rerank".to_string() +} + +/// V1 API compatibility format for rerank requests +/// Matches Python's V1RerankReqInput +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct V1RerankReqInput { + pub query: String, + pub documents: Vec, +} + +/// Convert V1RerankReqInput to RerankRequest +impl From for RerankRequest { + fn from(v1: V1RerankReqInput) -> Self { + RerankRequest { + query: v1.query, + documents: v1.documents, + model: default_model_name(), + top_k: None, + return_documents: true, + rid: None, + user: None, + } + } +} + +/// Implementation of GenerationRequest trait for RerankRequest +impl GenerationRequest for RerankRequest { + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn is_stream(&self) -> bool { + false // Reranking doesn't support streaming + } + + fn extract_text_for_routing(&self) -> String { + self.query.clone() + } +} + +impl RerankRequest { + pub fn validate(&self) -> Result<(), String> { + // Validate query is not empty + if self.query.trim().is_empty() { + return Err("Query cannot be empty".to_string()); + } + + // Validate documents list + if self.documents.is_empty() { + return Err("Documents list cannot be empty".to_string()); + } + + // Validate top_k if specified + if let Some(k) = self.top_k { + if k == 0 { + return Err("top_k must be greater than 0".to_string()); + } + if k > self.documents.len() { + // This is allowed but we log a warning + tracing::warn!( + "top_k ({}) is greater than number of documents ({})", + k, + self.documents.len() + ); + } + } + + Ok(()) + } + + /// Get the effective top_k value + pub fn effective_top_k(&self) -> usize { + self.top_k.unwrap_or(self.documents.len()) + } +} + +impl RerankResponse { + pub fn new(results: Vec, model: String, request_id: String) -> Self { + RerankResponse { + results, + model, + usage: None, + object: default_rerank_object(), + id: request_id, + created: current_timestamp(), + } + } + + /// Sort results by score in descending order + pub fn sort_by_score(&mut self) { + self.results.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + } + + /// Apply top_k limit to results + pub fn apply_top_k(&mut self, k: usize) { + self.results.truncate(k); + } +} + // ================================================================== // = COMMON = // ================================================================== @@ -1827,7 +2020,7 @@ pub trait GenerationRequest: Send + Sync { } /// Helper type for string or array of strings -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] #[serde(untagged)] pub enum StringOrArray { String(String), @@ -1866,3 +2059,619 @@ pub enum LoRAPath { Single(Option), Batch(Vec>), } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json; + + // ================================================================== + // = RERANK REQUEST TESTS = + // ================================================================== + + #[test] + fn test_rerank_request_serialization() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + model: "test-model".to_string(), + top_k: Some(5), + return_documents: true, + rid: Some(StringOrArray::String("req-123".to_string())), + user: Some("user-456".to_string()), + }; + + let serialized = serde_json::to_string(&request).unwrap(); + let deserialized: RerankRequest = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(deserialized.query, request.query); + assert_eq!(deserialized.documents, request.documents); + assert_eq!(deserialized.model, request.model); + assert_eq!(deserialized.top_k, request.top_k); + assert_eq!(deserialized.return_documents, request.return_documents); + assert_eq!(deserialized.rid, request.rid); + assert_eq!(deserialized.user, request.user); + } + + #[test] + fn test_rerank_request_deserialization_with_defaults() { + let json = r#"{ + "query": "test query", + "documents": ["doc1", "doc2"] + }"#; + + let request: RerankRequest = serde_json::from_str(json).unwrap(); + + assert_eq!(request.query, "test query"); + assert_eq!(request.documents, vec!["doc1", "doc2"]); + assert_eq!(request.model, default_model_name()); + assert_eq!(request.top_k, None); + assert!(request.return_documents); + assert_eq!(request.rid, None); + assert_eq!(request.user, None); + } + + #[test] + fn test_rerank_request_validation_success() { + let request = RerankRequest { + query: "valid query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + model: "test-model".to_string(), + top_k: Some(2), + return_documents: true, + rid: None, + user: None, + }; + + assert!(request.validate().is_ok()); + } + + #[test] + fn test_rerank_request_validation_empty_query() { + let request = RerankRequest { + query: "".to_string(), + documents: vec!["doc1".to_string()], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: None, + user: None, + }; + + let result = request.validate(); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Query cannot be empty"); + } + + #[test] + fn test_rerank_request_validation_whitespace_query() { + let request = RerankRequest { + query: " ".to_string(), + documents: vec!["doc1".to_string()], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: None, + user: None, + }; + + let result = request.validate(); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Query cannot be empty"); + } + + #[test] + fn test_rerank_request_validation_empty_documents() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec![], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: None, + user: None, + }; + + let result = request.validate(); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Documents list cannot be empty"); + } + + #[test] + fn test_rerank_request_validation_top_k_zero() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + model: "test-model".to_string(), + top_k: Some(0), + return_documents: true, + rid: None, + user: None, + }; + + let result = request.validate(); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "top_k must be greater than 0"); + } + + #[test] + fn test_rerank_request_validation_top_k_greater_than_docs() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + model: "test-model".to_string(), + top_k: Some(5), + return_documents: true, + rid: None, + user: None, + }; + + // This should pass but log a warning + assert!(request.validate().is_ok()); + } + + #[test] + fn test_rerank_request_effective_top_k() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()], + model: "test-model".to_string(), + top_k: Some(2), + return_documents: true, + rid: None, + user: None, + }; + + assert_eq!(request.effective_top_k(), 2); + } + + #[test] + fn test_rerank_request_effective_top_k_none() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: None, + user: None, + }; + + assert_eq!(request.effective_top_k(), 3); + } + + // ================================================================== + // = RERANK RESPONSE TESTS = + // ================================================================== + + #[test] + fn test_rerank_response_creation() { + let results = vec![ + RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }, + RerankResult { + score: 0.6, + document: Some("doc2".to_string()), + index: 1, + meta_info: None, + }, + ]; + + let response = RerankResponse::new( + results.clone(), + "test-model".to_string(), + "req-123".to_string(), + ); + + assert_eq!(response.results.len(), 2); + assert_eq!(response.model, "test-model"); + assert_eq!(response.id, "req-123"); + assert_eq!(response.object, "rerank"); + assert!(response.created > 0); + } + + #[test] + fn test_rerank_response_serialization() { + let results = vec![RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }]; + + let response = + RerankResponse::new(results, "test-model".to_string(), "req-123".to_string()); + + let serialized = serde_json::to_string(&response).unwrap(); + let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(deserialized.results.len(), response.results.len()); + assert_eq!(deserialized.model, response.model); + assert_eq!(deserialized.id, response.id); + assert_eq!(deserialized.object, response.object); + } + + #[test] + fn test_rerank_response_sort_by_score() { + let results = vec![ + RerankResult { + score: 0.6, + document: Some("doc2".to_string()), + index: 1, + meta_info: None, + }, + RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }, + RerankResult { + score: 0.4, + document: Some("doc3".to_string()), + index: 2, + meta_info: None, + }, + ]; + + let mut response = + RerankResponse::new(results, "test-model".to_string(), "req-123".to_string()); + + response.sort_by_score(); + + assert_eq!(response.results[0].score, 0.8); + assert_eq!(response.results[0].index, 0); + assert_eq!(response.results[1].score, 0.6); + assert_eq!(response.results[1].index, 1); + assert_eq!(response.results[2].score, 0.4); + assert_eq!(response.results[2].index, 2); + } + + #[test] + fn test_rerank_response_apply_top_k() { + let results = vec![ + RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }, + RerankResult { + score: 0.6, + document: Some("doc2".to_string()), + index: 1, + meta_info: None, + }, + RerankResult { + score: 0.4, + document: Some("doc3".to_string()), + index: 2, + meta_info: None, + }, + ]; + + let mut response = + RerankResponse::new(results, "test-model".to_string(), "req-123".to_string()); + + response.apply_top_k(2); + + assert_eq!(response.results.len(), 2); + assert_eq!(response.results[0].score, 0.8); + assert_eq!(response.results[1].score, 0.6); + } + + #[test] + fn test_rerank_response_apply_top_k_larger_than_results() { + let results = vec![RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }]; + + let mut response = + RerankResponse::new(results, "test-model".to_string(), "req-123".to_string()); + + response.apply_top_k(5); + + assert_eq!(response.results.len(), 1); + } + + // ================================================================== + // = RERANK RESULT TESTS = + // ================================================================== + + #[test] + fn test_rerank_result_serialization() { + let result = RerankResult { + score: 0.85, + document: Some("test document".to_string()), + index: 42, + meta_info: Some(HashMap::from([ + ("confidence".to_string(), Value::String("high".to_string())), + ( + "processing_time".to_string(), + Value::Number(serde_json::Number::from(150)), + ), + ])), + }; + + let serialized = serde_json::to_string(&result).unwrap(); + let deserialized: RerankResult = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(deserialized.score, result.score); + assert_eq!(deserialized.document, result.document); + assert_eq!(deserialized.index, result.index); + assert_eq!(deserialized.meta_info, result.meta_info); + } + + #[test] + fn test_rerank_result_serialization_without_document() { + let result = RerankResult { + score: 0.85, + document: None, + index: 42, + meta_info: None, + }; + + let serialized = serde_json::to_string(&result).unwrap(); + let deserialized: RerankResult = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(deserialized.score, result.score); + assert_eq!(deserialized.document, result.document); + assert_eq!(deserialized.index, result.index); + assert_eq!(deserialized.meta_info, result.meta_info); + } + + // ================================================================== + // = V1 COMPATIBILITY TESTS = + // ================================================================== + + #[test] + fn test_v1_rerank_req_input_serialization() { + let v1_input = V1RerankReqInput { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + }; + + let serialized = serde_json::to_string(&v1_input).unwrap(); + let deserialized: V1RerankReqInput = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(deserialized.query, v1_input.query); + assert_eq!(deserialized.documents, v1_input.documents); + } + + #[test] + fn test_v1_to_rerank_request_conversion() { + let v1_input = V1RerankReqInput { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + }; + + let request: RerankRequest = v1_input.into(); + + assert_eq!(request.query, "test query"); + assert_eq!(request.documents, vec!["doc1", "doc2"]); + assert_eq!(request.model, default_model_name()); + assert_eq!(request.top_k, None); + assert!(request.return_documents); + assert_eq!(request.rid, None); + assert_eq!(request.user, None); + } + + // ================================================================== + // = GENERATION REQUEST TRAIT TESTS = + // ================================================================== + + #[test] + fn test_rerank_request_generation_request_trait() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string()], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: None, + user: None, + }; + + assert_eq!(request.get_model(), Some("test-model")); + assert!(!request.is_stream()); + assert_eq!(request.extract_text_for_routing(), "test query"); + } + + // ================================================================== + // = EDGE CASES AND STRESS TESTS = + // ================================================================== + + #[test] + fn test_rerank_request_very_long_query() { + let long_query = "a".repeat(100000); + let request = RerankRequest { + query: long_query, + documents: vec!["doc1".to_string()], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: None, + user: None, + }; + + assert!(request.validate().is_ok()); + } + + #[test] + fn test_rerank_request_many_documents() { + let documents: Vec = (0..1000).map(|i| format!("doc{}", i)).collect(); + let request = RerankRequest { + query: "test query".to_string(), + documents, + model: "test-model".to_string(), + top_k: Some(100), + return_documents: true, + rid: None, + user: None, + }; + + assert!(request.validate().is_ok()); + assert_eq!(request.effective_top_k(), 100); + } + + #[test] + fn test_rerank_request_special_characters() { + let request = RerankRequest { + query: "query with émojis 🚀 and unicode: 测试".to_string(), + documents: vec![ + "doc with émojis 🎉".to_string(), + "doc with unicode: 测试".to_string(), + ], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: Some(StringOrArray::String("req-🚀-123".to_string())), + user: Some("user-🎉-456".to_string()), + }; + + assert!(request.validate().is_ok()); + } + + #[test] + fn test_rerank_request_rid_array() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string()], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: Some(StringOrArray::Array(vec![ + "req1".to_string(), + "req2".to_string(), + ])), + user: None, + }; + + assert!(request.validate().is_ok()); + } + + #[test] + fn test_rerank_response_with_usage_info() { + let results = vec![RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }]; + + let mut response = + RerankResponse::new(results, "test-model".to_string(), "req-123".to_string()); + + response.usage = Some(UsageInfo { + prompt_tokens: 100, + completion_tokens: 50, + total_tokens: 150, + reasoning_tokens: None, + prompt_tokens_details: None, + }); + + let serialized = serde_json::to_string(&response).unwrap(); + let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap(); + + assert!(deserialized.usage.is_some()); + let usage = deserialized.usage.unwrap(); + assert_eq!(usage.prompt_tokens, 100); + assert_eq!(usage.completion_tokens, 50); + assert_eq!(usage.total_tokens, 150); + } + + // ================================================================== + // = INTEGRATION TESTS = + // ================================================================== + + #[test] + fn test_full_rerank_workflow() { + // Create request + let request = RerankRequest { + query: "machine learning".to_string(), + documents: vec![ + "Introduction to machine learning algorithms".to_string(), + "Deep learning for computer vision".to_string(), + "Natural language processing basics".to_string(), + "Statistics and probability theory".to_string(), + ], + model: "rerank-model".to_string(), + top_k: Some(2), + return_documents: true, + rid: Some(StringOrArray::String("req-123".to_string())), + user: Some("user-456".to_string()), + }; + + // Validate request + assert!(request.validate().is_ok()); + + // Simulate reranking results (in real scenario, this would come from the model) + let results = vec![ + RerankResult { + score: 0.95, + document: Some("Introduction to machine learning algorithms".to_string()), + index: 0, + meta_info: None, + }, + RerankResult { + score: 0.87, + document: Some("Deep learning for computer vision".to_string()), + index: 1, + meta_info: None, + }, + RerankResult { + score: 0.72, + document: Some("Natural language processing basics".to_string()), + index: 2, + meta_info: None, + }, + RerankResult { + score: 0.45, + document: Some("Statistics and probability theory".to_string()), + index: 3, + meta_info: None, + }, + ]; + + // Create response + let mut response = RerankResponse::new( + results, + request.model.clone(), + request + .rid + .as_ref() + .and_then(|r| match r { + StringOrArray::String(s) => Some(s.clone()), + StringOrArray::Array(arr) => arr.first().cloned(), + }) + .unwrap_or_else(|| "unknown".to_string()), + ); + + // Sort by score + response.sort_by_score(); + + // Apply top_k + response.apply_top_k(request.effective_top_k()); + + // Verify results + assert_eq!(response.results.len(), 2); + assert_eq!(response.results[0].score, 0.95); + assert_eq!(response.results[0].index, 0); + assert_eq!(response.results[1].score, 0.87); + assert_eq!(response.results[1].index, 1); + assert_eq!(response.model, "rerank-model"); + + // Serialize and deserialize + let serialized = serde_json::to_string(&response).unwrap(); + let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized.results.len(), 2); + assert_eq!(deserialized.model, response.model); + } +}