[router][protocols] Add Axum validate extractor and use it for /v1/chat/completions endpoint (#11588)
This commit is contained in:
@@ -56,6 +56,7 @@ parking_lot = "0.12.4"
|
||||
thiserror = "2.0.12"
|
||||
regex = "1.10"
|
||||
url = "2.5.4"
|
||||
validator = { version = "0.18", features = ["derive"] }
|
||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||
anyhow = "1.0"
|
||||
tokenizers = { version = "0.22.0" }
|
||||
|
||||
@@ -4,8 +4,8 @@ use std::time::Instant;
|
||||
|
||||
use sglang_router_rs::core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType};
|
||||
use sglang_router_rs::protocols::spec::{
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||
SamplingParams, StringOrArray, UserMessageContent,
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, SamplingParams,
|
||||
StringOrArray, UserMessageContent,
|
||||
};
|
||||
use sglang_router_rs::routers::http::pd_types::{generate_room_id, RequestWithBootstrap};
|
||||
|
||||
@@ -31,7 +31,6 @@ fn default_generate_request() -> GenerateRequest {
|
||||
prompt: None,
|
||||
input_ids: None,
|
||||
stream: false,
|
||||
parameters: None,
|
||||
sampling_params: None,
|
||||
return_logprob: false,
|
||||
// SGLang Extensions
|
||||
@@ -101,14 +100,6 @@ fn default_completion_request() -> CompletionRequest {
|
||||
fn create_sample_generate_request() -> GenerateRequest {
|
||||
GenerateRequest {
|
||||
text: Some("Write a story about artificial intelligence".to_string()),
|
||||
parameters: Some(GenerateParameters {
|
||||
max_new_tokens: Some(100),
|
||||
temperature: Some(0.8),
|
||||
top_p: Some(0.9),
|
||||
top_k: Some(50),
|
||||
repetition_penalty: Some(1.0),
|
||||
..Default::default()
|
||||
}),
|
||||
sampling_params: Some(SamplingParams {
|
||||
temperature: Some(0.8),
|
||||
top_p: Some(0.9),
|
||||
@@ -128,12 +119,10 @@ fn create_sample_chat_completion_request() -> ChatCompletionRequest {
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
messages: vec![
|
||||
ChatMessage::System {
|
||||
role: "system".to_string(),
|
||||
content: "You are a helpful assistant".to_string(),
|
||||
name: None,
|
||||
},
|
||||
ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text(
|
||||
"Explain quantum computing in simple terms".to_string(),
|
||||
),
|
||||
@@ -170,7 +159,6 @@ fn create_sample_completion_request() -> CompletionRequest {
|
||||
#[allow(deprecated)]
|
||||
fn create_large_chat_completion_request() -> ChatCompletionRequest {
|
||||
let mut messages = vec![ChatMessage::System {
|
||||
role: "system".to_string(),
|
||||
content: "You are a helpful assistant with extensive knowledge.".to_string(),
|
||||
name: None,
|
||||
}];
|
||||
@@ -178,12 +166,10 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
|
||||
// Add many user/assistant pairs to simulate a long conversation
|
||||
for i in 0..50 {
|
||||
messages.push(ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text(format!("Question {}: What do you think about topic number {} which involves complex reasoning about multiple interconnected systems and their relationships?", i, i)),
|
||||
name: None,
|
||||
});
|
||||
messages.push(ChatMessage::Assistant {
|
||||
role: "assistant".to_string(),
|
||||
content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
|
||||
@@ -123,6 +123,7 @@ fn create_test_tools() -> Vec<Tool> {
|
||||
"limit": {"type": "number"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -137,6 +138,7 @@ fn create_test_tools() -> Vec<Tool> {
|
||||
"code": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@@ -301,13 +301,7 @@ impl SglangSchedulerClient {
|
||||
) -> Result<proto::SamplingParams, String> {
|
||||
let stop_sequences = self.extract_stop_strings(request);
|
||||
|
||||
// Handle max tokens: prefer max_completion_tokens (new) over max_tokens (deprecated)
|
||||
// If neither is specified, use None to let the backend decide the default
|
||||
#[allow(deprecated)]
|
||||
let max_new_tokens = request
|
||||
.max_completion_tokens
|
||||
.or(request.max_tokens)
|
||||
.map(|v| v as i32);
|
||||
let max_new_tokens = request.max_completion_tokens.map(|v| v as i32);
|
||||
|
||||
// Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none"
|
||||
let skip_special_tokens = if request.tools.is_some() {
|
||||
@@ -322,7 +316,6 @@ impl SglangSchedulerClient {
|
||||
request.skip_special_tokens
|
||||
};
|
||||
|
||||
#[allow(deprecated)]
|
||||
Ok(proto::SamplingParams {
|
||||
temperature: request.temperature.unwrap_or(1.0),
|
||||
top_p: request.top_p.unwrap_or(1.0),
|
||||
@@ -485,10 +478,10 @@ impl SglangSchedulerClient {
|
||||
})?);
|
||||
}
|
||||
|
||||
// Handle min_tokens with conversion
|
||||
if let Some(min_tokens) = p.min_tokens {
|
||||
sampling.min_new_tokens = i32::try_from(min_tokens)
|
||||
.map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?;
|
||||
// Handle min_new_tokens with conversion
|
||||
if let Some(min_new_tokens) = p.min_new_tokens {
|
||||
sampling.min_new_tokens = i32::try_from(min_new_tokens)
|
||||
.map_err(|_| "min_new_tokens must fit into a 32-bit signed integer".to_string())?;
|
||||
}
|
||||
|
||||
// Handle n with conversion
|
||||
|
||||
@@ -2,5 +2,5 @@
|
||||
// This module provides a structured approach to handling different API protocols
|
||||
|
||||
pub mod spec;
|
||||
pub mod validation;
|
||||
pub mod validated;
|
||||
pub mod worker_spec;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
172
sgl-router/src/protocols/validated.rs
Normal file
172
sgl-router/src/protocols/validated.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
// Validated JSON extractor for automatic request validation
|
||||
//
|
||||
// This module provides a ValidatedJson extractor that automatically validates
|
||||
// requests using the validator crate's Validate trait.
|
||||
|
||||
use axum::{
|
||||
extract::{rejection::JsonRejection, FromRequest, Request},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_json::json;
|
||||
use validator::Validate;
|
||||
|
||||
/// Trait for request types that need post-deserialization normalization
|
||||
pub trait Normalizable {
|
||||
/// Normalize the request by applying defaults and transformations
|
||||
fn normalize(&mut self) {
|
||||
// Default: no-op
|
||||
}
|
||||
}
|
||||
|
||||
/// A JSON extractor that automatically validates and normalizes the request body
|
||||
///
|
||||
/// This extractor deserializes the request body and automatically calls `.validate()`
|
||||
/// on types that implement the `Validate` trait. If validation fails, it returns
|
||||
/// a 400 Bad Request with detailed error information.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// async fn create_chat(
|
||||
/// ValidatedJson(request): ValidatedJson<ChatCompletionRequest>,
|
||||
/// ) -> Response {
|
||||
/// // request is guaranteed to be valid here
|
||||
/// process_request(request).await
|
||||
/// }
|
||||
/// ```
|
||||
pub struct ValidatedJson<T>(pub T);
|
||||
|
||||
impl<S, T> FromRequest<S> for ValidatedJson<T>
|
||||
where
|
||||
T: DeserializeOwned + Validate + Normalizable + Send,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Response;
|
||||
|
||||
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
|
||||
// First, extract and deserialize the JSON
|
||||
let Json(mut data) =
|
||||
Json::<T>::from_request(req, state)
|
||||
.await
|
||||
.map_err(|err: JsonRejection| {
|
||||
let error_message = match err {
|
||||
JsonRejection::JsonDataError(e) => {
|
||||
format!("Invalid JSON data: {}", e)
|
||||
}
|
||||
JsonRejection::JsonSyntaxError(e) => {
|
||||
format!("JSON syntax error: {}", e)
|
||||
}
|
||||
JsonRejection::MissingJsonContentType(_) => {
|
||||
"Missing Content-Type: application/json header".to_string()
|
||||
}
|
||||
_ => format!("Failed to parse JSON: {}", err),
|
||||
};
|
||||
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": {
|
||||
"message": error_message,
|
||||
"type": "invalid_request_error",
|
||||
"code": "json_parse_error"
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
})?;
|
||||
|
||||
// Normalize the request (apply defaults based on other fields)
|
||||
data.normalize();
|
||||
|
||||
// Then, automatically validate the data
|
||||
data.validate().map_err(|validation_errors| {
|
||||
// Extract the first error message from the validation errors
|
||||
let error_message = validation_errors
|
||||
.field_errors()
|
||||
.values()
|
||||
.flat_map(|errors| errors.iter())
|
||||
.find_map(|e| e.message.as_ref())
|
||||
.map(|m| m.to_string())
|
||||
.unwrap_or_else(|| "Validation failed".to_string());
|
||||
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": {
|
||||
"message": error_message,
|
||||
"type": "invalid_request_error",
|
||||
"code": 400
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
})?;
|
||||
|
||||
Ok(ValidatedJson(data))
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Deref to allow transparent access to the inner value
|
||||
impl<T> std::ops::Deref for ValidatedJson<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::DerefMut for ValidatedJson<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use validator::Validate;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Validate)]
|
||||
struct TestRequest {
|
||||
#[validate(range(min = 0.0, max = 1.0))]
|
||||
value: f32,
|
||||
#[validate(length(min = 1))]
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl Normalizable for TestRequest {
|
||||
// Use default no-op implementation
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validated_json_valid() {
|
||||
// This test is conceptual - actual testing would require Axum test harness
|
||||
let request = TestRequest {
|
||||
value: 0.5,
|
||||
name: "test".to_string(),
|
||||
};
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validated_json_invalid_range() {
|
||||
let request = TestRequest {
|
||||
value: 1.5, // Out of range
|
||||
name: "test".to_string(),
|
||||
};
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validated_json_invalid_length() {
|
||||
let request = TestRequest {
|
||||
value: 0.5,
|
||||
name: "".to_string(), // Empty name
|
||||
};
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -959,7 +959,6 @@ mod tests {
|
||||
#[test]
|
||||
fn test_transform_messages_string_format() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![
|
||||
ContentPart::Text {
|
||||
text: "Hello".to_string(),
|
||||
@@ -993,7 +992,6 @@ mod tests {
|
||||
#[test]
|
||||
fn test_transform_messages_openai_format() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![
|
||||
ContentPart::Text {
|
||||
text: "Describe this image:".to_string(),
|
||||
@@ -1028,7 +1026,6 @@ mod tests {
|
||||
#[test]
|
||||
fn test_transform_messages_simple_string_content() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text("Simple text message".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
@@ -1049,12 +1046,10 @@ mod tests {
|
||||
fn test_transform_messages_multiple_messages() {
|
||||
let messages = vec![
|
||||
ChatMessage::System {
|
||||
role: "system".to_string(),
|
||||
content: "System prompt".to_string(),
|
||||
name: None,
|
||||
},
|
||||
ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![
|
||||
ContentPart::Text {
|
||||
text: "User message".to_string(),
|
||||
@@ -1086,7 +1081,6 @@ mod tests {
|
||||
#[test]
|
||||
fn test_transform_messages_empty_text_parts() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
@@ -1109,12 +1103,10 @@ mod tests {
|
||||
fn test_transform_messages_mixed_content_types() {
|
||||
let messages = vec![
|
||||
ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text("Plain text".to_string()),
|
||||
name: None,
|
||||
},
|
||||
ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![
|
||||
ContentPart::Text {
|
||||
text: "With image".to_string(),
|
||||
|
||||
@@ -16,6 +16,7 @@ use crate::{
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest,
|
||||
RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput,
|
||||
},
|
||||
validated::ValidatedJson,
|
||||
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
|
||||
},
|
||||
reasoning_parser::ParserFactory as ReasoningParserFactory,
|
||||
@@ -291,7 +292,7 @@ async fn generate(
|
||||
async fn v1_chat_completions(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<ChatCompletionRequest>,
|
||||
ValidatedJson(body): ValidatedJson<ChatCompletionRequest>,
|
||||
) -> Response {
|
||||
state.router.route_chat(Some(&headers), &body, None).await
|
||||
}
|
||||
|
||||
@@ -1461,39 +1461,6 @@ mod error_tests {
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_missing_required_fields() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18405,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
// Missing messages in chat completion
|
||||
let payload = json!({
|
||||
"model": "test-model"
|
||||
// missing "messages"
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/chat/completions")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
// Axum validates JSON schema - returns 422 for validation errors
|
||||
assert_eq!(resp.status(), StatusCode::UNPROCESSABLE_ENTITY);
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_model() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
|
||||
@@ -172,14 +172,12 @@ assistant:
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = vec![
|
||||
let messages = [
|
||||
spec::ChatMessage::System {
|
||||
role: "system".to_string(),
|
||||
content: "You are helpful".to_string(),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
},
|
||||
@@ -216,7 +214,6 @@ fn test_chat_template_with_tokens_unit_test() {
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
@@ -18,7 +18,6 @@ fn test_simple_chat_template() {
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
@@ -53,7 +52,6 @@ fn test_chat_template_with_tokens() {
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
@@ -113,14 +111,12 @@ fn test_llama_style_template() {
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = vec![
|
||||
let messages = [
|
||||
spec::ChatMessage::System {
|
||||
role: "system".to_string(),
|
||||
content: "You are a helpful assistant".to_string(),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("What is 2+2?".to_string()),
|
||||
name: None,
|
||||
},
|
||||
@@ -172,19 +168,16 @@ fn test_chatml_template() {
|
||||
|
||||
let messages = vec![
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::Assistant {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hi there!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("How are you?".to_string()),
|
||||
name: None,
|
||||
},
|
||||
@@ -227,7 +220,6 @@ assistant:
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
@@ -315,7 +307,6 @@ fn test_template_with_multimodal_content() {
|
||||
let processor = ChatTemplateProcessor::new(template.to_string());
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Parts(vec![
|
||||
spec::ContentPart::Text {
|
||||
text: "Look at this:".to_string(),
|
||||
|
||||
@@ -57,14 +57,12 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let messages = vec![
|
||||
let messages = [
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::Assistant {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hi there".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
@@ -143,7 +141,6 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
let messages = [spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Test".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
@@ -201,14 +198,12 @@ mod tests {
|
||||
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}";
|
||||
tokenizer.set_chat_template(new_template.to_string());
|
||||
|
||||
let messages = vec![
|
||||
let messages = [
|
||||
spec::ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: spec::UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
},
|
||||
spec::ChatMessage::Assistant {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("World".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
|
||||
@@ -119,6 +119,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"query": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -135,6 +136,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"units": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -149,6 +151,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"y": {"type": "number"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -164,6 +167,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"target_lang": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -178,6 +182,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"format": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -192,6 +197,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"format": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -206,6 +212,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"notifications": {"type": "boolean"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -214,6 +221,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
name: "ping".to_string(),
|
||||
description: Some("Ping service".to_string()),
|
||||
parameters: json!({"type": "object", "properties": {}}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -222,6 +230,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
name: "test".to_string(),
|
||||
description: Some("Test function".to_string()),
|
||||
parameters: json!({"type": "object", "properties": {}}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -239,6 +248,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"text": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -254,6 +264,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"search_type": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -267,6 +278,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"city": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -282,6 +294,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"optional": {"type": "null"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -297,6 +310,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"none_val": {"type": "null"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -311,6 +325,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"email": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -325,6 +340,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"y": {"type": "number"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -338,6 +354,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"x": {"type": "number"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -346,6 +363,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
name: "func1".to_string(),
|
||||
description: Some("Function 1".to_string()),
|
||||
parameters: json!({"type": "object", "properties": {}}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -359,6 +377,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"y": {"type": "number"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -367,6 +386,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
name: "tool1".to_string(),
|
||||
description: Some("Tool 1".to_string()),
|
||||
parameters: json!({"type": "object", "properties": {}}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
@@ -380,6 +400,7 @@ pub fn create_test_tools() -> Vec<Tool> {
|
||||
"y": {"type": "number"}
|
||||
}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
575
sgl-router/tests/spec/chat_completion.rs
Normal file
575
sgl-router/tests/spec/chat_completion.rs
Normal file
@@ -0,0 +1,575 @@
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::protocols::spec::{
|
||||
ChatCompletionRequest, ChatMessage, Function, FunctionCall, FunctionChoice, StreamOptions,
|
||||
Tool, ToolChoice, ToolChoiceValue, ToolReference, UserMessageContent,
|
||||
};
|
||||
use sglang_router_rs::protocols::validated::Normalizable;
|
||||
use validator::Validate;
|
||||
|
||||
// Deprecated fields normalization tests
|
||||
|
||||
#[test]
|
||||
fn test_max_tokens_normalizes_to_max_completion_tokens() {
|
||||
#[allow(deprecated)]
|
||||
let mut req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
max_tokens: Some(100),
|
||||
max_completion_tokens: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
req.normalize();
|
||||
assert_eq!(
|
||||
req.max_completion_tokens,
|
||||
Some(100),
|
||||
"max_tokens should be copied to max_completion_tokens"
|
||||
);
|
||||
#[allow(deprecated)]
|
||||
{
|
||||
assert!(
|
||||
req.max_tokens.is_none(),
|
||||
"Deprecated field should be cleared"
|
||||
);
|
||||
}
|
||||
assert!(
|
||||
req.validate().is_ok(),
|
||||
"Should be valid after normalization"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_completion_tokens_takes_precedence() {
|
||||
#[allow(deprecated)]
|
||||
let mut req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
max_tokens: Some(100),
|
||||
max_completion_tokens: Some(200),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
req.normalize();
|
||||
assert_eq!(
|
||||
req.max_completion_tokens,
|
||||
Some(200),
|
||||
"max_completion_tokens should take precedence"
|
||||
);
|
||||
assert!(
|
||||
req.validate().is_ok(),
|
||||
"Should be valid after normalization"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_functions_normalizes_to_tools() {
|
||||
#[allow(deprecated)]
|
||||
let mut req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
functions: Some(vec![Function {
|
||||
name: "test_func".to_string(),
|
||||
description: Some("Test function".to_string()),
|
||||
parameters: json!({}),
|
||||
strict: None,
|
||||
}]),
|
||||
tools: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
req.normalize();
|
||||
assert!(req.tools.is_some(), "functions should be migrated to tools");
|
||||
assert_eq!(req.tools.as_ref().unwrap().len(), 1);
|
||||
assert_eq!(req.tools.as_ref().unwrap()[0].function.name, "test_func");
|
||||
#[allow(deprecated)]
|
||||
{
|
||||
assert!(
|
||||
req.functions.is_none(),
|
||||
"Deprecated field should be cleared"
|
||||
);
|
||||
}
|
||||
assert!(
|
||||
req.validate().is_ok(),
|
||||
"Should be valid after normalization"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call_normalizes_to_tool_choice() {
|
||||
#[allow(deprecated)]
|
||||
let mut req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
function_call: Some(FunctionCall::None),
|
||||
tool_choice: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
req.normalize();
|
||||
assert!(
|
||||
req.tool_choice.is_some(),
|
||||
"function_call should be migrated to tool_choice"
|
||||
);
|
||||
assert!(matches!(
|
||||
req.tool_choice,
|
||||
Some(ToolChoice::Value(ToolChoiceValue::None))
|
||||
));
|
||||
#[allow(deprecated)]
|
||||
{
|
||||
assert!(
|
||||
req.function_call.is_none(),
|
||||
"Deprecated field should be cleared"
|
||||
);
|
||||
}
|
||||
assert!(
|
||||
req.validate().is_ok(),
|
||||
"Should be valid after normalization"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call_function_variant_normalizes() {
|
||||
#[allow(deprecated)]
|
||||
let mut req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
function_call: Some(FunctionCall::Function {
|
||||
name: "my_function".to_string(),
|
||||
}),
|
||||
tool_choice: None,
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "my_function".to_string(),
|
||||
description: None,
|
||||
parameters: json!({}),
|
||||
strict: None,
|
||||
},
|
||||
}]),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
req.normalize();
|
||||
assert!(
|
||||
req.tool_choice.is_some(),
|
||||
"function_call should be migrated to tool_choice"
|
||||
);
|
||||
match &req.tool_choice {
|
||||
Some(ToolChoice::Function { function, .. }) => {
|
||||
assert_eq!(function.name, "my_function");
|
||||
}
|
||||
_ => panic!("Expected ToolChoice::Function variant"),
|
||||
}
|
||||
#[allow(deprecated)]
|
||||
{
|
||||
assert!(
|
||||
req.function_call.is_none(),
|
||||
"Deprecated field should be cleared"
|
||||
);
|
||||
}
|
||||
assert!(
|
||||
req.validate().is_ok(),
|
||||
"Should be valid after normalization"
|
||||
);
|
||||
}
|
||||
|
||||
// Stream options validation tests
|
||||
|
||||
#[test]
|
||||
fn test_stream_options_requires_stream_enabled() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
stream: false,
|
||||
stream_options: Some(StreamOptions {
|
||||
include_usage: Some(true),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = req.validate();
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Should reject stream_options when stream is false"
|
||||
);
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("stream_options") && err.contains("stream") && err.contains("enabled"),
|
||||
"Error should mention stream dependency: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_options_valid_when_stream_enabled() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
stream: true,
|
||||
stream_options: Some(StreamOptions {
|
||||
include_usage: Some(true),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = req.validate();
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should accept stream_options when stream is true"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_stream_options_valid_when_stream_disabled() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = req.validate();
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should accept no stream_options when stream is false"
|
||||
);
|
||||
}
|
||||
|
||||
// Tool choice validation tests
|
||||
#[test]
|
||||
fn test_tool_choice_function_not_found() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get weather".to_string()),
|
||||
parameters: json!({}),
|
||||
strict: None,
|
||||
},
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Function {
|
||||
function: FunctionChoice {
|
||||
name: "nonexistent_function".to_string(),
|
||||
},
|
||||
tool_type: "function".to_string(),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = req.validate();
|
||||
assert!(result.is_err(), "Should reject nonexistent function name");
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("function 'nonexistent_function' not found"),
|
||||
"Error should mention the missing function: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_choice_function_exists_valid() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get weather".to_string()),
|
||||
parameters: json!({}),
|
||||
strict: None,
|
||||
},
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Function {
|
||||
function: FunctionChoice {
|
||||
name: "get_weather".to_string(),
|
||||
},
|
||||
tool_type: "function".to_string(),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = req.validate();
|
||||
assert!(result.is_ok(), "Should accept existing function name");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_choice_allowed_tools_invalid_mode() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get weather".to_string()),
|
||||
parameters: json!({}),
|
||||
strict: None,
|
||||
},
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::AllowedTools {
|
||||
mode: "invalid_mode".to_string(),
|
||||
tools: vec![ToolReference {
|
||||
tool_type: "function".to_string(),
|
||||
name: "get_weather".to_string(),
|
||||
}],
|
||||
tool_type: "function".to_string(),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = req.validate();
|
||||
assert!(result.is_err(), "Should reject invalid mode");
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("must be 'auto' or 'required'"),
|
||||
"Error should mention valid modes: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_choice_allowed_tools_valid_mode_auto() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get weather".to_string()),
|
||||
parameters: json!({}),
|
||||
strict: None,
|
||||
},
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::AllowedTools {
|
||||
mode: "auto".to_string(),
|
||||
tools: vec![ToolReference {
|
||||
tool_type: "function".to_string(),
|
||||
name: "get_weather".to_string(),
|
||||
}],
|
||||
tool_type: "function".to_string(),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = req.validate();
|
||||
assert!(result.is_ok(), "Should accept 'auto' mode");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_choice_allowed_tools_valid_mode_required() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get weather".to_string()),
|
||||
parameters: json!({}),
|
||||
strict: None,
|
||||
},
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::AllowedTools {
|
||||
mode: "required".to_string(),
|
||||
tools: vec![ToolReference {
|
||||
tool_type: "function".to_string(),
|
||||
name: "get_weather".to_string(),
|
||||
}],
|
||||
tool_type: "function".to_string(),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = req.validate();
|
||||
assert!(result.is_ok(), "Should accept 'required' mode");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_choice_allowed_tools_tool_not_found() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get weather".to_string()),
|
||||
parameters: json!({}),
|
||||
strict: None,
|
||||
},
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::AllowedTools {
|
||||
mode: "auto".to_string(),
|
||||
tools: vec![ToolReference {
|
||||
tool_type: "function".to_string(),
|
||||
name: "nonexistent_tool".to_string(),
|
||||
}],
|
||||
tool_type: "function".to_string(),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = req.validate();
|
||||
assert!(result.is_err(), "Should reject nonexistent tool name");
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("tool 'nonexistent_tool' not found"),
|
||||
"Error should mention the missing tool: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_choice_allowed_tools_multiple_tools_valid() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
tools: Some(vec![
|
||||
Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get weather".to_string()),
|
||||
parameters: json!({}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_time".to_string(),
|
||||
description: Some("Get time".to_string()),
|
||||
parameters: json!({}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
]),
|
||||
tool_choice: Some(ToolChoice::AllowedTools {
|
||||
mode: "auto".to_string(),
|
||||
tools: vec![
|
||||
ToolReference {
|
||||
tool_type: "function".to_string(),
|
||||
name: "get_weather".to_string(),
|
||||
},
|
||||
ToolReference {
|
||||
tool_type: "function".to_string(),
|
||||
name: "get_time".to_string(),
|
||||
},
|
||||
],
|
||||
tool_type: "function".to_string(),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = req.validate();
|
||||
assert!(result.is_ok(), "Should accept all valid tool references");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_choice_allowed_tools_one_invalid_among_valid() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
content: UserMessageContent::Text("hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
tools: Some(vec![
|
||||
Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get weather".to_string()),
|
||||
parameters: json!({}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_time".to_string(),
|
||||
description: Some("Get time".to_string()),
|
||||
parameters: json!({}),
|
||||
strict: None,
|
||||
},
|
||||
},
|
||||
]),
|
||||
tool_choice: Some(ToolChoice::AllowedTools {
|
||||
mode: "auto".to_string(),
|
||||
tools: vec![
|
||||
ToolReference {
|
||||
tool_type: "function".to_string(),
|
||||
name: "get_weather".to_string(),
|
||||
},
|
||||
ToolReference {
|
||||
tool_type: "function".to_string(),
|
||||
name: "nonexistent_tool".to_string(),
|
||||
},
|
||||
],
|
||||
tool_type: "function".to_string(),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = req.validate();
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Should reject if any tool reference is invalid"
|
||||
);
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("tool 'nonexistent_tool' not found"),
|
||||
"Error should mention the missing tool: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
83
sgl-router/tests/spec/chat_message.rs
Normal file
83
sgl-router/tests/spec/chat_message.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::protocols::spec::{ChatMessage, UserMessageContent};
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_tagged_by_role_system() {
|
||||
let json = json!({
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
});
|
||||
|
||||
let msg: ChatMessage = serde_json::from_value(json).unwrap();
|
||||
match msg {
|
||||
ChatMessage::System { content, .. } => {
|
||||
assert_eq!(content, "You are a helpful assistant");
|
||||
}
|
||||
_ => panic!("Expected System variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_tagged_by_role_user() {
|
||||
let json = json!({
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
});
|
||||
|
||||
let msg: ChatMessage = serde_json::from_value(json).unwrap();
|
||||
match msg {
|
||||
ChatMessage::User { content, .. } => match content {
|
||||
UserMessageContent::Text(text) => assert_eq!(text, "Hello"),
|
||||
_ => panic!("Expected text content"),
|
||||
},
|
||||
_ => panic!("Expected User variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_tagged_by_role_assistant() {
|
||||
let json = json!({
|
||||
"role": "assistant",
|
||||
"content": "Hi there!"
|
||||
});
|
||||
|
||||
let msg: ChatMessage = serde_json::from_value(json).unwrap();
|
||||
match msg {
|
||||
ChatMessage::Assistant { content, .. } => {
|
||||
assert_eq!(content, Some("Hi there!".to_string()));
|
||||
}
|
||||
_ => panic!("Expected Assistant variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_tagged_by_role_tool() {
|
||||
let json = json!({
|
||||
"role": "tool",
|
||||
"content": "Tool result",
|
||||
"tool_call_id": "call_123"
|
||||
});
|
||||
|
||||
let msg: ChatMessage = serde_json::from_value(json).unwrap();
|
||||
match msg {
|
||||
ChatMessage::Tool {
|
||||
content,
|
||||
tool_call_id,
|
||||
} => {
|
||||
assert_eq!(content, "Tool result");
|
||||
assert_eq!(tool_call_id, "call_123");
|
||||
}
|
||||
_ => panic!("Expected Tool variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_wrong_role_rejected() {
|
||||
let json = json!({
|
||||
"role": "invalid_role",
|
||||
"content": "test"
|
||||
});
|
||||
|
||||
let result = serde_json::from_value::<ChatMessage>(json);
|
||||
assert!(result.is_err(), "Should reject invalid role");
|
||||
}
|
||||
96
sgl-router/tests/spec/embedding.rs
Normal file
96
sgl-router/tests/spec/embedding.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use serde_json::{from_str, json, to_string};
|
||||
use sglang_router_rs::protocols::spec::{EmbeddingRequest, GenerationRequest};
|
||||
|
||||
#[test]
|
||||
fn test_embedding_request_serialization_string_input() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "test-emb".to_string(),
|
||||
input: json!("hello"),
|
||||
encoding_format: Some("float".to_string()),
|
||||
user: Some("user-1".to_string()),
|
||||
dimensions: Some(128),
|
||||
rid: Some("rid-123".to_string()),
|
||||
};
|
||||
|
||||
let serialized = to_string(&req).unwrap();
|
||||
let deserialized: EmbeddingRequest = from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.model, req.model);
|
||||
assert_eq!(deserialized.input, req.input);
|
||||
assert_eq!(deserialized.encoding_format, req.encoding_format);
|
||||
assert_eq!(deserialized.user, req.user);
|
||||
assert_eq!(deserialized.dimensions, req.dimensions);
|
||||
assert_eq!(deserialized.rid, req.rid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_request_serialization_array_input() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "test-emb".to_string(),
|
||||
input: json!(["a", "b", "c"]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
|
||||
let serialized = to_string(&req).unwrap();
|
||||
let de: EmbeddingRequest = from_str(&serialized).unwrap();
|
||||
assert_eq!(de.model, req.model);
|
||||
assert_eq!(de.input, req.input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_generation_request_trait_string() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: json!("hello"),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
assert!(!req.is_stream());
|
||||
assert_eq!(req.get_model(), Some("emb-model"));
|
||||
assert_eq!(req.extract_text_for_routing(), "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_generation_request_trait_array() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: json!(["hello", "world"]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
assert_eq!(req.extract_text_for_routing(), "hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_generation_request_trait_non_text() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: json!({"tokens": [1, 2, 3]}),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
assert_eq!(req.extract_text_for_routing(), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_generation_request_trait_mixed_array_ignores_nested() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: json!(["a", ["b", "c"], 123, {"k": "v"}]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
// Only top-level string elements are extracted
|
||||
assert_eq!(req.extract_text_for_routing(), "a");
|
||||
}
|
||||
8
sgl-router/tests/spec/mod.rs
Normal file
8
sgl-router/tests/spec/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
// Protocol specification tests
|
||||
// These tests were originally in src/protocols/spec.rs and have been moved here
|
||||
// to reduce the size of that file and improve test organization.
|
||||
|
||||
mod chat_completion;
|
||||
mod chat_message;
|
||||
mod embedding;
|
||||
mod rerank;
|
||||
613
sgl-router/tests/spec/rerank.rs
Normal file
613
sgl-router/tests/spec/rerank.rs
Normal file
@@ -0,0 +1,613 @@
|
||||
use serde_json::{from_str, to_string, Number, Value};
|
||||
use sglang_router_rs::protocols::spec::{
|
||||
default_model_name, GenerationRequest, RerankRequest, RerankResponse, RerankResult,
|
||||
StringOrArray, UsageInfo, V1RerankReqInput,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_serialization() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: Some(5),
|
||||
return_documents: true,
|
||||
rid: Some(StringOrArray::String("req-123".to_string())),
|
||||
user: Some("user-456".to_string()),
|
||||
};
|
||||
|
||||
let serialized = to_string(&request).unwrap();
|
||||
let deserialized: RerankRequest = from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.query, request.query);
|
||||
assert_eq!(deserialized.documents, request.documents);
|
||||
assert_eq!(deserialized.model, request.model);
|
||||
assert_eq!(deserialized.top_k, request.top_k);
|
||||
assert_eq!(deserialized.return_documents, request.return_documents);
|
||||
assert_eq!(deserialized.rid, request.rid);
|
||||
assert_eq!(deserialized.user, request.user);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_deserialization_with_defaults() {
|
||||
let json = r#"{
|
||||
"query": "test query",
|
||||
"documents": ["doc1", "doc2"]
|
||||
}"#;
|
||||
|
||||
let request: RerankRequest = from_str(json).unwrap();
|
||||
|
||||
assert_eq!(request.query, "test query");
|
||||
assert_eq!(request.documents, vec!["doc1", "doc2"]);
|
||||
assert_eq!(request.model, default_model_name());
|
||||
assert_eq!(request.top_k, None);
|
||||
assert!(request.return_documents);
|
||||
assert_eq!(request.rid, None);
|
||||
assert_eq!(request.user, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_validation_success() {
|
||||
let request = RerankRequest {
|
||||
query: "valid query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: Some(2),
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_validation_empty_query() {
|
||||
let request = RerankRequest {
|
||||
query: "".to_string(),
|
||||
documents: vec!["doc1".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let result = request.validate();
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.unwrap_err(), "Query cannot be empty");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_validation_whitespace_query() {
|
||||
let request = RerankRequest {
|
||||
query: " ".to_string(),
|
||||
documents: vec!["doc1".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let result = request.validate();
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.unwrap_err(), "Query cannot be empty");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_validation_empty_documents() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec![],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let result = request.validate();
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.unwrap_err(), "Documents list cannot be empty");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_validation_top_k_zero() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: Some(0),
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let result = request.validate();
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.unwrap_err(), "top_k must be greater than 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_validation_top_k_greater_than_docs() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: Some(5),
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
// This should pass but log a warning
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_effective_top_k() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: Some(2),
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert_eq!(request.effective_top_k(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_effective_top_k_none() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert_eq!(request.effective_top_k(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_creation() {
|
||||
let results = vec![
|
||||
RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.6,
|
||||
document: Some("doc2".to_string()),
|
||||
index: 1,
|
||||
meta_info: None,
|
||||
},
|
||||
];
|
||||
|
||||
let response = RerankResponse::new(
|
||||
results.clone(),
|
||||
"test-model".to_string(),
|
||||
Some(StringOrArray::String("req-123".to_string())),
|
||||
);
|
||||
|
||||
assert_eq!(response.results.len(), 2);
|
||||
assert_eq!(response.model, "test-model");
|
||||
assert_eq!(
|
||||
response.id,
|
||||
Some(StringOrArray::String("req-123".to_string()))
|
||||
);
|
||||
assert_eq!(response.object, "rerank");
|
||||
assert!(response.created > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_serialization() {
|
||||
let results = vec![RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
}];
|
||||
|
||||
let response = RerankResponse::new(
|
||||
results,
|
||||
"test-model".to_string(),
|
||||
Some(StringOrArray::String("req-123".to_string())),
|
||||
);
|
||||
|
||||
let serialized = to_string(&response).unwrap();
|
||||
let deserialized: RerankResponse = from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.results.len(), response.results.len());
|
||||
assert_eq!(deserialized.model, response.model);
|
||||
assert_eq!(deserialized.id, response.id);
|
||||
assert_eq!(deserialized.object, response.object);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_sort_by_score() {
|
||||
let results = vec![
|
||||
RerankResult {
|
||||
score: 0.6,
|
||||
document: Some("doc2".to_string()),
|
||||
index: 1,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.4,
|
||||
document: Some("doc3".to_string()),
|
||||
index: 2,
|
||||
meta_info: None,
|
||||
},
|
||||
];
|
||||
|
||||
let mut response = RerankResponse::new(
|
||||
results,
|
||||
"test-model".to_string(),
|
||||
Some(StringOrArray::String("req-123".to_string())),
|
||||
);
|
||||
|
||||
response.sort_by_score();
|
||||
|
||||
assert_eq!(response.results[0].score, 0.8);
|
||||
assert_eq!(response.results[0].index, 0);
|
||||
assert_eq!(response.results[1].score, 0.6);
|
||||
assert_eq!(response.results[1].index, 1);
|
||||
assert_eq!(response.results[2].score, 0.4);
|
||||
assert_eq!(response.results[2].index, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_apply_top_k() {
|
||||
let results = vec![
|
||||
RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.6,
|
||||
document: Some("doc2".to_string()),
|
||||
index: 1,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.4,
|
||||
document: Some("doc3".to_string()),
|
||||
index: 2,
|
||||
meta_info: None,
|
||||
},
|
||||
];
|
||||
|
||||
let mut response = RerankResponse::new(
|
||||
results,
|
||||
"test-model".to_string(),
|
||||
Some(StringOrArray::String("req-123".to_string())),
|
||||
);
|
||||
|
||||
response.apply_top_k(2);
|
||||
|
||||
assert_eq!(response.results.len(), 2);
|
||||
assert_eq!(response.results[0].score, 0.8);
|
||||
assert_eq!(response.results[1].score, 0.6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_apply_top_k_larger_than_results() {
|
||||
let results = vec![RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
}];
|
||||
|
||||
let mut response = RerankResponse::new(
|
||||
results,
|
||||
"test-model".to_string(),
|
||||
Some(StringOrArray::String("req-123".to_string())),
|
||||
);
|
||||
|
||||
response.apply_top_k(5);
|
||||
|
||||
assert_eq!(response.results.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_drop_documents() {
|
||||
let results = vec![RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
}];
|
||||
let mut response = RerankResponse::new(
|
||||
results,
|
||||
"test-model".to_string(),
|
||||
Some(StringOrArray::String("req-123".to_string())),
|
||||
);
|
||||
|
||||
response.drop_documents();
|
||||
|
||||
assert_eq!(response.results[0].document, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_result_serialization() {
|
||||
let result = RerankResult {
|
||||
score: 0.85,
|
||||
document: Some("test document".to_string()),
|
||||
index: 42,
|
||||
meta_info: Some(HashMap::from([
|
||||
("confidence".to_string(), Value::String("high".to_string())),
|
||||
(
|
||||
"processing_time".to_string(),
|
||||
Value::Number(Number::from(150)),
|
||||
),
|
||||
])),
|
||||
};
|
||||
|
||||
let serialized = to_string(&result).unwrap();
|
||||
let deserialized: RerankResult = from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.score, result.score);
|
||||
assert_eq!(deserialized.document, result.document);
|
||||
assert_eq!(deserialized.index, result.index);
|
||||
assert_eq!(deserialized.meta_info, result.meta_info);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_result_serialization_without_document() {
|
||||
let result = RerankResult {
|
||||
score: 0.85,
|
||||
document: None,
|
||||
index: 42,
|
||||
meta_info: None,
|
||||
};
|
||||
|
||||
let serialized = to_string(&result).unwrap();
|
||||
let deserialized: RerankResult = from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.score, result.score);
|
||||
assert_eq!(deserialized.document, result.document);
|
||||
assert_eq!(deserialized.index, result.index);
|
||||
assert_eq!(deserialized.meta_info, result.meta_info);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v1_rerank_req_input_serialization() {
|
||||
let v1_input = V1RerankReqInput {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
};
|
||||
|
||||
let serialized = to_string(&v1_input).unwrap();
|
||||
let deserialized: V1RerankReqInput = from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.query, v1_input.query);
|
||||
assert_eq!(deserialized.documents, v1_input.documents);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v1_to_rerank_request_conversion() {
|
||||
let v1_input = V1RerankReqInput {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
};
|
||||
|
||||
let request: RerankRequest = v1_input.into();
|
||||
|
||||
assert_eq!(request.query, "test query");
|
||||
assert_eq!(request.documents, vec!["doc1", "doc2"]);
|
||||
assert_eq!(request.model, default_model_name());
|
||||
assert_eq!(request.top_k, None);
|
||||
assert!(request.return_documents);
|
||||
assert_eq!(request.rid, None);
|
||||
assert_eq!(request.user, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_generation_request_trait() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert_eq!(request.get_model(), Some("test-model"));
|
||||
assert!(!request.is_stream());
|
||||
assert_eq!(request.extract_text_for_routing(), "test query");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_very_long_query() {
|
||||
let long_query = "a".repeat(100000);
|
||||
let request = RerankRequest {
|
||||
query: long_query,
|
||||
documents: vec!["doc1".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_many_documents() {
|
||||
let documents: Vec<String> = (0..1000).map(|i| format!("doc{}", i)).collect();
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents,
|
||||
model: "test-model".to_string(),
|
||||
top_k: Some(100),
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
assert_eq!(request.effective_top_k(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_special_characters() {
|
||||
let request = RerankRequest {
|
||||
query: "query with émojis 🚀 and unicode: 测试".to_string(),
|
||||
documents: vec![
|
||||
"doc with émojis 🎉".to_string(),
|
||||
"doc with unicode: 测试".to_string(),
|
||||
],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: Some(StringOrArray::String("req-🚀-123".to_string())),
|
||||
user: Some("user-🎉-456".to_string()),
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_request_rid_array() {
|
||||
let request = RerankRequest {
|
||||
query: "test query".to_string(),
|
||||
documents: vec!["doc1".to_string()],
|
||||
model: "test-model".to_string(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: Some(StringOrArray::Array(vec![
|
||||
"req1".to_string(),
|
||||
"req2".to_string(),
|
||||
])),
|
||||
user: None,
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rerank_response_with_usage_info() {
|
||||
let results = vec![RerankResult {
|
||||
score: 0.8,
|
||||
document: Some("doc1".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
}];
|
||||
|
||||
let mut response = RerankResponse::new(
|
||||
results,
|
||||
"test-model".to_string(),
|
||||
Some(StringOrArray::String("req-123".to_string())),
|
||||
);
|
||||
|
||||
response.usage = Some(UsageInfo {
|
||||
prompt_tokens: 100,
|
||||
completion_tokens: 50,
|
||||
total_tokens: 150,
|
||||
reasoning_tokens: None,
|
||||
prompt_tokens_details: None,
|
||||
});
|
||||
|
||||
let serialized = to_string(&response).unwrap();
|
||||
let deserialized: RerankResponse = from_str(&serialized).unwrap();
|
||||
|
||||
assert!(deserialized.usage.is_some());
|
||||
let usage = deserialized.usage.unwrap();
|
||||
assert_eq!(usage.prompt_tokens, 100);
|
||||
assert_eq!(usage.completion_tokens, 50);
|
||||
assert_eq!(usage.total_tokens, 150);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_rerank_workflow() {
|
||||
// Create request
|
||||
let request = RerankRequest {
|
||||
query: "machine learning".to_string(),
|
||||
documents: vec![
|
||||
"Introduction to machine learning algorithms".to_string(),
|
||||
"Deep learning for computer vision".to_string(),
|
||||
"Natural language processing basics".to_string(),
|
||||
"Statistics and probability theory".to_string(),
|
||||
],
|
||||
model: "rerank-model".to_string(),
|
||||
top_k: Some(2),
|
||||
return_documents: true,
|
||||
rid: Some(StringOrArray::String("req-123".to_string())),
|
||||
user: Some("user-456".to_string()),
|
||||
};
|
||||
|
||||
// Validate request
|
||||
assert!(request.validate().is_ok());
|
||||
|
||||
// Simulate reranking results (in real scenario, this would come from the model)
|
||||
let results = vec![
|
||||
RerankResult {
|
||||
score: 0.95,
|
||||
document: Some("Introduction to machine learning algorithms".to_string()),
|
||||
index: 0,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.87,
|
||||
document: Some("Deep learning for computer vision".to_string()),
|
||||
index: 1,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.72,
|
||||
document: Some("Natural language processing basics".to_string()),
|
||||
index: 2,
|
||||
meta_info: None,
|
||||
},
|
||||
RerankResult {
|
||||
score: 0.45,
|
||||
document: Some("Statistics and probability theory".to_string()),
|
||||
index: 3,
|
||||
meta_info: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Create response
|
||||
let mut response = RerankResponse::new(results, request.model.clone(), request.rid.clone());
|
||||
|
||||
// Sort by score
|
||||
response.sort_by_score();
|
||||
|
||||
// Apply top_k
|
||||
response.apply_top_k(request.effective_top_k());
|
||||
|
||||
assert_eq!(response.results.len(), 2);
|
||||
assert_eq!(response.results[0].score, 0.95);
|
||||
assert_eq!(response.results[0].index, 0);
|
||||
assert_eq!(response.results[1].score, 0.87);
|
||||
assert_eq!(response.results[1].index, 1);
|
||||
assert_eq!(response.model, "rerank-model");
|
||||
|
||||
// Serialize and deserialize
|
||||
let serialized = to_string(&response).unwrap();
|
||||
let deserialized: RerankResponse = from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.results.len(), 2);
|
||||
assert_eq!(deserialized.model, response.model);
|
||||
}
|
||||
@@ -601,7 +601,6 @@ async fn test_unsupported_endpoints() {
|
||||
prompt: None,
|
||||
text: Some("Hello world".to_string()),
|
||||
input_ids: None,
|
||||
parameters: None,
|
||||
sampling_params: None,
|
||||
stream: false,
|
||||
return_logprob: false,
|
||||
@@ -642,7 +641,6 @@ async fn test_openai_router_chat_completion_with_mock() {
|
||||
// Create a minimal chat completion request
|
||||
let mut chat_request = create_minimal_chat_request();
|
||||
chat_request.messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text("Hello, how are you?".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
Reference in New Issue
Block a user