[router][protocols] Add Axum validate extractor and use it for /v1/chat/completions endpoint (#11588)

This commit is contained in:
Chang Su
2025-10-13 22:51:15 -07:00
committed by GitHub
parent e4358a4585
commit 27ef1459e6
21 changed files with 1982 additions and 2003 deletions

View File

@@ -0,0 +1,575 @@
use serde_json::json;
use sglang_router_rs::protocols::spec::{
ChatCompletionRequest, ChatMessage, Function, FunctionCall, FunctionChoice, StreamOptions,
Tool, ToolChoice, ToolChoiceValue, ToolReference, UserMessageContent,
};
use sglang_router_rs::protocols::validated::Normalizable;
use validator::Validate;
// Deprecated fields normalization tests
#[test]
fn test_max_tokens_normalizes_to_max_completion_tokens() {
#[allow(deprecated)]
let mut req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
max_tokens: Some(100),
max_completion_tokens: None,
..Default::default()
};
req.normalize();
assert_eq!(
req.max_completion_tokens,
Some(100),
"max_tokens should be copied to max_completion_tokens"
);
#[allow(deprecated)]
{
assert!(
req.max_tokens.is_none(),
"Deprecated field should be cleared"
);
}
assert!(
req.validate().is_ok(),
"Should be valid after normalization"
);
}
#[test]
fn test_max_completion_tokens_takes_precedence() {
#[allow(deprecated)]
let mut req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
max_tokens: Some(100),
max_completion_tokens: Some(200),
..Default::default()
};
req.normalize();
assert_eq!(
req.max_completion_tokens,
Some(200),
"max_completion_tokens should take precedence"
);
assert!(
req.validate().is_ok(),
"Should be valid after normalization"
);
}
#[test]
fn test_functions_normalizes_to_tools() {
#[allow(deprecated)]
let mut req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
functions: Some(vec![Function {
name: "test_func".to_string(),
description: Some("Test function".to_string()),
parameters: json!({}),
strict: None,
}]),
tools: None,
..Default::default()
};
req.normalize();
assert!(req.tools.is_some(), "functions should be migrated to tools");
assert_eq!(req.tools.as_ref().unwrap().len(), 1);
assert_eq!(req.tools.as_ref().unwrap()[0].function.name, "test_func");
#[allow(deprecated)]
{
assert!(
req.functions.is_none(),
"Deprecated field should be cleared"
);
}
assert!(
req.validate().is_ok(),
"Should be valid after normalization"
);
}
#[test]
fn test_function_call_normalizes_to_tool_choice() {
#[allow(deprecated)]
let mut req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
function_call: Some(FunctionCall::None),
tool_choice: None,
..Default::default()
};
req.normalize();
assert!(
req.tool_choice.is_some(),
"function_call should be migrated to tool_choice"
);
assert!(matches!(
req.tool_choice,
Some(ToolChoice::Value(ToolChoiceValue::None))
));
#[allow(deprecated)]
{
assert!(
req.function_call.is_none(),
"Deprecated field should be cleared"
);
}
assert!(
req.validate().is_ok(),
"Should be valid after normalization"
);
}
#[test]
fn test_function_call_function_variant_normalizes() {
#[allow(deprecated)]
let mut req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
function_call: Some(FunctionCall::Function {
name: "my_function".to_string(),
}),
tool_choice: None,
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "my_function".to_string(),
description: None,
parameters: json!({}),
strict: None,
},
}]),
..Default::default()
};
req.normalize();
assert!(
req.tool_choice.is_some(),
"function_call should be migrated to tool_choice"
);
match &req.tool_choice {
Some(ToolChoice::Function { function, .. }) => {
assert_eq!(function.name, "my_function");
}
_ => panic!("Expected ToolChoice::Function variant"),
}
#[allow(deprecated)]
{
assert!(
req.function_call.is_none(),
"Deprecated field should be cleared"
);
}
assert!(
req.validate().is_ok(),
"Should be valid after normalization"
);
}
// Stream options validation tests
#[test]
fn test_stream_options_requires_stream_enabled() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
stream: false,
stream_options: Some(StreamOptions {
include_usage: Some(true),
}),
..Default::default()
};
let result = req.validate();
assert!(
result.is_err(),
"Should reject stream_options when stream is false"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("stream_options") && err.contains("stream") && err.contains("enabled"),
"Error should mention stream dependency: {}",
err
);
}
#[test]
fn test_stream_options_valid_when_stream_enabled() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
stream: true,
stream_options: Some(StreamOptions {
include_usage: Some(true),
}),
..Default::default()
};
let result = req.validate();
assert!(
result.is_ok(),
"Should accept stream_options when stream is true"
);
}
#[test]
fn test_no_stream_options_valid_when_stream_disabled() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
stream: false,
stream_options: None,
..Default::default()
};
let result = req.validate();
assert!(
result.is_ok(),
"Should accept no stream_options when stream is false"
);
}
// Tool choice validation tests
#[test]
fn test_tool_choice_function_not_found() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::Function {
function: FunctionChoice {
name: "nonexistent_function".to_string(),
},
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_err(), "Should reject nonexistent function name");
let err = result.unwrap_err().to_string();
assert!(
err.contains("function 'nonexistent_function' not found"),
"Error should mention the missing function: {}",
err
);
}
#[test]
fn test_tool_choice_function_exists_valid() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::Function {
function: FunctionChoice {
name: "get_weather".to_string(),
},
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_ok(), "Should accept existing function name");
}
#[test]
fn test_tool_choice_allowed_tools_invalid_mode() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::AllowedTools {
mode: "invalid_mode".to_string(),
tools: vec![ToolReference {
tool_type: "function".to_string(),
name: "get_weather".to_string(),
}],
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_err(), "Should reject invalid mode");
let err = result.unwrap_err().to_string();
assert!(
err.contains("must be 'auto' or 'required'"),
"Error should mention valid modes: {}",
err
);
}
#[test]
fn test_tool_choice_allowed_tools_valid_mode_auto() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::AllowedTools {
mode: "auto".to_string(),
tools: vec![ToolReference {
tool_type: "function".to_string(),
name: "get_weather".to_string(),
}],
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_ok(), "Should accept 'auto' mode");
}
#[test]
fn test_tool_choice_allowed_tools_valid_mode_required() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::AllowedTools {
mode: "required".to_string(),
tools: vec![ToolReference {
tool_type: "function".to_string(),
name: "get_weather".to_string(),
}],
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_ok(), "Should accept 'required' mode");
}
#[test]
fn test_tool_choice_allowed_tools_tool_not_found() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::AllowedTools {
mode: "auto".to_string(),
tools: vec![ToolReference {
tool_type: "function".to_string(),
name: "nonexistent_tool".to_string(),
}],
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_err(), "Should reject nonexistent tool name");
let err = result.unwrap_err().to_string();
assert!(
err.contains("tool 'nonexistent_tool' not found"),
"Error should mention the missing tool: {}",
err
);
}
#[test]
fn test_tool_choice_allowed_tools_multiple_tools_valid() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_time".to_string(),
description: Some("Get time".to_string()),
parameters: json!({}),
strict: None,
},
},
]),
tool_choice: Some(ToolChoice::AllowedTools {
mode: "auto".to_string(),
tools: vec![
ToolReference {
tool_type: "function".to_string(),
name: "get_weather".to_string(),
},
ToolReference {
tool_type: "function".to_string(),
name: "get_time".to_string(),
},
],
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_ok(), "Should accept all valid tool references");
}
#[test]
fn test_tool_choice_allowed_tools_one_invalid_among_valid() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_time".to_string(),
description: Some("Get time".to_string()),
parameters: json!({}),
strict: None,
},
},
]),
tool_choice: Some(ToolChoice::AllowedTools {
mode: "auto".to_string(),
tools: vec![
ToolReference {
tool_type: "function".to_string(),
name: "get_weather".to_string(),
},
ToolReference {
tool_type: "function".to_string(),
name: "nonexistent_tool".to_string(),
},
],
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(
result.is_err(),
"Should reject if any tool reference is invalid"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("tool 'nonexistent_tool' not found"),
"Error should mention the missing tool: {}",
err
);
}

View File

@@ -0,0 +1,83 @@
use serde_json::json;
use sglang_router_rs::protocols::spec::{ChatMessage, UserMessageContent};
#[test]
fn test_chat_message_tagged_by_role_system() {
let json = json!({
"role": "system",
"content": "You are a helpful assistant"
});
let msg: ChatMessage = serde_json::from_value(json).unwrap();
match msg {
ChatMessage::System { content, .. } => {
assert_eq!(content, "You are a helpful assistant");
}
_ => panic!("Expected System variant"),
}
}
#[test]
fn test_chat_message_tagged_by_role_user() {
let json = json!({
"role": "user",
"content": "Hello"
});
let msg: ChatMessage = serde_json::from_value(json).unwrap();
match msg {
ChatMessage::User { content, .. } => match content {
UserMessageContent::Text(text) => assert_eq!(text, "Hello"),
_ => panic!("Expected text content"),
},
_ => panic!("Expected User variant"),
}
}
#[test]
fn test_chat_message_tagged_by_role_assistant() {
let json = json!({
"role": "assistant",
"content": "Hi there!"
});
let msg: ChatMessage = serde_json::from_value(json).unwrap();
match msg {
ChatMessage::Assistant { content, .. } => {
assert_eq!(content, Some("Hi there!".to_string()));
}
_ => panic!("Expected Assistant variant"),
}
}
#[test]
fn test_chat_message_tagged_by_role_tool() {
let json = json!({
"role": "tool",
"content": "Tool result",
"tool_call_id": "call_123"
});
let msg: ChatMessage = serde_json::from_value(json).unwrap();
match msg {
ChatMessage::Tool {
content,
tool_call_id,
} => {
assert_eq!(content, "Tool result");
assert_eq!(tool_call_id, "call_123");
}
_ => panic!("Expected Tool variant"),
}
}
#[test]
fn test_chat_message_wrong_role_rejected() {
let json = json!({
"role": "invalid_role",
"content": "test"
});
let result = serde_json::from_value::<ChatMessage>(json);
assert!(result.is_err(), "Should reject invalid role");
}

View File

@@ -0,0 +1,96 @@
use serde_json::{from_str, json, to_string};
use sglang_router_rs::protocols::spec::{EmbeddingRequest, GenerationRequest};
#[test]
fn test_embedding_request_serialization_string_input() {
let req = EmbeddingRequest {
model: "test-emb".to_string(),
input: json!("hello"),
encoding_format: Some("float".to_string()),
user: Some("user-1".to_string()),
dimensions: Some(128),
rid: Some("rid-123".to_string()),
};
let serialized = to_string(&req).unwrap();
let deserialized: EmbeddingRequest = from_str(&serialized).unwrap();
assert_eq!(deserialized.model, req.model);
assert_eq!(deserialized.input, req.input);
assert_eq!(deserialized.encoding_format, req.encoding_format);
assert_eq!(deserialized.user, req.user);
assert_eq!(deserialized.dimensions, req.dimensions);
assert_eq!(deserialized.rid, req.rid);
}
#[test]
fn test_embedding_request_serialization_array_input() {
let req = EmbeddingRequest {
model: "test-emb".to_string(),
input: json!(["a", "b", "c"]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
let serialized = to_string(&req).unwrap();
let de: EmbeddingRequest = from_str(&serialized).unwrap();
assert_eq!(de.model, req.model);
assert_eq!(de.input, req.input);
}
#[test]
fn test_embedding_generation_request_trait_string() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: json!("hello"),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert!(!req.is_stream());
assert_eq!(req.get_model(), Some("emb-model"));
assert_eq!(req.extract_text_for_routing(), "hello");
}
#[test]
fn test_embedding_generation_request_trait_array() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: json!(["hello", "world"]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert_eq!(req.extract_text_for_routing(), "hello world");
}
#[test]
fn test_embedding_generation_request_trait_non_text() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: json!({"tokens": [1, 2, 3]}),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert_eq!(req.extract_text_for_routing(), "");
}
#[test]
fn test_embedding_generation_request_trait_mixed_array_ignores_nested() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: json!(["a", ["b", "c"], 123, {"k": "v"}]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
// Only top-level string elements are extracted
assert_eq!(req.extract_text_for_routing(), "a");
}

View File

@@ -0,0 +1,8 @@
// Protocol specification tests
// These tests were originally in src/protocols/spec.rs and have been moved here
// to reduce the size of that file and improve test organization.
mod chat_completion;
mod chat_message;
mod embedding;
mod rerank;

View File

@@ -0,0 +1,613 @@
use serde_json::{from_str, to_string, Number, Value};
use sglang_router_rs::protocols::spec::{
default_model_name, GenerationRequest, RerankRequest, RerankResponse, RerankResult,
StringOrArray, UsageInfo, V1RerankReqInput,
};
use std::collections::HashMap;
#[test]
fn test_rerank_request_serialization() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "test-model".to_string(),
top_k: Some(5),
return_documents: true,
rid: Some(StringOrArray::String("req-123".to_string())),
user: Some("user-456".to_string()),
};
let serialized = to_string(&request).unwrap();
let deserialized: RerankRequest = from_str(&serialized).unwrap();
assert_eq!(deserialized.query, request.query);
assert_eq!(deserialized.documents, request.documents);
assert_eq!(deserialized.model, request.model);
assert_eq!(deserialized.top_k, request.top_k);
assert_eq!(deserialized.return_documents, request.return_documents);
assert_eq!(deserialized.rid, request.rid);
assert_eq!(deserialized.user, request.user);
}
#[test]
fn test_rerank_request_deserialization_with_defaults() {
let json = r#"{
"query": "test query",
"documents": ["doc1", "doc2"]
}"#;
let request: RerankRequest = from_str(json).unwrap();
assert_eq!(request.query, "test query");
assert_eq!(request.documents, vec!["doc1", "doc2"]);
assert_eq!(request.model, default_model_name());
assert_eq!(request.top_k, None);
assert!(request.return_documents);
assert_eq!(request.rid, None);
assert_eq!(request.user, None);
}
#[test]
fn test_rerank_request_validation_success() {
let request = RerankRequest {
query: "valid query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "test-model".to_string(),
top_k: Some(2),
return_documents: true,
rid: None,
user: None,
};
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_request_validation_empty_query() {
let request = RerankRequest {
query: "".to_string(),
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Query cannot be empty");
}
#[test]
fn test_rerank_request_validation_whitespace_query() {
let request = RerankRequest {
query: " ".to_string(),
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Query cannot be empty");
}
#[test]
fn test_rerank_request_validation_empty_documents() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec![],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Documents list cannot be empty");
}
#[test]
fn test_rerank_request_validation_top_k_zero() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "test-model".to_string(),
top_k: Some(0),
return_documents: true,
rid: None,
user: None,
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "top_k must be greater than 0");
}
#[test]
fn test_rerank_request_validation_top_k_greater_than_docs() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "test-model".to_string(),
top_k: Some(5),
return_documents: true,
rid: None,
user: None,
};
// This should pass but log a warning
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_request_effective_top_k() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()],
model: "test-model".to_string(),
top_k: Some(2),
return_documents: true,
rid: None,
user: None,
};
assert_eq!(request.effective_top_k(), 2);
}
#[test]
fn test_rerank_request_effective_top_k_none() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
assert_eq!(request.effective_top_k(), 3);
}
#[test]
fn test_rerank_response_creation() {
let results = vec![
RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
},
RerankResult {
score: 0.6,
document: Some("doc2".to_string()),
index: 1,
meta_info: None,
},
];
let response = RerankResponse::new(
results.clone(),
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
assert_eq!(response.results.len(), 2);
assert_eq!(response.model, "test-model");
assert_eq!(
response.id,
Some(StringOrArray::String("req-123".to_string()))
);
assert_eq!(response.object, "rerank");
assert!(response.created > 0);
}
#[test]
fn test_rerank_response_serialization() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
let serialized = to_string(&response).unwrap();
let deserialized: RerankResponse = from_str(&serialized).unwrap();
assert_eq!(deserialized.results.len(), response.results.len());
assert_eq!(deserialized.model, response.model);
assert_eq!(deserialized.id, response.id);
assert_eq!(deserialized.object, response.object);
}
#[test]
fn test_rerank_response_sort_by_score() {
let results = vec![
RerankResult {
score: 0.6,
document: Some("doc2".to_string()),
index: 1,
meta_info: None,
},
RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
},
RerankResult {
score: 0.4,
document: Some("doc3".to_string()),
index: 2,
meta_info: None,
},
];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.sort_by_score();
assert_eq!(response.results[0].score, 0.8);
assert_eq!(response.results[0].index, 0);
assert_eq!(response.results[1].score, 0.6);
assert_eq!(response.results[1].index, 1);
assert_eq!(response.results[2].score, 0.4);
assert_eq!(response.results[2].index, 2);
}
#[test]
fn test_rerank_response_apply_top_k() {
let results = vec![
RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
},
RerankResult {
score: 0.6,
document: Some("doc2".to_string()),
index: 1,
meta_info: None,
},
RerankResult {
score: 0.4,
document: Some("doc3".to_string()),
index: 2,
meta_info: None,
},
];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.apply_top_k(2);
assert_eq!(response.results.len(), 2);
assert_eq!(response.results[0].score, 0.8);
assert_eq!(response.results[1].score, 0.6);
}
#[test]
fn test_rerank_response_apply_top_k_larger_than_results() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.apply_top_k(5);
assert_eq!(response.results.len(), 1);
}
#[test]
fn test_rerank_response_drop_documents() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.drop_documents();
assert_eq!(response.results[0].document, None);
}
#[test]
fn test_rerank_result_serialization() {
let result = RerankResult {
score: 0.85,
document: Some("test document".to_string()),
index: 42,
meta_info: Some(HashMap::from([
("confidence".to_string(), Value::String("high".to_string())),
(
"processing_time".to_string(),
Value::Number(Number::from(150)),
),
])),
};
let serialized = to_string(&result).unwrap();
let deserialized: RerankResult = from_str(&serialized).unwrap();
assert_eq!(deserialized.score, result.score);
assert_eq!(deserialized.document, result.document);
assert_eq!(deserialized.index, result.index);
assert_eq!(deserialized.meta_info, result.meta_info);
}
#[test]
fn test_rerank_result_serialization_without_document() {
let result = RerankResult {
score: 0.85,
document: None,
index: 42,
meta_info: None,
};
let serialized = to_string(&result).unwrap();
let deserialized: RerankResult = from_str(&serialized).unwrap();
assert_eq!(deserialized.score, result.score);
assert_eq!(deserialized.document, result.document);
assert_eq!(deserialized.index, result.index);
assert_eq!(deserialized.meta_info, result.meta_info);
}
#[test]
fn test_v1_rerank_req_input_serialization() {
let v1_input = V1RerankReqInput {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
};
let serialized = to_string(&v1_input).unwrap();
let deserialized: V1RerankReqInput = from_str(&serialized).unwrap();
assert_eq!(deserialized.query, v1_input.query);
assert_eq!(deserialized.documents, v1_input.documents);
}
#[test]
fn test_v1_to_rerank_request_conversion() {
let v1_input = V1RerankReqInput {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
};
let request: RerankRequest = v1_input.into();
assert_eq!(request.query, "test query");
assert_eq!(request.documents, vec!["doc1", "doc2"]);
assert_eq!(request.model, default_model_name());
assert_eq!(request.top_k, None);
assert!(request.return_documents);
assert_eq!(request.rid, None);
assert_eq!(request.user, None);
}
#[test]
fn test_rerank_request_generation_request_trait() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
assert_eq!(request.get_model(), Some("test-model"));
assert!(!request.is_stream());
assert_eq!(request.extract_text_for_routing(), "test query");
}
#[test]
fn test_rerank_request_very_long_query() {
let long_query = "a".repeat(100000);
let request = RerankRequest {
query: long_query,
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_request_many_documents() {
let documents: Vec<String> = (0..1000).map(|i| format!("doc{}", i)).collect();
let request = RerankRequest {
query: "test query".to_string(),
documents,
model: "test-model".to_string(),
top_k: Some(100),
return_documents: true,
rid: None,
user: None,
};
assert!(request.validate().is_ok());
assert_eq!(request.effective_top_k(), 100);
}
#[test]
fn test_rerank_request_special_characters() {
let request = RerankRequest {
query: "query with émojis 🚀 and unicode: 测试".to_string(),
documents: vec![
"doc with émojis 🎉".to_string(),
"doc with unicode: 测试".to_string(),
],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: Some(StringOrArray::String("req-🚀-123".to_string())),
user: Some("user-🎉-456".to_string()),
};
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_request_rid_array() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: Some(StringOrArray::Array(vec![
"req1".to_string(),
"req2".to_string(),
])),
user: None,
};
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_response_with_usage_info() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.usage = Some(UsageInfo {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
reasoning_tokens: None,
prompt_tokens_details: None,
});
let serialized = to_string(&response).unwrap();
let deserialized: RerankResponse = from_str(&serialized).unwrap();
assert!(deserialized.usage.is_some());
let usage = deserialized.usage.unwrap();
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
assert_eq!(usage.total_tokens, 150);
}
#[test]
fn test_full_rerank_workflow() {
// Create request
let request = RerankRequest {
query: "machine learning".to_string(),
documents: vec![
"Introduction to machine learning algorithms".to_string(),
"Deep learning for computer vision".to_string(),
"Natural language processing basics".to_string(),
"Statistics and probability theory".to_string(),
],
model: "rerank-model".to_string(),
top_k: Some(2),
return_documents: true,
rid: Some(StringOrArray::String("req-123".to_string())),
user: Some("user-456".to_string()),
};
// Validate request
assert!(request.validate().is_ok());
// Simulate reranking results (in real scenario, this would come from the model)
let results = vec![
RerankResult {
score: 0.95,
document: Some("Introduction to machine learning algorithms".to_string()),
index: 0,
meta_info: None,
},
RerankResult {
score: 0.87,
document: Some("Deep learning for computer vision".to_string()),
index: 1,
meta_info: None,
},
RerankResult {
score: 0.72,
document: Some("Natural language processing basics".to_string()),
index: 2,
meta_info: None,
},
RerankResult {
score: 0.45,
document: Some("Statistics and probability theory".to_string()),
index: 3,
meta_info: None,
},
];
// Create response
let mut response = RerankResponse::new(results, request.model.clone(), request.rid.clone());
// Sort by score
response.sort_by_score();
// Apply top_k
response.apply_top_k(request.effective_top_k());
assert_eq!(response.results.len(), 2);
assert_eq!(response.results[0].score, 0.95);
assert_eq!(response.results[0].index, 0);
assert_eq!(response.results[1].score, 0.87);
assert_eq!(response.results[1].index, 1);
assert_eq!(response.model, "rerank-model");
// Serialize and deserialize
let serialized = to_string(&response).unwrap();
let deserialized: RerankResponse = from_str(&serialized).unwrap();
assert_eq!(deserialized.results.len(), 2);
assert_eq!(deserialized.model, response.model);
}