[router] Add Rerank API Specification (#9906)
This commit is contained in:
@@ -38,7 +38,10 @@ use std::collections::HashMap;
|
||||
// - Sampling Parameters
|
||||
// - Request/Response structures
|
||||
//
|
||||
// 6. **COMMON**
|
||||
// 6. **SGLANG SPEC - RERANK API**
|
||||
// - Request/Response structures
|
||||
//
|
||||
// 7. **COMMON**
|
||||
// - GenerationRequest trait
|
||||
// - StringOrArray & LoRAPath types
|
||||
// - Helper functions
|
||||
@@ -1805,6 +1808,196 @@ impl GenerationRequest for GenerateRequest {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = SGLANG SPEC - RERANK API =
|
||||
// ==================================================================
|
||||
|
||||
// Constants for rerank API
|
||||
pub const DEFAULT_MODEL_NAME: &str = "default";
|
||||
|
||||
/// Rerank request for scoring documents against a query
|
||||
/// Used for RAG systems and document relevance scoring
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RerankRequest {
|
||||
/// The query text to rank documents against
|
||||
pub query: String,
|
||||
|
||||
/// List of documents to be ranked
|
||||
pub documents: Vec<String>,
|
||||
|
||||
/// Model to use for reranking
|
||||
#[serde(default = "default_model_name")]
|
||||
pub model: String,
|
||||
|
||||
/// Maximum number of documents to return (optional)
|
||||
pub top_k: Option<usize>,
|
||||
|
||||
/// Whether to return documents in addition to scores
|
||||
#[serde(default = "default_return_documents")]
|
||||
pub return_documents: bool,
|
||||
|
||||
// SGLang specific extensions
|
||||
/// Request ID for tracking
|
||||
pub rid: Option<StringOrArray>,
|
||||
|
||||
/// User identifier
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
fn default_model_name() -> String {
|
||||
DEFAULT_MODEL_NAME.to_string()
|
||||
}
|
||||
|
||||
fn default_return_documents() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Individual rerank result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RerankResult {
|
||||
/// Relevance score for the document
|
||||
pub score: f32,
|
||||
|
||||
/// The document text (if return_documents was true)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub document: Option<String>,
|
||||
|
||||
/// Original index of the document in the request
|
||||
pub index: usize,
|
||||
|
||||
/// Additional metadata about the ranking
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub meta_info: Option<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
/// Rerank response containing sorted results
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RerankResponse {
|
||||
/// Ranked results sorted by score (highest first)
|
||||
pub results: Vec<RerankResult>,
|
||||
|
||||
/// Model used for reranking
|
||||
pub model: String,
|
||||
|
||||
/// Usage information
|
||||
pub usage: Option<UsageInfo>,
|
||||
|
||||
/// Response object type
|
||||
#[serde(default = "default_rerank_object")]
|
||||
pub object: String,
|
||||
|
||||
/// Response ID
|
||||
pub id: String,
|
||||
|
||||
/// Creation timestamp
|
||||
pub created: i64,
|
||||
}
|
||||
|
||||
fn default_rerank_object() -> String {
|
||||
"rerank".to_string()
|
||||
}
|
||||
|
||||
/// V1 API compatibility format for rerank requests
|
||||
/// Matches Python's V1RerankReqInput
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct V1RerankReqInput {
|
||||
pub query: String,
|
||||
pub documents: Vec<String>,
|
||||
}
|
||||
|
||||
/// Convert V1RerankReqInput to RerankRequest
|
||||
impl From<V1RerankReqInput> for RerankRequest {
|
||||
fn from(v1: V1RerankReqInput) -> Self {
|
||||
RerankRequest {
|
||||
query: v1.query,
|
||||
documents: v1.documents,
|
||||
model: default_model_name(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of GenerationRequest trait for RerankRequest
|
||||
impl GenerationRequest for RerankRequest {
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
Some(&self.model)
|
||||
}
|
||||
|
||||
fn is_stream(&self) -> bool {
|
||||
false // Reranking doesn't support streaming
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
self.query.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl RerankRequest {
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
// Validate query is not empty
|
||||
if self.query.trim().is_empty() {
|
||||
return Err("Query cannot be empty".to_string());
|
||||
}
|
||||
|
||||
// Validate documents list
|
||||
if self.documents.is_empty() {
|
||||
return Err("Documents list cannot be empty".to_string());
|
||||
}
|
||||
|
||||
// Validate top_k if specified
|
||||
if let Some(k) = self.top_k {
|
||||
if k == 0 {
|
||||
return Err("top_k must be greater than 0".to_string());
|
||||
}
|
||||
if k > self.documents.len() {
|
||||
// This is allowed but we log a warning
|
||||
tracing::warn!(
|
||||
"top_k ({}) is greater than number of documents ({})",
|
||||
k,
|
||||
self.documents.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the effective top_k value
|
||||
pub fn effective_top_k(&self) -> usize {
|
||||
self.top_k.unwrap_or(self.documents.len())
|
||||
}
|
||||
}
|
||||
|
||||
impl RerankResponse {
|
||||
pub fn new(results: Vec<RerankResult>, model: String, request_id: String) -> Self {
|
||||
RerankResponse {
|
||||
results,
|
||||
model,
|
||||
usage: None,
|
||||
object: default_rerank_object(),
|
||||
id: request_id,
|
||||
created: current_timestamp(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Sort results by score in descending order
|
||||
pub fn sort_by_score(&mut self) {
|
||||
self.results.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
}
|
||||
|
||||
/// Apply top_k limit to results
|
||||
pub fn apply_top_k(&mut self, k: usize) {
|
||||
self.results.truncate(k);
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = COMMON =
|
||||
// ==================================================================
|
||||
@@ -1827,7 +2020,7 @@ pub trait GenerationRequest: Send + Sync {
|
||||
}
|
||||
|
||||
/// Helper type for string or array of strings
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum StringOrArray {
|
||||
String(String),
|
||||
@@ -1866,3 +2059,619 @@ pub enum LoRAPath {
|
||||
Single(Option<String>),
|
||||
Batch(Vec<Option<String>>),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json;
|
||||
|
||||
// ==================================================================
|
||||
// = RERANK REQUEST TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_serialization() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: Some(5),
|
||||
return_documents: true,
|
||||
rid: Some(StringOrArray::String("req-123".to_string())),
|
||||
user: Some("user-456".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&request).unwrap();
|
||||
let deserialized: RerankRequest = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.query, request.query);
|
||||
assert_eq!(deserialized.documents, request.documents);
|
||||
assert_eq!(deserialized.model, request.model);
|
||||
assert_eq!(deserialized.top_k, request.top_k);
|
||||
assert_eq!(deserialized.return_documents, request.return_documents);
|
||||
assert_eq!(deserialized.rid, request.rid);
|
||||
assert_eq!(deserialized.user, request.user);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_deserialization_with_defaults() {
|
||||
let json = r#"{
|
||||
"query": "test query",
|
||||
"documents": ["doc1", "doc2"]
|
||||
}"#;
|
||||
|
||||
let request: RerankRequest = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert_eq!(request.query, "test query");
|
||||
assert_eq!(request.documents, vec!["doc1", "doc2"]);
|
||||
assert_eq!(request.model, default_model_name());
|
||||
assert_eq!(request.top_k, None);
|
||||
assert!(request.return_documents);
|
||||
assert_eq!(request.rid, None);
|
||||
assert_eq!(request.user, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_validation_success() {
|
||||
let request = RerankRequest {
|
||||
query: "valid query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: Some(2),
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_validation_empty_query() {
|
||||
let request = RerankRequest {
|
||||
query: "".to_string(),
|
||||
documents: vec!["doc1".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let result = request.validate();
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.unwrap_err(), "Query cannot be empty");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_validation_whitespace_query() {
|
||||
let request = RerankRequest {
|
||||
query: " ".to_string(),
|
||||
documents: vec!["doc1".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let result = request.validate();
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.unwrap_err(), "Query cannot be empty");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_validation_empty_documents() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec![],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let result = request.validate();
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.unwrap_err(), "Documents list cannot be empty");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_validation_top_k_zero() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: Some(0),
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let result = request.validate();
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.unwrap_err(), "top_k must be greater than 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_validation_top_k_greater_than_docs() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: Some(5),
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
// This should pass but log a warning
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_effective_top_k() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: Some(2),
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert_eq!(request.effective_top_k(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_effective_top_k_none() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert_eq!(request.effective_top_k(), 3);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = RERANK RESPONSE TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_creation() {
|
||||
let results = vec![
|
||||
RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.6,
|
||||
document: Some("doc2".to_string()),
|
||||
index: 1,
|
||||
meta_info: None,
|
||||
},
|
||||
];
|
||||
|
||||
let response = RerankResponse::new(
|
||||
results.clone(),
|
||||
"test-model".to_string(),
|
||||
"req-123".to_string(),
|
||||
);
|
||||
|
||||
assert_eq!(response.results.len(), 2);
|
||||
assert_eq!(response.model, "test-model");
|
||||
assert_eq!(response.id, "req-123");
|
||||
assert_eq!(response.object, "rerank");
|
||||
assert!(response.created > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_serialization() {
|
||||
let results = vec![RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
}];
|
||||
|
||||
let response =
|
||||
RerankResponse::new(results, "test-model".to_string(), "req-123".to_string());
|
||||
|
||||
let serialized = serde_json::to_string(&response).unwrap();
|
||||
let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.results.len(), response.results.len());
|
||||
assert_eq!(deserialized.model, response.model);
|
||||
assert_eq!(deserialized.id, response.id);
|
||||
assert_eq!(deserialized.object, response.object);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_sort_by_score() {
|
||||
let results = vec![
|
||||
RerankResult {
|
||||
score: 0.6,
|
||||
document: Some("doc2".to_string()),
|
||||
index: 1,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.4,
|
||||
document: Some("doc3".to_string()),
|
||||
index: 2,
|
||||
meta_info: None,
|
||||
},
|
||||
];
|
||||
|
||||
let mut response =
|
||||
RerankResponse::new(results, "test-model".to_string(), "req-123".to_string());
|
||||
|
||||
response.sort_by_score();
|
||||
|
||||
assert_eq!(response.results[0].score, 0.8);
|
||||
assert_eq!(response.results[0].index, 0);
|
||||
assert_eq!(response.results[1].score, 0.6);
|
||||
assert_eq!(response.results[1].index, 1);
|
||||
assert_eq!(response.results[2].score, 0.4);
|
||||
assert_eq!(response.results[2].index, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_apply_top_k() {
|
||||
let results = vec![
|
||||
RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.6,
|
||||
document: Some("doc2".to_string()),
|
||||
index: 1,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.4,
|
||||
document: Some("doc3".to_string()),
|
||||
index: 2,
|
||||
meta_info: None,
|
||||
},
|
||||
];
|
||||
|
||||
let mut response =
|
||||
RerankResponse::new(results, "test-model".to_string(), "req-123".to_string());
|
||||
|
||||
response.apply_top_k(2);
|
||||
|
||||
assert_eq!(response.results.len(), 2);
|
||||
assert_eq!(response.results[0].score, 0.8);
|
||||
assert_eq!(response.results[1].score, 0.6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_apply_top_k_larger_than_results() {
|
||||
let results = vec![RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
}];
|
||||
|
||||
let mut response =
|
||||
RerankResponse::new(results, "test-model".to_string(), "req-123".to_string());
|
||||
|
||||
response.apply_top_k(5);
|
||||
|
||||
assert_eq!(response.results.len(), 1);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = RERANK RESULT TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_rerank_result_serialization() {
|
||||
let result = RerankResult {
|
||||
score: 0.85,
|
||||
document: Some("test document".to_string()),
|
||||
index: 42,
|
||||
meta_info: Some(HashMap::from([
|
||||
("confidence".to_string(), Value::String("high".to_string())),
|
||||
(
|
||||
"processing_time".to_string(),
|
||||
Value::Number(serde_json::Number::from(150)),
|
||||
),
|
||||
])),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&result).unwrap();
|
||||
let deserialized: RerankResult = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.score, result.score);
|
||||
assert_eq!(deserialized.document, result.document);
|
||||
assert_eq!(deserialized.index, result.index);
|
||||
assert_eq!(deserialized.meta_info, result.meta_info);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_result_serialization_without_document() {
|
||||
let result = RerankResult {
|
||||
score: 0.85,
|
||||
document: None,
|
||||
index: 42,
|
||||
meta_info: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&result).unwrap();
|
||||
let deserialized: RerankResult = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.score, result.score);
|
||||
assert_eq!(deserialized.document, result.document);
|
||||
assert_eq!(deserialized.index, result.index);
|
||||
assert_eq!(deserialized.meta_info, result.meta_info);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = V1 COMPATIBILITY TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_v1_rerank_req_input_serialization() {
|
||||
let v1_input = V1RerankReqInput {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&v1_input).unwrap();
|
||||
let deserialized: V1RerankReqInput = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.query, v1_input.query);
|
||||
assert_eq!(deserialized.documents, v1_input.documents);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v1_to_rerank_request_conversion() {
|
||||
let v1_input = V1RerankReqInput {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
};
|
||||
|
||||
let request: RerankRequest = v1_input.into();
|
||||
|
||||
assert_eq!(request.query, "test query");
|
||||
assert_eq!(request.documents, vec!["doc1", "doc2"]);
|
||||
assert_eq!(request.model, default_model_name());
|
||||
assert_eq!(request.top_k, None);
|
||||
assert!(request.return_documents);
|
||||
assert_eq!(request.rid, None);
|
||||
assert_eq!(request.user, None);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = GENERATION REQUEST TRAIT TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_generation_request_trait() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert_eq!(request.get_model(), Some("test-model"));
|
||||
assert!(!request.is_stream());
|
||||
assert_eq!(request.extract_text_for_routing(), "test query");
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = EDGE CASES AND STRESS TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_very_long_query() {
|
||||
let long_query = "a".repeat(100000);
|
||||
let request = RerankRequest {
|
||||
query: long_query,
|
||||
documents: vec!["doc1".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_many_documents() {
|
||||
let documents: Vec<String> = (0..1000).map(|i| format!("doc{}", i)).collect();
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents,
|
||||
model: "test-model".to_string(),
|
||||
top_k: Some(100),
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
assert_eq!(request.effective_top_k(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_special_characters() {
|
||||
let request = RerankRequest {
|
||||
query: "query with émojis 🚀 and unicode: 测试".to_string(),
|
||||
documents: vec![
|
||||
"doc with émojis 🎉".to_string(),
|
||||
"doc with unicode: 测试".to_string(),
|
||||
],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: Some(StringOrArray::String("req-🚀-123".to_string())),
|
||||
user: Some("user-🎉-456".to_string()),
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_rid_array() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: Some(StringOrArray::Array(vec![
|
||||
"req1".to_string(),
|
||||
"req2".to_string(),
|
||||
])),
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_with_usage_info() {
|
||||
let results = vec![RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
}];
|
||||
|
||||
let mut response =
|
||||
RerankResponse::new(results, "test-model".to_string(), "req-123".to_string());
|
||||
|
||||
response.usage = Some(UsageInfo {
|
||||
prompt_tokens: 100,
|
||||
completion_tokens: 50,
|
||||
total_tokens: 150,
|
||||
reasoning_tokens: None,
|
||||
prompt_tokens_details: None,
|
||||
});
|
||||
|
||||
let serialized = serde_json::to_string(&response).unwrap();
|
||||
let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert!(deserialized.usage.is_some());
|
||||
let usage = deserialized.usage.unwrap();
|
||||
assert_eq!(usage.prompt_tokens, 100);
|
||||
assert_eq!(usage.completion_tokens, 50);
|
||||
assert_eq!(usage.total_tokens, 150);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = INTEGRATION TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_full_rerank_workflow() {
|
||||
// Create request
|
||||
let request = RerankRequest {
|
||||
query: "machine learning".to_string(),
|
||||
documents: vec![
|
||||
"Introduction to machine learning algorithms".to_string(),
|
||||
"Deep learning for computer vision".to_string(),
|
||||
"Natural language processing basics".to_string(),
|
||||
"Statistics and probability theory".to_string(),
|
||||
],
|
||||
model: "rerank-model".to_string(),
|
||||
top_k: Some(2),
|
||||
return_documents: true,
|
||||
rid: Some(StringOrArray::String("req-123".to_string())),
|
||||
user: Some("user-456".to_string()),
|
||||
};
|
||||
|
||||
// Validate request
|
||||
assert!(request.validate().is_ok());
|
||||
|
||||
// Simulate reranking results (in real scenario, this would come from the model)
|
||||
let results = vec![
|
||||
RerankResult {
|
||||
score: 0.95,
|
||||
document: Some("Introduction to machine learning algorithms".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.87,
|
||||
document: Some("Deep learning for computer vision".to_string()),
|
||||
index: 1,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.72,
|
||||
document: Some("Natural language processing basics".to_string()),
|
||||
index: 2,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.45,
|
||||
document: Some("Statistics and probability theory".to_string()),
|
||||
index: 3,
|
||||
meta_info: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Create response
|
||||
let mut response = RerankResponse::new(
|
||||
results,
|
||||
request.model.clone(),
|
||||
request
|
||||
.rid
|
||||
.as_ref()
|
||||
.and_then(|r| match r {
|
||||
StringOrArray::String(s) => Some(s.clone()),
|
||||
StringOrArray::Array(arr) => arr.first().cloned(),
|
||||
})
|
||||
.unwrap_or_else(|| "unknown".to_string()),
|
||||
);
|
||||
|
||||
// Sort by score
|
||||
response.sort_by_score();
|
||||
|
||||
// Apply top_k
|
||||
response.apply_top_k(request.effective_top_k());
|
||||
|
||||
// Verify results
|
||||
assert_eq!(response.results.len(), 2);
|
||||
assert_eq!(response.results[0].score, 0.95);
|
||||
assert_eq!(response.results[0].index, 0);
|
||||
assert_eq!(response.results[1].score, 0.87);
|
||||
assert_eq!(response.results[1].index, 1);
|
||||
assert_eq!(response.model, "rerank-model");
|
||||
|
||||
// Serialize and deserialize
|
||||
let serialized = serde_json::to_string(&response).unwrap();
|
||||
let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.results.len(), 2);
|
||||
assert_eq!(deserialized.model, response.model);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user