diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 20518902c..0557f0660 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -56,6 +56,7 @@ parking_lot = "0.12.4" thiserror = "2.0.12" regex = "1.10" url = "2.5.4" +validator = { version = "0.18", features = ["derive"] } tokio-stream = { version = "0.1", features = ["sync"] } anyhow = "1.0" tokenizers = { version = "0.22.0" } diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index ea3160218..703ca55fd 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -4,8 +4,8 @@ use std::time::Instant; use sglang_router_rs::core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType}; use sglang_router_rs::protocols::spec::{ - ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, - SamplingParams, StringOrArray, UserMessageContent, + ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, SamplingParams, + StringOrArray, UserMessageContent, }; use sglang_router_rs::routers::http::pd_types::{generate_room_id, RequestWithBootstrap}; @@ -31,7 +31,6 @@ fn default_generate_request() -> GenerateRequest { prompt: None, input_ids: None, stream: false, - parameters: None, sampling_params: None, return_logprob: false, // SGLang Extensions @@ -101,14 +100,6 @@ fn default_completion_request() -> CompletionRequest { fn create_sample_generate_request() -> GenerateRequest { GenerateRequest { text: Some("Write a story about artificial intelligence".to_string()), - parameters: Some(GenerateParameters { - max_new_tokens: Some(100), - temperature: Some(0.8), - top_p: Some(0.9), - top_k: Some(50), - repetition_penalty: Some(1.0), - ..Default::default() - }), sampling_params: Some(SamplingParams { temperature: Some(0.8), top_p: Some(0.9), @@ -128,12 +119,10 @@ fn create_sample_chat_completion_request() -> ChatCompletionRequest { model: "gpt-3.5-turbo".to_string(), messages: vec![ ChatMessage::System { - role: "system".to_string(), content: "You are a helpful assistant".to_string(), name: None, }, ChatMessage::User { - role: "user".to_string(), content: UserMessageContent::Text( "Explain quantum computing in simple terms".to_string(), ), @@ -170,7 +159,6 @@ fn create_sample_completion_request() -> CompletionRequest { #[allow(deprecated)] fn create_large_chat_completion_request() -> ChatCompletionRequest { let mut messages = vec![ChatMessage::System { - role: "system".to_string(), content: "You are a helpful assistant with extensive knowledge.".to_string(), name: None, }]; @@ -178,12 +166,10 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest { // Add many user/assistant pairs to simulate a long conversation for i in 0..50 { messages.push(ChatMessage::User { - role: "user".to_string(), content: UserMessageContent::Text(format!("Question {}: What do you think about topic number {} which involves complex reasoning about multiple interconnected systems and their relationships?", i, i)), name: None, }); messages.push(ChatMessage::Assistant { - role: "assistant".to_string(), content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)), name: None, tool_calls: None, diff --git a/sgl-router/benches/tool_parser_benchmark.rs b/sgl-router/benches/tool_parser_benchmark.rs index 636a32366..96f7d6f69 100644 --- a/sgl-router/benches/tool_parser_benchmark.rs +++ b/sgl-router/benches/tool_parser_benchmark.rs @@ -123,6 +123,7 @@ fn create_test_tools() -> Vec { "limit": {"type": "number"} } }), + strict: None, }, }, Tool { @@ -137,6 +138,7 @@ fn create_test_tools() -> Vec { "code": {"type": "string"} } }), + strict: None, }, }, ] diff --git a/sgl-router/src/grpc_client/sglang_scheduler.rs b/sgl-router/src/grpc_client/sglang_scheduler.rs index 6f2d9f84f..3ff74d303 100644 --- a/sgl-router/src/grpc_client/sglang_scheduler.rs +++ b/sgl-router/src/grpc_client/sglang_scheduler.rs @@ -301,13 +301,7 @@ impl SglangSchedulerClient { ) -> Result { let stop_sequences = self.extract_stop_strings(request); - // Handle max tokens: prefer max_completion_tokens (new) over max_tokens (deprecated) - // If neither is specified, use None to let the backend decide the default - #[allow(deprecated)] - let max_new_tokens = request - .max_completion_tokens - .or(request.max_tokens) - .map(|v| v as i32); + let max_new_tokens = request.max_completion_tokens.map(|v| v as i32); // Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none" let skip_special_tokens = if request.tools.is_some() { @@ -322,7 +316,6 @@ impl SglangSchedulerClient { request.skip_special_tokens }; - #[allow(deprecated)] Ok(proto::SamplingParams { temperature: request.temperature.unwrap_or(1.0), top_p: request.top_p.unwrap_or(1.0), @@ -485,10 +478,10 @@ impl SglangSchedulerClient { })?); } - // Handle min_tokens with conversion - if let Some(min_tokens) = p.min_tokens { - sampling.min_new_tokens = i32::try_from(min_tokens) - .map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?; + // Handle min_new_tokens with conversion + if let Some(min_new_tokens) = p.min_new_tokens { + sampling.min_new_tokens = i32::try_from(min_new_tokens) + .map_err(|_| "min_new_tokens must fit into a 32-bit signed integer".to_string())?; } // Handle n with conversion diff --git a/sgl-router/src/protocols/mod.rs b/sgl-router/src/protocols/mod.rs index 7359a3d2e..418c2568b 100644 --- a/sgl-router/src/protocols/mod.rs +++ b/sgl-router/src/protocols/mod.rs @@ -2,5 +2,5 @@ // This module provides a structured approach to handling different API protocols pub mod spec; -pub mod validation; +pub mod validated; pub mod worker_spec; diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index 5a8f3b7d5..c7eb42c98 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -1,6 +1,9 @@ use serde::{Deserialize, Serialize}; use serde_json::{to_value, Map, Number, Value}; use std::collections::HashMap; +use validator::Validate; + +use crate::protocols::validated::Normalizable; // Default model value when not specified fn default_model() -> String { @@ -55,22 +58,22 @@ fn default_model() -> String { // - Helper functions #[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] +#[serde(tag = "role")] pub enum ChatMessage { + #[serde(rename = "system")] System { - role: String, content: String, #[serde(skip_serializing_if = "Option::is_none")] name: Option, }, + #[serde(rename = "user")] User { - role: String, // "user" content: UserMessageContent, #[serde(skip_serializing_if = "Option::is_none")] name: Option, }, + #[serde(rename = "assistant")] Assistant { - role: String, // "assistant" #[serde(skip_serializing_if = "Option::is_none")] content: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -81,16 +84,13 @@ pub enum ChatMessage { #[serde(skip_serializing_if = "Option::is_none")] reasoning_content: Option, }, + #[serde(rename = "tool")] Tool { - role: String, // "tool" content: String, tool_call_id: String, }, - Function { - role: String, // "function" - content: String, - name: String, - }, + #[serde(rename = "function")] + Function { content: String, name: String }, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -168,9 +168,11 @@ pub struct FunctionCallDelta { pub arguments: Option, } -#[derive(Debug, Clone, Deserialize, Serialize, Default)] +#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)] +#[validate(schema(function = "validate_chat_cross_parameters"))] pub struct ChatCompletionRequest { /// A list of messages comprising the conversation so far + #[validate(custom(function = "validate_messages"))] pub messages: Vec, /// ID of the model to use @@ -179,6 +181,7 @@ pub struct ChatCompletionRequest { /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = -2.0, max = 2.0))] pub frequency_penalty: Option, /// Deprecated: Replaced by tool_choice @@ -202,10 +205,12 @@ pub struct ChatCompletionRequest { /// Deprecated: Replaced by max_completion_tokens #[serde(skip_serializing_if = "Option::is_none")] #[deprecated(note = "Use max_completion_tokens instead")] + #[validate(range(min = 1))] pub max_tokens: Option, /// An upper bound for the number of tokens that can be generated for a completion #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 1))] pub max_completion_tokens: Option, /// Developer-defined tags and values used for filtering completions in the dashboard @@ -218,6 +223,7 @@ pub struct ChatCompletionRequest { /// How many chat completion choices to generate for each input message #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 1, max = 10))] pub n: Option, /// Whether to enable parallel function calling during tool use @@ -226,6 +232,7 @@ pub struct ChatCompletionRequest { /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = -2.0, max = 2.0))] pub presence_penalty: Option, /// Cache key for prompts (beta feature) @@ -255,6 +262,7 @@ pub struct ChatCompletionRequest { /// Up to 4 sequences where the API will stop generating further tokens #[serde(skip_serializing_if = "Option::is_none")] + #[validate(custom(function = "validate_stop"))] pub stop: Option, /// If set, partial message deltas will be sent @@ -267,6 +275,7 @@ pub struct ChatCompletionRequest { /// What sampling temperature to use, between 0 and 2 #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0.0, max = 2.0))] pub temperature: Option, /// Controls which (if any) tool is called by the model @@ -279,30 +288,42 @@ pub struct ChatCompletionRequest { /// An integer between 0 and 20 specifying the number of most likely tokens to return #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0, max = 20))] pub top_logprobs: Option, /// An alternative to sampling with temperature #[serde(skip_serializing_if = "Option::is_none")] + #[validate(custom(function = "validate_top_p_value"))] pub top_p: Option, /// Verbosity level for debugging #[serde(skip_serializing_if = "Option::is_none")] pub verbosity: Option, + // ============================================================================= + // Engine-Specific Sampling Parameters + // ============================================================================= + // These parameters are extensions beyond the OpenAI API specification and + // control model generation behavior in engine-specific ways. + // ============================================================================= /// Top-k sampling parameter (-1 to disable) #[serde(skip_serializing_if = "Option::is_none")] + #[validate(custom(function = "validate_top_k_value"))] pub top_k: Option, /// Min-p nucleus sampling parameter #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0.0, max = 1.0))] pub min_p: Option, /// Minimum number of tokens to generate #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 1))] pub min_tokens: Option, /// Repetition penalty for reducing repetitive text #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0.0, max = 2.0))] pub repetition_penalty: Option, /// Regex constraint for output generation @@ -362,6 +383,290 @@ pub struct ChatCompletionRequest { pub sampling_seed: Option, } +// Validation functions for ChatCompletionRequest +// These are automatically called by the validator derive macro + +/// Validates stop sequences (max 4, non-empty strings) +fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> { + match stop { + StringOrArray::String(s) => { + if s.is_empty() { + return Err(validator::ValidationError::new( + "stop sequences cannot be empty", + )); + } + } + StringOrArray::Array(arr) => { + if arr.len() > 4 { + return Err(validator::ValidationError::new( + "maximum 4 stop sequences allowed", + )); + } + for s in arr { + if s.is_empty() { + return Err(validator::ValidationError::new( + "stop sequences cannot be empty", + )); + } + } + } + } + Ok(()) +} + +/// Validates messages array is not empty and has valid content +fn validate_messages(messages: &[ChatMessage]) -> Result<(), validator::ValidationError> { + if messages.is_empty() { + return Err(validator::ValidationError::new("messages cannot be empty")); + } + + for msg in messages.iter() { + if let ChatMessage::User { content, .. } = msg { + match content { + UserMessageContent::Text(text) if text.is_empty() => { + return Err(validator::ValidationError::new( + "message content cannot be empty", + )); + } + UserMessageContent::Parts(parts) if parts.is_empty() => { + return Err(validator::ValidationError::new( + "message content parts cannot be empty", + )); + } + _ => {} + } + } + } + Ok(()) +} + +/// Validates top_p: 0.0 < top_p <= 1.0 (exclusive lower bound - can't use range validator) +fn validate_top_p_value(top_p: f32) -> Result<(), validator::ValidationError> { + if !(top_p > 0.0 && top_p <= 1.0) { + return Err(validator::ValidationError::new( + "top_p must be in (0, 1] - greater than 0.0 and at most 1.0", + )); + } + Ok(()) +} + +/// Validates top_k: -1 (disabled) or >= 1 (special -1 case - can't use range validator) +fn validate_top_k_value(top_k: i32) -> Result<(), validator::ValidationError> { + if top_k != -1 && top_k < 1 { + return Err(validator::ValidationError::new( + "top_k must be -1 (disabled) or at least 1", + )); + } + Ok(()) +} + +/// Schema-level validation for cross-field dependencies +fn validate_chat_cross_parameters( + req: &ChatCompletionRequest, +) -> Result<(), validator::ValidationError> { + // 1. Validate logprobs dependency + if req.top_logprobs.is_some() && !req.logprobs { + let mut e = validator::ValidationError::new("top_logprobs_requires_logprobs"); + e.message = Some("top_logprobs is only allowed when logprobs is enabled".into()); + return Err(e); + } + + // 2. Validate stream_options dependency + if req.stream_options.is_some() && !req.stream { + let mut e = validator::ValidationError::new("stream_options_requires_stream"); + e.message = + Some("The 'stream_options' parameter is only allowed when 'stream' is enabled".into()); + return Err(e); + } + + // 3. Validate token limits - min <= max + if let (Some(min), Some(max)) = (req.min_tokens, req.max_completion_tokens) { + if min > max { + let mut e = validator::ValidationError::new("min_tokens_exceeds_max"); + e.message = Some("min_tokens cannot exceed max_tokens/max_completion_tokens".into()); + return Err(e); + } + } + + // 4. Validate structured output conflicts + let has_json_format = matches!( + req.response_format, + Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. }) + ); + + if has_json_format && req.regex.is_some() { + let mut e = validator::ValidationError::new("regex_conflicts_with_json"); + e.message = Some("cannot use regex constraint with JSON response format".into()); + return Err(e); + } + + if has_json_format && req.ebnf.is_some() { + let mut e = validator::ValidationError::new("ebnf_conflicts_with_json"); + e.message = Some("cannot use EBNF constraint with JSON response format".into()); + return Err(e); + } + + // 5. Validate mutually exclusive structured output constraints + let constraint_count = [ + req.regex.is_some(), + req.ebnf.is_some(), + matches!(req.response_format, Some(ResponseFormat::JsonSchema { .. })), + ] + .iter() + .filter(|&&x| x) + .count(); + + if constraint_count > 1 { + let mut e = validator::ValidationError::new("multiple_constraints"); + e.message = Some("only one structured output constraint (regex, ebnf, or json_schema) can be active at a time".into()); + return Err(e); + } + + // 6. Validate response format JSON schema name + if let Some(ResponseFormat::JsonSchema { json_schema }) = &req.response_format { + if json_schema.name.is_empty() { + let mut e = validator::ValidationError::new("json_schema_name_empty"); + e.message = Some("JSON schema name cannot be empty".into()); + return Err(e); + } + } + + // 7. Validate tool_choice requires tools (except for "none") + if let Some(ref tool_choice) = req.tool_choice { + let has_tools = req.tools.as_ref().is_some_and(|t| !t.is_empty()); + + // Check if tool_choice is anything other than "none" + let is_some_choice = !matches!(tool_choice, ToolChoice::Value(ToolChoiceValue::None)); + + if is_some_choice && !has_tools { + let mut e = validator::ValidationError::new("tool_choice_requires_tools"); + e.message = Some("Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.".into()); + return Err(e); + } + + // Additional validation when tools are present + if has_tools { + let tools = req.tools.as_ref().unwrap(); + + match tool_choice { + ToolChoice::Function { function, .. } => { + // Validate that the specified function name exists in tools + let function_exists = tools.iter().any(|tool| { + tool.tool_type == "function" && tool.function.name == function.name + }); + + if !function_exists { + let mut e = + validator::ValidationError::new("tool_choice_function_not_found"); + e.message = Some( + format!( + "Invalid value for 'tool_choice': function '{}' not found in 'tools'.", + function.name + ) + .into(), + ); + return Err(e); + } + } + ToolChoice::AllowedTools { + mode, + tools: allowed_tools, + .. + } => { + // Validate mode is "auto" or "required" + if mode != "auto" && mode != "required" { + let mut e = validator::ValidationError::new("tool_choice_invalid_mode"); + e.message = Some(format!( + "Invalid value for 'tool_choice.mode': must be 'auto' or 'required', got '{}'.", + mode + ).into()); + return Err(e); + } + + // Validate that all referenced tool names exist in tools + for tool_ref in allowed_tools { + let tool_exists = tools.iter().any(|tool| { + tool.tool_type == tool_ref.tool_type + && tool.function.name == tool_ref.name + }); + + if !tool_exists { + let mut e = + validator::ValidationError::new("tool_choice_tool_not_found"); + e.message = Some(format!( + "Invalid value for 'tool_choice.tools': tool '{}' not found in 'tools'.", + tool_ref.name + ).into()); + return Err(e); + } + } + } + _ => {} + } + } + } + + Ok(()) +} + +impl Normalizable for ChatCompletionRequest { + /// Normalize the request by applying migrations and defaults: + /// 1. Migrate deprecated fields to their replacements + /// 2. Clear deprecated fields and log warnings + /// 3. Apply OpenAI defaults for tool_choice + fn normalize(&mut self) { + // Migrate deprecated max_tokens → max_completion_tokens + #[allow(deprecated)] + if self.max_completion_tokens.is_none() && self.max_tokens.is_some() { + tracing::warn!("max_tokens is deprecated, use max_completion_tokens instead"); + self.max_completion_tokens = self.max_tokens; + self.max_tokens = None; // Clear deprecated field + } + + // Migrate deprecated functions → tools + #[allow(deprecated)] + if self.tools.is_none() && self.functions.is_some() { + tracing::warn!("functions is deprecated, use tools instead"); + self.tools = self.functions.as_ref().map(|functions| { + functions + .iter() + .map(|func| Tool { + tool_type: "function".to_string(), + function: func.clone(), + }) + .collect() + }); + self.functions = None; // Clear deprecated field + } + + // Migrate deprecated function_call → tool_choice + #[allow(deprecated)] + if self.tool_choice.is_none() && self.function_call.is_some() { + tracing::warn!("function_call is deprecated, use tool_choice instead"); + self.tool_choice = self.function_call.as_ref().map(|fc| match fc { + FunctionCall::None => ToolChoice::Value(ToolChoiceValue::None), + FunctionCall::Auto => ToolChoice::Value(ToolChoiceValue::Auto), + FunctionCall::Function { name } => ToolChoice::Function { + tool_type: "function".to_string(), + function: FunctionChoice { name: name.clone() }, + }, + }); + self.function_call = None; // Clear deprecated field + } + + // Apply tool_choice defaults + if self.tool_choice.is_none() { + let has_tools = self.tools.as_ref().is_some_and(|t| !t.is_empty()); + + self.tool_choice = if has_tools { + Some(ToolChoice::Value(ToolChoiceValue::Auto)) + } else { + Some(ToolChoice::Value(ToolChoiceValue::None)) + }; + } + } +} + impl GenerationRequest for ChatCompletionRequest { fn is_stream(&self) -> bool { self.stream @@ -553,6 +858,7 @@ pub struct CompletionRequest { #[serde(skip_serializing_if = "Option::is_none")] pub seed: Option, + // -------- Engine Specific Sampling Parameters -------- /// Top-k sampling parameter (-1 to disable) #[serde(skip_serializing_if = "Option::is_none")] pub top_k: Option, @@ -1816,6 +2122,9 @@ pub struct Function { #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, pub parameters: Value, // JSON Schema + /// Whether to enable strict schema adherence (OpenAI structured outputs) + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -1911,55 +2220,33 @@ pub enum InputIds { Batch(Vec>), } -#[derive(Debug, Clone, Deserialize, Serialize, Default)] -pub struct GenerateParameters { - #[serde(skip_serializing_if = "Option::is_none")] - pub best_of: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub decoder_input_details: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub details: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub do_sample: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_new_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub repetition_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub return_full_text: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub truncate: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub typical_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub watermark: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize, Default)] +#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)] +#[validate(schema(function = "validate_sampling_params"))] pub struct SamplingParams { + /// Temperature for sampling (must be >= 0.0, no upper limit) #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0.0))] pub temperature: Option, + /// Maximum number of new tokens to generate (must be >= 0) #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0))] pub max_new_tokens: Option, + /// Top-p nucleus sampling (0.0 < top_p <= 1.0) #[serde(skip_serializing_if = "Option::is_none")] + #[validate(custom(function = "validate_top_p_value"))] pub top_p: Option, + /// Top-k sampling (-1 to disable, or >= 1) #[serde(skip_serializing_if = "Option::is_none")] + #[validate(custom(function = "validate_top_k_value"))] pub top_k: Option, #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = -2.0, max = 2.0))] pub frequency_penalty: Option, #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = -2.0, max = 2.0))] pub presence_penalty: Option, #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0.0, max = 2.0))] pub repetition_penalty: Option, #[serde(skip_serializing_if = "Option::is_none")] pub stop: Option, @@ -1974,9 +2261,11 @@ pub struct SamplingParams { #[serde(skip_serializing_if = "Option::is_none")] pub ebnf: Option, #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0.0, max = 1.0))] pub min_p: Option, + /// Minimum number of new tokens (validated in schema function for cross-field check with max_new_tokens) #[serde(skip_serializing_if = "Option::is_none")] - pub min_tokens: Option, + pub min_new_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] pub stop_token_ids: Option>, #[serde(skip_serializing_if = "Option::is_none")] @@ -1987,7 +2276,38 @@ pub struct SamplingParams { pub sampling_seed: Option, } -#[derive(Clone, Debug, Serialize, Deserialize)] +/// Validation function for SamplingParams - cross-field validation only +fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::ValidationError> { + // 1. Cross-field validation: min_new_tokens <= max_new_tokens + if let (Some(min), Some(max)) = (params.min_new_tokens, params.max_new_tokens) { + if min > max { + return Err(validator::ValidationError::new( + "min_new_tokens cannot exceed max_new_tokens", + )); + } + } + + // 2. Validate mutually exclusive structured output constraints + let constraint_count = [ + params.regex.is_some(), + params.ebnf.is_some(), + params.json_schema.is_some(), + ] + .iter() + .filter(|&&x| x) + .count(); + + if constraint_count > 1 { + return Err(validator::ValidationError::new( + "only one of regex, ebnf, or json_schema can be set", + )); + } + + Ok(()) +} + +#[derive(Clone, Debug, Serialize, Deserialize, Validate)] +#[validate(schema(function = "validate_generate_request"))] pub struct GenerateRequest { /// The prompt to generate from (OpenAI style) #[serde(skip_serializing_if = "Option::is_none")] @@ -2001,10 +2321,6 @@ pub struct GenerateRequest { #[serde(skip_serializing_if = "Option::is_none")] pub input_ids: Option, - /// Generation parameters - #[serde(default, skip_serializing_if = "Option::is_none")] - pub parameters: Option, - /// Sampling parameters (sglang style) #[serde(skip_serializing_if = "Option::is_none")] pub sampling_params: Option, @@ -2034,6 +2350,34 @@ pub struct GenerateRequest { pub rid: Option, } +impl Normalizable for GenerateRequest { + // Use default no-op implementation - no normalization needed for GenerateRequest +} + +/// Validation function for GenerateRequest - ensure exactly one input type is provided +fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> { + // Exactly one of text or input_ids must be provided + // Note: input_embeds not yet supported in Rust implementation + let has_text = req.text.is_some() || req.prompt.is_some(); + let has_input_ids = req.input_ids.is_some(); + + let count = [has_text, has_input_ids].iter().filter(|&&x| x).count(); + + if count == 0 { + return Err(validator::ValidationError::new( + "Either text or input_ids should be provided.", + )); + } + + if count > 1 { + return Err(validator::ValidationError::new( + "Either text or input_ids should be provided.", + )); + } + + Ok(()) +} + impl GenerationRequest for GenerateRequest { fn is_stream(&self) -> bool { self.stream @@ -2168,7 +2512,7 @@ pub struct RerankRequest { pub user: Option, } -fn default_model_name() -> String { +pub fn default_model_name() -> String { DEFAULT_MODEL_NAME.to_string() } @@ -2441,710 +2785,3 @@ pub enum LoRAPath { Single(Option), Batch(Vec>), } - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::{from_str, json, to_string}; - - #[test] - fn test_rerank_request_serialization() { - let request = RerankRequest { - query: "test query".to_string(), - documents: vec!["doc1".to_string(), "doc2".to_string()], - model: "test-model".to_string(), - top_k: Some(5), - return_documents: true, - rid: Some(StringOrArray::String("req-123".to_string())), - user: Some("user-456".to_string()), - }; - - let serialized = to_string(&request).unwrap(); - let deserialized: RerankRequest = from_str(&serialized).unwrap(); - - assert_eq!(deserialized.query, request.query); - assert_eq!(deserialized.documents, request.documents); - assert_eq!(deserialized.model, request.model); - assert_eq!(deserialized.top_k, request.top_k); - assert_eq!(deserialized.return_documents, request.return_documents); - assert_eq!(deserialized.rid, request.rid); - assert_eq!(deserialized.user, request.user); - } - - #[test] - fn test_rerank_request_deserialization_with_defaults() { - let json = r#"{ - "query": "test query", - "documents": ["doc1", "doc2"] - }"#; - - let request: RerankRequest = from_str(json).unwrap(); - - assert_eq!(request.query, "test query"); - assert_eq!(request.documents, vec!["doc1", "doc2"]); - assert_eq!(request.model, default_model_name()); - assert_eq!(request.top_k, None); - assert!(request.return_documents); - assert_eq!(request.rid, None); - assert_eq!(request.user, None); - } - - #[test] - fn test_rerank_request_validation_success() { - let request = RerankRequest { - query: "valid query".to_string(), - documents: vec!["doc1".to_string(), "doc2".to_string()], - model: "test-model".to_string(), - top_k: Some(2), - return_documents: true, - rid: None, - user: None, - }; - - assert!(request.validate().is_ok()); - } - - #[test] - fn test_rerank_request_validation_empty_query() { - let request = RerankRequest { - query: "".to_string(), - documents: vec!["doc1".to_string()], - model: "test-model".to_string(), - top_k: None, - return_documents: true, - rid: None, - user: None, - }; - - let result = request.validate(); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "Query cannot be empty"); - } - - #[test] - fn test_rerank_request_validation_whitespace_query() { - let request = RerankRequest { - query: " ".to_string(), - documents: vec!["doc1".to_string()], - model: "test-model".to_string(), - top_k: None, - return_documents: true, - rid: None, - user: None, - }; - - let result = request.validate(); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "Query cannot be empty"); - } - - #[test] - fn test_rerank_request_validation_empty_documents() { - let request = RerankRequest { - query: "test query".to_string(), - documents: vec![], - model: "test-model".to_string(), - top_k: None, - return_documents: true, - rid: None, - user: None, - }; - - let result = request.validate(); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "Documents list cannot be empty"); - } - - #[test] - fn test_rerank_request_validation_top_k_zero() { - let request = RerankRequest { - query: "test query".to_string(), - documents: vec!["doc1".to_string(), "doc2".to_string()], - model: "test-model".to_string(), - top_k: Some(0), - return_documents: true, - rid: None, - user: None, - }; - - let result = request.validate(); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "top_k must be greater than 0"); - } - - #[test] - fn test_rerank_request_validation_top_k_greater_than_docs() { - let request = RerankRequest { - query: "test query".to_string(), - documents: vec!["doc1".to_string(), "doc2".to_string()], - model: "test-model".to_string(), - top_k: Some(5), - return_documents: true, - rid: None, - user: None, - }; - - // This should pass but log a warning - assert!(request.validate().is_ok()); - } - - #[test] - fn test_rerank_request_effective_top_k() { - let request = RerankRequest { - query: "test query".to_string(), - documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()], - model: "test-model".to_string(), - top_k: Some(2), - return_documents: true, - rid: None, - user: None, - }; - - assert_eq!(request.effective_top_k(), 2); - } - - #[test] - fn test_rerank_request_effective_top_k_none() { - let request = RerankRequest { - query: "test query".to_string(), - documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()], - model: "test-model".to_string(), - top_k: None, - return_documents: true, - rid: None, - user: None, - }; - - assert_eq!(request.effective_top_k(), 3); - } - - #[test] - fn test_rerank_response_creation() { - let results = vec![ - RerankResult { - score: 0.8, - document: Some("doc1".to_string()), - index: 0, - meta_info: None, - }, - RerankResult { - score: 0.6, - document: Some("doc2".to_string()), - index: 1, - meta_info: None, - }, - ]; - - let response = RerankResponse::new( - results.clone(), - "test-model".to_string(), - Some(StringOrArray::String("req-123".to_string())), - ); - - assert_eq!(response.results.len(), 2); - assert_eq!(response.model, "test-model"); - assert_eq!( - response.id, - Some(StringOrArray::String("req-123".to_string())) - ); - assert_eq!(response.object, "rerank"); - assert!(response.created > 0); - } - - #[test] - fn test_rerank_response_serialization() { - let results = vec![RerankResult { - score: 0.8, - document: Some("doc1".to_string()), - index: 0, - meta_info: None, - }]; - - let response = RerankResponse::new( - results, - "test-model".to_string(), - Some(StringOrArray::String("req-123".to_string())), - ); - - let serialized = to_string(&response).unwrap(); - let deserialized: RerankResponse = from_str(&serialized).unwrap(); - - assert_eq!(deserialized.results.len(), response.results.len()); - assert_eq!(deserialized.model, response.model); - assert_eq!(deserialized.id, response.id); - assert_eq!(deserialized.object, response.object); - } - - #[test] - fn test_rerank_response_sort_by_score() { - let results = vec![ - RerankResult { - score: 0.6, - document: Some("doc2".to_string()), - index: 1, - meta_info: None, - }, - RerankResult { - score: 0.8, - document: Some("doc1".to_string()), - index: 0, - meta_info: None, - }, - RerankResult { - score: 0.4, - document: Some("doc3".to_string()), - index: 2, - meta_info: None, - }, - ]; - - let mut response = RerankResponse::new( - results, - "test-model".to_string(), - Some(StringOrArray::String("req-123".to_string())), - ); - - response.sort_by_score(); - - assert_eq!(response.results[0].score, 0.8); - assert_eq!(response.results[0].index, 0); - assert_eq!(response.results[1].score, 0.6); - assert_eq!(response.results[1].index, 1); - assert_eq!(response.results[2].score, 0.4); - assert_eq!(response.results[2].index, 2); - } - - #[test] - fn test_rerank_response_apply_top_k() { - let results = vec![ - RerankResult { - score: 0.8, - document: Some("doc1".to_string()), - index: 0, - meta_info: None, - }, - RerankResult { - score: 0.6, - document: Some("doc2".to_string()), - index: 1, - meta_info: None, - }, - RerankResult { - score: 0.4, - document: Some("doc3".to_string()), - index: 2, - meta_info: None, - }, - ]; - - let mut response = RerankResponse::new( - results, - "test-model".to_string(), - Some(StringOrArray::String("req-123".to_string())), - ); - - response.apply_top_k(2); - - assert_eq!(response.results.len(), 2); - assert_eq!(response.results[0].score, 0.8); - assert_eq!(response.results[1].score, 0.6); - } - - #[test] - fn test_rerank_response_apply_top_k_larger_than_results() { - let results = vec![RerankResult { - score: 0.8, - document: Some("doc1".to_string()), - index: 0, - meta_info: None, - }]; - - let mut response = RerankResponse::new( - results, - "test-model".to_string(), - Some(StringOrArray::String("req-123".to_string())), - ); - - response.apply_top_k(5); - - assert_eq!(response.results.len(), 1); - } - - #[test] - fn test_rerank_response_drop_documents() { - let results = vec![RerankResult { - score: 0.8, - document: Some("doc1".to_string()), - index: 0, - meta_info: None, - }]; - let mut response = RerankResponse::new( - results, - "test-model".to_string(), - Some(StringOrArray::String("req-123".to_string())), - ); - - response.drop_documents(); - - assert_eq!(response.results[0].document, None); - } - - #[test] - fn test_rerank_result_serialization() { - let result = RerankResult { - score: 0.85, - document: Some("test document".to_string()), - index: 42, - meta_info: Some(HashMap::from([ - ("confidence".to_string(), Value::String("high".to_string())), - ( - "processing_time".to_string(), - Value::Number(Number::from(150)), - ), - ])), - }; - - let serialized = to_string(&result).unwrap(); - let deserialized: RerankResult = from_str(&serialized).unwrap(); - - assert_eq!(deserialized.score, result.score); - assert_eq!(deserialized.document, result.document); - assert_eq!(deserialized.index, result.index); - assert_eq!(deserialized.meta_info, result.meta_info); - } - - #[test] - fn test_rerank_result_serialization_without_document() { - let result = RerankResult { - score: 0.85, - document: None, - index: 42, - meta_info: None, - }; - - let serialized = to_string(&result).unwrap(); - let deserialized: RerankResult = from_str(&serialized).unwrap(); - - assert_eq!(deserialized.score, result.score); - assert_eq!(deserialized.document, result.document); - assert_eq!(deserialized.index, result.index); - assert_eq!(deserialized.meta_info, result.meta_info); - } - - #[test] - fn test_v1_rerank_req_input_serialization() { - let v1_input = V1RerankReqInput { - query: "test query".to_string(), - documents: vec!["doc1".to_string(), "doc2".to_string()], - }; - - let serialized = to_string(&v1_input).unwrap(); - let deserialized: V1RerankReqInput = from_str(&serialized).unwrap(); - - assert_eq!(deserialized.query, v1_input.query); - assert_eq!(deserialized.documents, v1_input.documents); - } - - #[test] - fn test_v1_to_rerank_request_conversion() { - let v1_input = V1RerankReqInput { - query: "test query".to_string(), - documents: vec!["doc1".to_string(), "doc2".to_string()], - }; - - let request: RerankRequest = v1_input.into(); - - assert_eq!(request.query, "test query"); - assert_eq!(request.documents, vec!["doc1", "doc2"]); - assert_eq!(request.model, default_model_name()); - assert_eq!(request.top_k, None); - assert!(request.return_documents); - assert_eq!(request.rid, None); - assert_eq!(request.user, None); - } - - #[test] - fn test_rerank_request_generation_request_trait() { - let request = RerankRequest { - query: "test query".to_string(), - documents: vec!["doc1".to_string()], - model: "test-model".to_string(), - top_k: None, - return_documents: true, - rid: None, - user: None, - }; - - assert_eq!(request.get_model(), Some("test-model")); - assert!(!request.is_stream()); - assert_eq!(request.extract_text_for_routing(), "test query"); - } - - #[test] - fn test_rerank_request_very_long_query() { - let long_query = "a".repeat(100000); - let request = RerankRequest { - query: long_query, - documents: vec!["doc1".to_string()], - model: "test-model".to_string(), - top_k: None, - return_documents: true, - rid: None, - user: None, - }; - - assert!(request.validate().is_ok()); - } - - #[test] - fn test_rerank_request_many_documents() { - let documents: Vec = (0..1000).map(|i| format!("doc{}", i)).collect(); - let request = RerankRequest { - query: "test query".to_string(), - documents, - model: "test-model".to_string(), - top_k: Some(100), - return_documents: true, - rid: None, - user: None, - }; - - assert!(request.validate().is_ok()); - assert_eq!(request.effective_top_k(), 100); - } - - #[test] - fn test_rerank_request_special_characters() { - let request = RerankRequest { - query: "query with émojis 🚀 and unicode: 测试".to_string(), - documents: vec![ - "doc with émojis 🎉".to_string(), - "doc with unicode: 测试".to_string(), - ], - model: "test-model".to_string(), - top_k: None, - return_documents: true, - rid: Some(StringOrArray::String("req-🚀-123".to_string())), - user: Some("user-🎉-456".to_string()), - }; - - assert!(request.validate().is_ok()); - } - - #[test] - fn test_rerank_request_rid_array() { - let request = RerankRequest { - query: "test query".to_string(), - documents: vec!["doc1".to_string()], - model: "test-model".to_string(), - top_k: None, - return_documents: true, - rid: Some(StringOrArray::Array(vec![ - "req1".to_string(), - "req2".to_string(), - ])), - user: None, - }; - - assert!(request.validate().is_ok()); - } - - #[test] - fn test_rerank_response_with_usage_info() { - let results = vec![RerankResult { - score: 0.8, - document: Some("doc1".to_string()), - index: 0, - meta_info: None, - }]; - - let mut response = RerankResponse::new( - results, - "test-model".to_string(), - Some(StringOrArray::String("req-123".to_string())), - ); - - response.usage = Some(UsageInfo { - prompt_tokens: 100, - completion_tokens: 50, - total_tokens: 150, - reasoning_tokens: None, - prompt_tokens_details: None, - }); - - let serialized = to_string(&response).unwrap(); - let deserialized: RerankResponse = from_str(&serialized).unwrap(); - - assert!(deserialized.usage.is_some()); - let usage = deserialized.usage.unwrap(); - assert_eq!(usage.prompt_tokens, 100); - assert_eq!(usage.completion_tokens, 50); - assert_eq!(usage.total_tokens, 150); - } - - #[test] - fn test_full_rerank_workflow() { - // Create request - let request = RerankRequest { - query: "machine learning".to_string(), - documents: vec![ - "Introduction to machine learning algorithms".to_string(), - "Deep learning for computer vision".to_string(), - "Natural language processing basics".to_string(), - "Statistics and probability theory".to_string(), - ], - model: "rerank-model".to_string(), - top_k: Some(2), - return_documents: true, - rid: Some(StringOrArray::String("req-123".to_string())), - user: Some("user-456".to_string()), - }; - - // Validate request - assert!(request.validate().is_ok()); - - // Simulate reranking results (in real scenario, this would come from the model) - let results = vec![ - RerankResult { - score: 0.95, - document: Some("Introduction to machine learning algorithms".to_string()), - index: 0, - meta_info: None, - }, - RerankResult { - score: 0.87, - document: Some("Deep learning for computer vision".to_string()), - index: 1, - meta_info: None, - }, - RerankResult { - score: 0.72, - document: Some("Natural language processing basics".to_string()), - index: 2, - meta_info: None, - }, - RerankResult { - score: 0.45, - document: Some("Statistics and probability theory".to_string()), - index: 3, - meta_info: None, - }, - ]; - - // Create response - let mut response = RerankResponse::new(results, request.model.clone(), request.rid.clone()); - - // Sort by score - response.sort_by_score(); - - // Apply top_k - response.apply_top_k(request.effective_top_k()); - - assert_eq!(response.results.len(), 2); - assert_eq!(response.results[0].score, 0.95); - assert_eq!(response.results[0].index, 0); - assert_eq!(response.results[1].score, 0.87); - assert_eq!(response.results[1].index, 1); - assert_eq!(response.model, "rerank-model"); - - // Serialize and deserialize - let serialized = to_string(&response).unwrap(); - let deserialized: RerankResponse = from_str(&serialized).unwrap(); - assert_eq!(deserialized.results.len(), 2); - assert_eq!(deserialized.model, response.model); - } - - #[test] - fn test_embedding_request_serialization_string_input() { - let req = EmbeddingRequest { - model: "test-emb".to_string(), - input: Value::String("hello".to_string()), - encoding_format: Some("float".to_string()), - user: Some("user-1".to_string()), - dimensions: Some(128), - rid: Some("rid-123".to_string()), - }; - - let serialized = to_string(&req).unwrap(); - let deserialized: EmbeddingRequest = from_str(&serialized).unwrap(); - - assert_eq!(deserialized.model, req.model); - assert_eq!(deserialized.input, req.input); - assert_eq!(deserialized.encoding_format, req.encoding_format); - assert_eq!(deserialized.user, req.user); - assert_eq!(deserialized.dimensions, req.dimensions); - assert_eq!(deserialized.rid, req.rid); - } - - #[test] - fn test_embedding_request_serialization_array_input() { - let req = EmbeddingRequest { - model: "test-emb".to_string(), - input: json!(["a", "b", "c"]), - encoding_format: None, - user: None, - dimensions: None, - rid: None, - }; - - let serialized = to_string(&req).unwrap(); - let de: EmbeddingRequest = from_str(&serialized).unwrap(); - assert_eq!(de.model, req.model); - assert_eq!(de.input, req.input); - } - - #[test] - fn test_embedding_generation_request_trait_string() { - let req = EmbeddingRequest { - model: "emb-model".to_string(), - input: Value::String("hello".to_string()), - encoding_format: None, - user: None, - dimensions: None, - rid: None, - }; - assert!(!req.is_stream()); - assert_eq!(req.get_model(), Some("emb-model")); - assert_eq!(req.extract_text_for_routing(), "hello"); - } - - #[test] - fn test_embedding_generation_request_trait_array() { - let req = EmbeddingRequest { - model: "emb-model".to_string(), - input: json!(["hello", "world"]), - encoding_format: None, - user: None, - dimensions: None, - rid: None, - }; - assert_eq!(req.extract_text_for_routing(), "hello world"); - } - - #[test] - fn test_embedding_generation_request_trait_non_text() { - let req = EmbeddingRequest { - model: "emb-model".to_string(), - input: json!({"tokens": [1, 2, 3]}), - encoding_format: None, - user: None, - dimensions: None, - rid: None, - }; - assert_eq!(req.extract_text_for_routing(), ""); - } - - #[test] - fn test_embedding_generation_request_trait_mixed_array_ignores_nested() { - let req = EmbeddingRequest { - model: "emb-model".to_string(), - input: json!(["a", ["b", "c"], 123, {"k": "v"}]), - encoding_format: None, - user: None, - dimensions: None, - rid: None, - }; - // Only top-level string elements are extracted - assert_eq!(req.extract_text_for_routing(), "a"); - } -} diff --git a/sgl-router/src/protocols/validated.rs b/sgl-router/src/protocols/validated.rs new file mode 100644 index 000000000..58958f66b --- /dev/null +++ b/sgl-router/src/protocols/validated.rs @@ -0,0 +1,172 @@ +// Validated JSON extractor for automatic request validation +// +// This module provides a ValidatedJson extractor that automatically validates +// requests using the validator crate's Validate trait. + +use axum::{ + extract::{rejection::JsonRejection, FromRequest, Request}, + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde::de::DeserializeOwned; +use serde_json::json; +use validator::Validate; + +/// Trait for request types that need post-deserialization normalization +pub trait Normalizable { + /// Normalize the request by applying defaults and transformations + fn normalize(&mut self) { + // Default: no-op + } +} + +/// A JSON extractor that automatically validates and normalizes the request body +/// +/// This extractor deserializes the request body and automatically calls `.validate()` +/// on types that implement the `Validate` trait. If validation fails, it returns +/// a 400 Bad Request with detailed error information. +/// +/// # Example +/// +/// ```rust,ignore +/// async fn create_chat( +/// ValidatedJson(request): ValidatedJson, +/// ) -> Response { +/// // request is guaranteed to be valid here +/// process_request(request).await +/// } +/// ``` +pub struct ValidatedJson(pub T); + +impl FromRequest for ValidatedJson +where + T: DeserializeOwned + Validate + Normalizable + Send, + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request(req: Request, state: &S) -> Result { + // First, extract and deserialize the JSON + let Json(mut data) = + Json::::from_request(req, state) + .await + .map_err(|err: JsonRejection| { + let error_message = match err { + JsonRejection::JsonDataError(e) => { + format!("Invalid JSON data: {}", e) + } + JsonRejection::JsonSyntaxError(e) => { + format!("JSON syntax error: {}", e) + } + JsonRejection::MissingJsonContentType(_) => { + "Missing Content-Type: application/json header".to_string() + } + _ => format!("Failed to parse JSON: {}", err), + }; + + ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": { + "message": error_message, + "type": "invalid_request_error", + "code": "json_parse_error" + } + })), + ) + .into_response() + })?; + + // Normalize the request (apply defaults based on other fields) + data.normalize(); + + // Then, automatically validate the data + data.validate().map_err(|validation_errors| { + // Extract the first error message from the validation errors + let error_message = validation_errors + .field_errors() + .values() + .flat_map(|errors| errors.iter()) + .find_map(|e| e.message.as_ref()) + .map(|m| m.to_string()) + .unwrap_or_else(|| "Validation failed".to_string()); + + ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": { + "message": error_message, + "type": "invalid_request_error", + "code": 400 + } + })), + ) + .into_response() + })?; + + Ok(ValidatedJson(data)) + } +} + +// Implement Deref to allow transparent access to the inner value +impl std::ops::Deref for ValidatedJson { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for ValidatedJson { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::{Deserialize, Serialize}; + use validator::Validate; + + #[derive(Debug, Deserialize, Serialize, Validate)] + struct TestRequest { + #[validate(range(min = 0.0, max = 1.0))] + value: f32, + #[validate(length(min = 1))] + name: String, + } + + impl Normalizable for TestRequest { + // Use default no-op implementation + } + + #[tokio::test] + async fn test_validated_json_valid() { + // This test is conceptual - actual testing would require Axum test harness + let request = TestRequest { + value: 0.5, + name: "test".to_string(), + }; + assert!(request.validate().is_ok()); + } + + #[tokio::test] + async fn test_validated_json_invalid_range() { + let request = TestRequest { + value: 1.5, // Out of range + name: "test".to_string(), + }; + assert!(request.validate().is_err()); + } + + #[tokio::test] + async fn test_validated_json_invalid_length() { + let request = TestRequest { + value: 0.5, + name: "".to_string(), // Empty name + }; + assert!(request.validate().is_err()); + } +} diff --git a/sgl-router/src/protocols/validation.rs b/sgl-router/src/protocols/validation.rs deleted file mode 100644 index b036c5b4a..000000000 --- a/sgl-router/src/protocols/validation.rs +++ /dev/null @@ -1,1149 +0,0 @@ -// Core validation infrastructure for API parameter validation - -use anyhow::Result; -use serde::{Deserialize, Serialize}; -use std::fmt::Display; - -// Import types from spec module -use crate::protocols::spec::{ - ChatCompletionRequest, ChatMessage, ResponseFormat, StringOrArray, UserMessageContent, -}; - -/// Validation constants for OpenAI API parameters -pub mod constants { - /// Temperature range: 0.0 to 2.0 (OpenAI spec) - pub const TEMPERATURE_RANGE: (f32, f32) = (0.0, 2.0); - - /// Top-p range: 0.0 to 1.0 (exclusive of 0.0) - pub const TOP_P_RANGE: (f32, f32) = (0.0, 1.0); - - /// Presence penalty range: -2.0 to 2.0 (OpenAI spec) - pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (-2.0, 2.0); - - /// Frequency penalty range: -2.0 to 2.0 (OpenAI spec) - pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (-2.0, 2.0); - - /// Logprobs range for completions API: 0 to 5 - pub const LOGPROBS_RANGE: (u32, u32) = (0, 5); - - /// Top logprobs range for chat completions: 0 to 20 - pub const TOP_LOGPROBS_RANGE: (u32, u32) = (0, 20); - - /// Maximum number of stop sequences allowed - pub const MAX_STOP_SEQUENCES: usize = 4; - - /// SGLang-specific validation constants - pub mod sglang { - /// Min-p range: 0.0 to 1.0 (SGLang extension) - pub const MIN_P_RANGE: (f32, f32) = (0.0, 1.0); - - /// Top-k minimum value: -1 to disable, otherwise positive - pub const TOP_K_MIN: i32 = -1; - - /// Repetition penalty range: 0.0 to 2.0 (SGLang extension) - /// 1.0 = no penalty, >1.0 = discourage repetition, <1.0 = encourage repetition - pub const REPETITION_PENALTY_RANGE: (f32, f32) = (0.0, 2.0); - } -} - -/// Core validation error types -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ValidationError { - /// Parameter value out of valid range - OutOfRange { - parameter: String, - value: String, - min: String, - max: String, - }, - /// Invalid parameter value format or type - InvalidValue { - parameter: String, - value: String, - reason: String, - }, - /// Cross-parameter validation failure - ConflictingParameters { - parameter1: String, - parameter2: String, - reason: String, - }, - /// Required parameter missing - MissingRequired { parameter: String }, - /// Too many items in array parameter - TooManyItems { - parameter: String, - count: usize, - max: usize, - }, - /// Custom validation error - Custom(String), -} - -impl Display for ValidationError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ValidationError::OutOfRange { - parameter, - value, - min, - max, - } => { - write!( - f, - "Parameter '{}' must be between {} and {}, got {}", - parameter, min, max, value - ) - } - ValidationError::InvalidValue { - parameter, - value, - reason, - } => { - write!( - f, - "Invalid value for parameter '{}': {} ({})", - parameter, value, reason - ) - } - ValidationError::ConflictingParameters { - parameter1, - parameter2, - reason, - } => { - write!( - f, - "Conflicting parameters '{}' and '{}': {}", - parameter1, parameter2, reason - ) - } - ValidationError::MissingRequired { parameter } => { - write!(f, "Required parameter '{}' is missing", parameter) - } - ValidationError::TooManyItems { - parameter, - count, - max, - } => { - write!( - f, - "Parameter '{}' has too many items: {} (maximum: {})", - parameter, count, max - ) - } - ValidationError::Custom(msg) => write!(f, "{}", msg), - } - } -} - -impl std::error::Error for ValidationError {} - -/// Core validation utility functions -pub mod utils { - use super::*; - - /// Validate that a numeric value is within the specified range (inclusive) - pub fn validate_range( - value: T, - range: &(T, T), - param_name: &str, - ) -> Result - where - T: PartialOrd + Display + Copy, - { - if value >= range.0 && value <= range.1 { - Ok(value) - } else { - Err(ValidationError::OutOfRange { - parameter: param_name.to_string(), - value: value.to_string(), - min: range.0.to_string(), - max: range.1.to_string(), - }) - } - } - - /// Validate that a positive number is actually positive - pub fn validate_positive(value: T, param_name: &str) -> Result - where - T: PartialOrd + Display + Copy + Default, - { - if value > T::default() { - Ok(value) - } else { - Err(ValidationError::InvalidValue { - parameter: param_name.to_string(), - value: value.to_string(), - reason: "must be positive".to_string(), - }) - } - } - - /// Validate that an array doesn't exceed maximum length - pub fn validate_max_items( - items: &[T], - max_count: usize, - param_name: &str, - ) -> Result<(), ValidationError> { - if items.len() <= max_count { - Ok(()) - } else { - Err(ValidationError::TooManyItems { - parameter: param_name.to_string(), - count: items.len(), - max: max_count, - }) - } - } - - /// Validate that a required parameter is present - pub fn validate_required<'a, T>( - value: &'a Option, - param_name: &str, - ) -> Result<&'a T, ValidationError> { - value - .as_ref() - .ok_or_else(|| ValidationError::MissingRequired { - parameter: param_name.to_string(), - }) - } - - /// Validate top_k parameter (SGLang extension) - pub fn validate_top_k(top_k: i32) -> Result { - if top_k == constants::sglang::TOP_K_MIN || top_k > 0 { - Ok(top_k) - } else { - Err(ValidationError::InvalidValue { - parameter: "top_k".to_string(), - value: top_k.to_string(), - reason: "must be -1 (disabled) or positive".to_string(), - }) - } - } - - /// Generic validation function for sampling options - pub fn validate_sampling_options( - request: &T, - ) -> Result<(), ValidationError> { - // Validate temperature (0.0 to 2.0) - if let Some(temp) = request.get_temperature() { - validate_range(temp, &constants::TEMPERATURE_RANGE, "temperature")?; - } - - // Validate top_p (0.0 to 1.0) - if let Some(top_p) = request.get_top_p() { - validate_range(top_p, &constants::TOP_P_RANGE, "top_p")?; - } - - // Validate frequency_penalty (-2.0 to 2.0) - if let Some(freq_penalty) = request.get_frequency_penalty() { - validate_range( - freq_penalty, - &constants::FREQUENCY_PENALTY_RANGE, - "frequency_penalty", - )?; - } - - // Validate presence_penalty (-2.0 to 2.0) - if let Some(pres_penalty) = request.get_presence_penalty() { - validate_range( - pres_penalty, - &constants::PRESENCE_PENALTY_RANGE, - "presence_penalty", - )?; - } - - Ok(()) - } - - /// Generic validation function for stop conditions - pub fn validate_stop_conditions( - request: &T, - ) -> Result<(), ValidationError> { - if let Some(stop) = request.get_stop_sequences() { - match stop { - StringOrArray::String(s) => { - if s.is_empty() { - return Err(ValidationError::InvalidValue { - parameter: "stop".to_string(), - value: "empty string".to_string(), - reason: "stop sequences cannot be empty".to_string(), - }); - } - } - StringOrArray::Array(arr) => { - validate_max_items(arr, constants::MAX_STOP_SEQUENCES, "stop")?; - for (i, s) in arr.iter().enumerate() { - if s.is_empty() { - return Err(ValidationError::InvalidValue { - parameter: format!("stop[{}]", i), - value: "empty string".to_string(), - reason: "stop sequences cannot be empty".to_string(), - }); - } - } - } - } - } - - Ok(()) - } - - /// Generic validation function for token limits - pub fn validate_token_limits( - request: &T, - ) -> Result<(), ValidationError> { - // Validate max_tokens if provided - if let Some(max_tokens) = request.get_max_tokens() { - validate_positive(max_tokens, "max_tokens")?; - } - - // Validate min_tokens if provided (SGLang extension) - if let Some(min_tokens) = request.get_min_tokens() { - validate_positive(min_tokens, "min_tokens")?; - } - - Ok(()) - } - - /// Generic validation function for logprobs - pub fn validate_logprobs( - request: &T, - ) -> Result<(), ValidationError> { - // Validate logprobs (completions API - 0 to 5) - if let Some(logprobs) = request.get_logprobs() { - validate_range(logprobs, &constants::LOGPROBS_RANGE, "logprobs")?; - } - - // Validate top_logprobs (chat API - 0 to 20) - if let Some(top_logprobs) = request.get_top_logprobs() { - validate_range(top_logprobs, &constants::TOP_LOGPROBS_RANGE, "top_logprobs")?; - } - - Ok(()) - } - - /// Generic cross-parameter validation - pub fn validate_cross_parameters( - request: &T, - ) -> Result<(), ValidationError> { - // Check min_tokens <= max_tokens if both are specified - if let (Some(min_tokens), Some(max_tokens)) = - (request.get_min_tokens(), request.get_max_tokens()) - { - if min_tokens > max_tokens { - return Err(ValidationError::ConflictingParameters { - parameter1: "min_tokens".to_string(), - parameter2: "max_tokens".to_string(), - reason: format!( - "min_tokens ({}) cannot be greater than max_tokens ({})", - min_tokens, max_tokens - ), - }); - } - } - - Ok(()) - } - - /// Validate conflicting structured output constraints - pub fn validate_conflicting_parameters( - param1_name: &str, - param1_value: bool, - param2_name: &str, - param2_value: bool, - reason: &str, - ) -> Result<(), ValidationError> { - if param1_value && param2_value { - return Err(ValidationError::ConflictingParameters { - parameter1: param1_name.to_string(), - parameter2: param2_name.to_string(), - reason: reason.to_string(), - }); - } - Ok(()) - } - - /// Validate that only one option from a set is active - pub fn validate_mutually_exclusive_options( - options: &[(&str, bool)], - error_msg: &str, - ) -> Result<(), ValidationError> { - let active_count = options.iter().filter(|(_, is_active)| *is_active).count(); - if active_count > 1 { - return Err(ValidationError::Custom(error_msg.to_string())); - } - Ok(()) - } - - /// Generic validation for SGLang extensions - pub fn validate_sglang_extensions( - request: &T, - ) -> Result<(), ValidationError> { - // Validate top_k (-1 to disable, or positive) - if let Some(top_k) = request.get_top_k() { - validate_top_k(top_k)?; - } - - // Validate min_p (0.0 to 1.0) - if let Some(min_p) = request.get_min_p() { - validate_range(min_p, &constants::sglang::MIN_P_RANGE, "min_p")?; - } - - // Validate repetition_penalty (0.0 to 2.0) - if let Some(rep_penalty) = request.get_repetition_penalty() { - validate_range( - rep_penalty, - &constants::sglang::REPETITION_PENALTY_RANGE, - "repetition_penalty", - )?; - } - - Ok(()) - } - - /// Generic validation for n parameter (number of completions) - pub fn validate_completion_count( - request: &T, - ) -> Result<(), ValidationError> { - const N_RANGE: (u32, u32) = (1, 10); - - if let Some(n) = request.get_n() { - validate_range(n, &N_RANGE, "n")?; - } - - Ok(()) - } - - /// Validate that an array is not empty - pub fn validate_non_empty_array( - items: &[T], - param_name: &str, - ) -> Result<(), ValidationError> { - if items.is_empty() { - return Err(ValidationError::MissingRequired { - parameter: param_name.to_string(), - }); - } - Ok(()) - } - - /// Validate common request parameters that are shared across all API types - pub fn validate_common_request_params(request: &T) -> Result<(), ValidationError> - where - T: SamplingOptionsProvider - + StopConditionsProvider - + TokenLimitsProvider - + LogProbsProvider - + SGLangExtensionsProvider - + CompletionCountProvider - + ?Sized, - { - // Validate all standard parameters - validate_sampling_options(request)?; - validate_stop_conditions(request)?; - validate_token_limits(request)?; - validate_logprobs(request)?; - - // Validate SGLang extensions and completion count - validate_sglang_extensions(request)?; - validate_completion_count(request)?; - - // Perform cross-parameter validation - validate_cross_parameters(request)?; - - Ok(()) - } -} - -/// Core validation traits for different parameter categories -pub trait SamplingOptionsProvider { - /// Get temperature parameter - fn get_temperature(&self) -> Option; - - /// Get top_p parameter - fn get_top_p(&self) -> Option; - - /// Get frequency penalty parameter - fn get_frequency_penalty(&self) -> Option; - - /// Get presence penalty parameter - fn get_presence_penalty(&self) -> Option; -} - -/// Trait for validating stop conditions -pub trait StopConditionsProvider { - /// Get stop sequences - fn get_stop_sequences(&self) -> Option<&StringOrArray>; -} - -/// Trait for validating token limits -pub trait TokenLimitsProvider { - /// Get maximum tokens parameter - fn get_max_tokens(&self) -> Option; - - /// Get minimum tokens parameter (SGLang extension) - fn get_min_tokens(&self) -> Option; -} - -/// Trait for validating logprobs parameters -pub trait LogProbsProvider { - /// Get logprobs parameter (completions API) - fn get_logprobs(&self) -> Option; - - /// Get top_logprobs parameter (chat API) - fn get_top_logprobs(&self) -> Option; -} - -/// Trait for SGLang-specific extensions -pub trait SGLangExtensionsProvider { - /// Get top_k parameter - fn get_top_k(&self) -> Option { - None - } - - /// Get min_p parameter - fn get_min_p(&self) -> Option { - None - } - - /// Get repetition_penalty parameter - fn get_repetition_penalty(&self) -> Option { - None - } -} - -/// Trait for n parameter (number of completions) -pub trait CompletionCountProvider { - /// Get n parameter - fn get_n(&self) -> Option { - None - } -} - -/// Comprehensive validation trait that combines all validation aspects -pub trait ValidatableRequest: - SamplingOptionsProvider - + StopConditionsProvider - + TokenLimitsProvider - + LogProbsProvider - + SGLangExtensionsProvider - + CompletionCountProvider -{ - /// Perform comprehensive validation of the entire request - fn validate(&self) -> Result<(), ValidationError> { - // Use the common validation function - utils::validate_common_request_params(self) - } -} - -impl SamplingOptionsProvider for ChatCompletionRequest { - fn get_temperature(&self) -> Option { - self.temperature - } - fn get_top_p(&self) -> Option { - self.top_p - } - fn get_frequency_penalty(&self) -> Option { - self.frequency_penalty - } - fn get_presence_penalty(&self) -> Option { - self.presence_penalty - } -} - -impl StopConditionsProvider for ChatCompletionRequest { - fn get_stop_sequences(&self) -> Option<&StringOrArray> { - self.stop.as_ref() - } -} - -impl TokenLimitsProvider for ChatCompletionRequest { - #[allow(deprecated)] - fn get_max_tokens(&self) -> Option { - // Prefer max_completion_tokens over max_tokens if both are set - self.max_completion_tokens.or(self.max_tokens) - } - - fn get_min_tokens(&self) -> Option { - self.min_tokens - } -} - -impl LogProbsProvider for ChatCompletionRequest { - fn get_logprobs(&self) -> Option { - // For chat API, logprobs is a boolean, return 1 if true for validation purposes - if self.logprobs { - Some(1) - } else { - None - } - } - - fn get_top_logprobs(&self) -> Option { - self.top_logprobs - } -} - -impl SGLangExtensionsProvider for ChatCompletionRequest { - fn get_top_k(&self) -> Option { - self.top_k - } - - fn get_min_p(&self) -> Option { - self.min_p - } - - fn get_repetition_penalty(&self) -> Option { - self.repetition_penalty - } -} - -impl CompletionCountProvider for ChatCompletionRequest { - fn get_n(&self) -> Option { - self.n - } -} - -impl ChatCompletionRequest { - /// Validate message-specific requirements - pub fn validate_messages(&self) -> Result<(), ValidationError> { - // Ensure messages array is not empty - utils::validate_non_empty_array(&self.messages, "messages")?; - - // Validate message content is not empty - for (i, msg) in self.messages.iter().enumerate() { - if let ChatMessage::User { content, .. } = msg { - match content { - UserMessageContent::Text(text) if text.is_empty() => { - return Err(ValidationError::InvalidValue { - parameter: format!("messages[{}].content", i), - value: "empty".to_string(), - reason: "message content cannot be empty".to_string(), - }); - } - UserMessageContent::Parts(parts) if parts.is_empty() => { - return Err(ValidationError::InvalidValue { - parameter: format!("messages[{}].content", i), - value: "empty array".to_string(), - reason: "message content parts cannot be empty".to_string(), - }); - } - _ => {} - } - } - } - - Ok(()) - } - - /// Validate response format if specified - pub fn validate_response_format(&self) -> Result<(), ValidationError> { - if let Some(ResponseFormat::JsonSchema { json_schema }) = &self.response_format { - if json_schema.name.is_empty() { - return Err(ValidationError::InvalidValue { - parameter: "response_format.json_schema.name".to_string(), - value: "empty".to_string(), - reason: "JSON schema name cannot be empty".to_string(), - }); - } - } - Ok(()) - } - - /// Validate chat API specific logprobs requirements - pub fn validate_chat_logprobs(&self) -> Result<(), ValidationError> { - // OpenAI rule: If top_logprobs is specified, logprobs must be true - // But logprobs=true without top_logprobs is valid (returns basic logprobs) - if self.top_logprobs.is_some() && !self.logprobs { - return Err(ValidationError::InvalidValue { - parameter: "top_logprobs".to_string(), - value: self.top_logprobs.unwrap().to_string(), - reason: "top_logprobs is only allowed when logprobs is enabled".to_string(), - }); - } - - Ok(()) - } - - /// Validate cross-parameter relationships specific to chat completions - #[allow(deprecated)] - pub fn validate_chat_cross_parameters(&self) -> Result<(), ValidationError> { - // Validate that both max_tokens and max_completion_tokens aren't set - utils::validate_conflicting_parameters( - "max_tokens", - self.max_tokens.is_some(), - "max_completion_tokens", - self.max_completion_tokens.is_some(), - "cannot specify both max_tokens and max_completion_tokens", - )?; - - // Validate that tools and functions aren't both specified (deprecated) - utils::validate_conflicting_parameters( - "tools", - self.tools.is_some(), - "functions", - self.functions.is_some(), - "functions is deprecated, use tools instead", - )?; - - // Validate structured output constraints don't conflict with JSON response format - let has_json_format = matches!( - self.response_format, - Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. }) - ); - - utils::validate_conflicting_parameters( - "response_format", - has_json_format, - "regex", - self.regex.is_some(), - "cannot use regex constraint with JSON response format", - )?; - - utils::validate_conflicting_parameters( - "response_format", - has_json_format, - "ebnf", - self.ebnf.is_some(), - "cannot use EBNF constraint with JSON response format", - )?; - - // Only one structured output constraint should be active - let structured_constraints = [ - ("regex", self.regex.is_some()), - ("ebnf", self.ebnf.is_some()), - ( - "json_schema", - matches!( - self.response_format, - Some(ResponseFormat::JsonSchema { .. }) - ), - ), - ]; - - utils::validate_mutually_exclusive_options( - &structured_constraints, - "Only one structured output constraint (regex, ebnf, or json_schema) can be active at a time", - )?; - - Ok(()) - } -} - -impl ValidatableRequest for ChatCompletionRequest { - fn validate(&self) -> Result<(), ValidationError> { - // Call the common validation function from the validation module - utils::validate_common_request_params(self)?; - - // Then validate chat-specific parameters - self.validate_messages()?; - self.validate_response_format()?; - self.validate_chat_logprobs()?; - self.validate_chat_cross_parameters()?; - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::constants::*; - use super::utils::*; - use super::*; - use crate::protocols::spec::StringOrArray; - - // Mock request type for testing validation traits - #[derive(Debug, Default)] - struct MockRequest { - temperature: Option, - stop: Option, - max_tokens: Option, - min_tokens: Option, - } - - impl SamplingOptionsProvider for MockRequest { - fn get_temperature(&self) -> Option { - self.temperature - } - fn get_top_p(&self) -> Option { - None - } - fn get_frequency_penalty(&self) -> Option { - None - } - fn get_presence_penalty(&self) -> Option { - None - } - } - - impl StopConditionsProvider for MockRequest { - fn get_stop_sequences(&self) -> Option<&StringOrArray> { - self.stop.as_ref() - } - } - - impl TokenLimitsProvider for MockRequest { - fn get_max_tokens(&self) -> Option { - self.max_tokens - } - fn get_min_tokens(&self) -> Option { - self.min_tokens - } - } - - impl LogProbsProvider for MockRequest { - fn get_logprobs(&self) -> Option { - None - } - fn get_top_logprobs(&self) -> Option { - None - } - } - - impl SGLangExtensionsProvider for MockRequest {} - impl CompletionCountProvider for MockRequest {} - impl ValidatableRequest for MockRequest {} - - #[test] - fn test_range_validation() { - // Valid range - assert!(validate_range(1.5f32, &TEMPERATURE_RANGE, "temperature").is_ok()); - // Invalid range - assert!(validate_range(-0.1f32, &TEMPERATURE_RANGE, "temperature").is_err()); - assert!(validate_range(3.0f32, &TEMPERATURE_RANGE, "temperature").is_err()); - } - - #[test] - fn test_sglang_top_k_validation() { - assert!(validate_top_k(-1).is_ok()); // Disabled - assert!(validate_top_k(50).is_ok()); // Valid positive - assert!(validate_top_k(0).is_err()); // Invalid - assert!(validate_top_k(-5).is_err()); // Invalid - } - - #[test] - fn test_stop_sequences_limits() { - let request = MockRequest { - stop: Some(StringOrArray::Array(vec![ - "stop1".to_string(), - "stop2".to_string(), - "stop3".to_string(), - "stop4".to_string(), - "stop5".to_string(), // Too many - ])), - ..Default::default() - }; - assert!(request.validate().is_err()); - } - - #[test] - fn test_token_limits_conflict() { - let request = MockRequest { - min_tokens: Some(100), - max_tokens: Some(50), // min > max - ..Default::default() - }; - assert!(request.validate().is_err()); - } - - #[test] - fn test_valid_request() { - let request = MockRequest { - temperature: Some(1.0), - stop: Some(StringOrArray::Array(vec!["stop".to_string()])), - max_tokens: Some(100), - min_tokens: Some(10), - }; - assert!(request.validate().is_ok()); - } - - // Chat completion specific tests - #[cfg(test)] - mod chat_tests { - use super::*; - - #[allow(deprecated)] - fn create_valid_chat_request() -> ChatCompletionRequest { - ChatCompletionRequest { - messages: vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Text("Hello".to_string()), - name: None, - }], - model: "gpt-4".to_string(), - // Set specific fields we want to test - temperature: Some(1.0), - top_p: Some(0.9), - n: Some(1), - max_tokens: Some(100), - frequency_penalty: Some(0.0), - presence_penalty: Some(0.0), - // Use default for all other fields - ..Default::default() - } - } - - #[test] - fn test_chat_validation_basics() { - // Valid request - assert!(create_valid_chat_request().validate().is_ok()); - - // Empty messages - let mut request = create_valid_chat_request(); - request.messages = vec![]; - assert!(request.validate().is_err()); - - // Invalid temperature - let mut request = create_valid_chat_request(); - request.temperature = Some(3.0); - assert!(request.validate().is_err()); - } - - #[test] - #[allow(deprecated)] - fn test_chat_cross_parameter_conflicts() { - let mut request = create_valid_chat_request(); - - request.max_tokens = Some(100); - request.max_completion_tokens = Some(200); - assert!( - request.validate().is_err(), - "Should reject both max_tokens and max_completion_tokens" - ); - - // Reset for next test - request.max_tokens = None; - request.max_completion_tokens = None; - - request.tools = Some(vec![]); - request.functions = Some(vec![]); - assert!( - request.validate().is_err(), - "Should reject both tools and functions" - ); - - let mut request = create_valid_chat_request(); - request.logprobs = true; - request.top_logprobs = None; - assert!( - request.validate().is_ok(), - "logprobs=true without top_logprobs should be valid" - ); - - let mut request = create_valid_chat_request(); - request.logprobs = false; - request.top_logprobs = Some(5); - assert!( - request.validate().is_err(), - "top_logprobs without logprobs=true should fail" - ); - } - - #[test] - fn test_sglang_extensions() { - let mut request = create_valid_chat_request(); - - // Valid SGLang parameters - request.top_k = Some(-1); - request.min_p = Some(0.1); - request.repetition_penalty = Some(1.2); - assert!(request.validate().is_ok()); - - // Invalid parameters - request.top_k = Some(0); // Invalid - assert!(request.validate().is_err()); - } - - #[test] - fn test_parameter_ranges() { - let mut request = create_valid_chat_request(); - - request.temperature = Some(1.5); - assert!(request.validate().is_ok()); - request.temperature = Some(-0.1); - assert!(request.validate().is_err()); - request.temperature = Some(3.0); - assert!(request.validate().is_err()); - - request.temperature = Some(1.0); // Reset - request.top_p = Some(0.9); - assert!(request.validate().is_ok()); - request.top_p = Some(-0.1); - assert!(request.validate().is_err()); - request.top_p = Some(1.5); - assert!(request.validate().is_err()); - - request.top_p = Some(0.9); // Reset - request.frequency_penalty = Some(1.5); - assert!(request.validate().is_ok()); - request.frequency_penalty = Some(-2.5); - assert!(request.validate().is_err()); - request.frequency_penalty = Some(3.0); - assert!(request.validate().is_err()); - - request.frequency_penalty = Some(0.0); // Reset - request.presence_penalty = Some(-1.5); - assert!(request.validate().is_ok()); - request.presence_penalty = Some(-3.0); - assert!(request.validate().is_err()); - request.presence_penalty = Some(2.5); - assert!(request.validate().is_err()); - - request.presence_penalty = Some(0.0); // Reset - request.repetition_penalty = Some(1.2); - assert!(request.validate().is_ok()); - request.repetition_penalty = Some(-0.1); - assert!(request.validate().is_err()); - request.repetition_penalty = Some(2.1); - assert!(request.validate().is_err()); - - request.repetition_penalty = Some(1.0); // Reset - request.min_p = Some(0.5); - assert!(request.validate().is_ok()); - request.min_p = Some(-0.1); - assert!(request.validate().is_err()); - request.min_p = Some(1.5); - assert!(request.validate().is_err()); - } - - #[test] - fn test_structured_output_conflicts() { - let mut request = create_valid_chat_request(); - - // JSON response format with regex should conflict - request.response_format = Some(ResponseFormat::JsonObject); - request.regex = Some(".*".to_string()); - assert!(request.validate().is_err()); - - // JSON response format with EBNF should conflict - request.regex = None; - request.ebnf = Some("grammar".to_string()); - assert!(request.validate().is_err()); - - // Multiple structured constraints should conflict - request.response_format = None; - request.regex = Some(".*".to_string()); - request.ebnf = Some("grammar".to_string()); - assert!(request.validate().is_err()); - - // Only one constraint should work - request.ebnf = None; - request.regex = Some(".*".to_string()); - assert!(request.validate().is_ok()); - - request.regex = None; - request.ebnf = Some("grammar".to_string()); - assert!(request.validate().is_ok()); - - request.ebnf = None; - request.response_format = Some(ResponseFormat::JsonObject); - assert!(request.validate().is_ok()); - } - - #[test] - fn test_stop_sequences_validation() { - let mut request = create_valid_chat_request(); - - // Valid stop sequences - request.stop = Some(StringOrArray::Array(vec![ - "stop1".to_string(), - "stop2".to_string(), - ])); - assert!(request.validate().is_ok()); - - // Too many stop sequences (max 4) - request.stop = Some(StringOrArray::Array(vec![ - "stop1".to_string(), - "stop2".to_string(), - "stop3".to_string(), - "stop4".to_string(), - "stop5".to_string(), - ])); - assert!(request.validate().is_err()); - - // Empty stop sequence should fail - request.stop = Some(StringOrArray::String("".to_string())); - assert!(request.validate().is_err()); - - // Empty string in array should fail - request.stop = Some(StringOrArray::Array(vec![ - "stop1".to_string(), - "".to_string(), - ])); - assert!(request.validate().is_err()); - } - - #[test] - fn test_logprobs_validation() { - let mut request = create_valid_chat_request(); - - // Valid logprobs configuration with top_logprobs - request.logprobs = true; - request.top_logprobs = Some(10); - assert!(request.validate().is_ok()); - - // logprobs=true without top_logprobs should be valid (OpenAI behavior) - request.top_logprobs = None; - assert!( - request.validate().is_ok(), - "logprobs=true without top_logprobs should be valid" - ); - - // top_logprobs without logprobs=true should fail - request.logprobs = false; - request.top_logprobs = Some(10); - assert!(request.validate().is_err()); - - // top_logprobs out of range (0-20) - request.logprobs = true; - request.top_logprobs = Some(25); - assert!(request.validate().is_err()); - } - - #[test] - fn test_n_parameter_validation() { - let mut request = create_valid_chat_request(); - - // Valid n values (1-10) - request.n = Some(1); - assert!(request.validate().is_ok()); - request.n = Some(5); - assert!(request.validate().is_ok()); - request.n = Some(10); - assert!(request.validate().is_ok()); - - // Invalid n values - request.n = Some(0); - assert!(request.validate().is_err()); - request.n = Some(15); - assert!(request.validate().is_err()); - } - - #[test] - #[allow(deprecated)] - fn test_min_max_tokens_validation() { - let mut request = create_valid_chat_request(); - - // Valid token limits - request.min_tokens = Some(10); - request.max_tokens = Some(100); - assert!(request.validate().is_ok()); - - // min_tokens > max_tokens should fail - request.min_tokens = Some(150); - request.max_tokens = Some(100); - assert!(request.validate().is_err()); - - // Should work with max_completion_tokens instead - request.max_tokens = None; - request.max_completion_tokens = Some(200); - request.min_tokens = Some(50); - assert!(request.validate().is_ok()); - - // min_tokens > max_completion_tokens should fail - request.min_tokens = Some(250); - assert!(request.validate().is_err()); - } - } -} diff --git a/sgl-router/src/routers/grpc/utils.rs b/sgl-router/src/routers/grpc/utils.rs index 709678685..9b4891a66 100644 --- a/sgl-router/src/routers/grpc/utils.rs +++ b/sgl-router/src/routers/grpc/utils.rs @@ -959,7 +959,6 @@ mod tests { #[test] fn test_transform_messages_string_format() { let messages = vec![ChatMessage::User { - role: "user".to_string(), content: UserMessageContent::Parts(vec![ ContentPart::Text { text: "Hello".to_string(), @@ -993,7 +992,6 @@ mod tests { #[test] fn test_transform_messages_openai_format() { let messages = vec![ChatMessage::User { - role: "user".to_string(), content: UserMessageContent::Parts(vec![ ContentPart::Text { text: "Describe this image:".to_string(), @@ -1028,7 +1026,6 @@ mod tests { #[test] fn test_transform_messages_simple_string_content() { let messages = vec![ChatMessage::User { - role: "user".to_string(), content: UserMessageContent::Text("Simple text message".to_string()), name: None, }]; @@ -1049,12 +1046,10 @@ mod tests { fn test_transform_messages_multiple_messages() { let messages = vec![ ChatMessage::System { - role: "system".to_string(), content: "System prompt".to_string(), name: None, }, ChatMessage::User { - role: "user".to_string(), content: UserMessageContent::Parts(vec![ ContentPart::Text { text: "User message".to_string(), @@ -1086,7 +1081,6 @@ mod tests { #[test] fn test_transform_messages_empty_text_parts() { let messages = vec![ChatMessage::User { - role: "user".to_string(), content: UserMessageContent::Parts(vec![ContentPart::ImageUrl { image_url: ImageUrl { url: "https://example.com/image.jpg".to_string(), @@ -1109,12 +1103,10 @@ mod tests { fn test_transform_messages_mixed_content_types() { let messages = vec![ ChatMessage::User { - role: "user".to_string(), content: UserMessageContent::Text("Plain text".to_string()), name: None, }, ChatMessage::User { - role: "user".to_string(), content: UserMessageContent::Parts(vec![ ContentPart::Text { text: "With image".to_string(), diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index d1d634e6b..7ae721924 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -16,6 +16,7 @@ use crate::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput, }, + validated::ValidatedJson, worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}, }, reasoning_parser::ParserFactory as ReasoningParserFactory, @@ -291,7 +292,7 @@ async fn generate( async fn v1_chat_completions( State(state): State>, headers: http::HeaderMap, - Json(body): Json, + ValidatedJson(body): ValidatedJson, ) -> Response { state.router.route_chat(Some(&headers), &body, None).await } diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index b29f70b48..44392a5e2 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -1461,39 +1461,6 @@ mod error_tests { ctx.shutdown().await; } - #[tokio::test] - async fn test_missing_required_fields() { - let ctx = TestContext::new(vec![MockWorkerConfig { - port: 18405, - worker_type: WorkerType::Regular, - health_status: HealthStatus::Healthy, - response_delay_ms: 0, - fail_rate: 0.0, - }]) - .await; - - let app = ctx.create_app().await; - - // Missing messages in chat completion - let payload = json!({ - "model": "test-model" - // missing "messages" - }); - - let req = Request::builder() - .method("POST") - .uri("/v1/chat/completions") - .header(CONTENT_TYPE, "application/json") - .body(Body::from(serde_json::to_string(&payload).unwrap())) - .unwrap(); - - let resp = app.oneshot(req).await.unwrap(); - // Axum validates JSON schema - returns 422 for validation errors - assert_eq!(resp.status(), StatusCode::UNPROCESSABLE_ENTITY); - - ctx.shutdown().await; - } - #[tokio::test] async fn test_invalid_model() { let ctx = TestContext::new(vec![MockWorkerConfig { diff --git a/sgl-router/tests/chat_template_format_detection.rs b/sgl-router/tests/chat_template_format_detection.rs index 145cb8227..3efa6676d 100644 --- a/sgl-router/tests/chat_template_format_detection.rs +++ b/sgl-router/tests/chat_template_format_detection.rs @@ -172,14 +172,12 @@ assistant: let processor = ChatTemplateProcessor::new(template.to_string()); - let messages = vec![ + let messages = [ spec::ChatMessage::System { - role: "system".to_string(), content: "You are helpful".to_string(), name: None, }, spec::ChatMessage::User { - role: "user".to_string(), content: spec::UserMessageContent::Text("Hello".to_string()), name: None, }, @@ -216,7 +214,6 @@ fn test_chat_template_with_tokens_unit_test() { let processor = ChatTemplateProcessor::new(template.to_string()); let messages = [spec::ChatMessage::User { - role: "user".to_string(), content: spec::UserMessageContent::Text("Test".to_string()), name: None, }]; diff --git a/sgl-router/tests/chat_template_integration.rs b/sgl-router/tests/chat_template_integration.rs index ac25a3f10..077b813b4 100644 --- a/sgl-router/tests/chat_template_integration.rs +++ b/sgl-router/tests/chat_template_integration.rs @@ -18,7 +18,6 @@ fn test_simple_chat_template() { let processor = ChatTemplateProcessor::new(template.to_string()); let messages = [spec::ChatMessage::User { - role: "user".to_string(), content: spec::UserMessageContent::Text("Test".to_string()), name: None, }]; @@ -53,7 +52,6 @@ fn test_chat_template_with_tokens() { let processor = ChatTemplateProcessor::new(template.to_string()); let messages = [spec::ChatMessage::User { - role: "user".to_string(), content: spec::UserMessageContent::Text("Test".to_string()), name: None, }]; @@ -113,14 +111,12 @@ fn test_llama_style_template() { let processor = ChatTemplateProcessor::new(template.to_string()); - let messages = vec![ + let messages = [ spec::ChatMessage::System { - role: "system".to_string(), content: "You are a helpful assistant".to_string(), name: None, }, spec::ChatMessage::User { - role: "user".to_string(), content: spec::UserMessageContent::Text("What is 2+2?".to_string()), name: None, }, @@ -172,19 +168,16 @@ fn test_chatml_template() { let messages = vec![ spec::ChatMessage::User { - role: "user".to_string(), content: spec::UserMessageContent::Text("Hello".to_string()), name: None, }, spec::ChatMessage::Assistant { - role: "assistant".to_string(), content: Some("Hi there!".to_string()), name: None, tool_calls: None, reasoning_content: None, }, spec::ChatMessage::User { - role: "user".to_string(), content: spec::UserMessageContent::Text("How are you?".to_string()), name: None, }, @@ -227,7 +220,6 @@ assistant: let processor = ChatTemplateProcessor::new(template.to_string()); let messages = [spec::ChatMessage::User { - role: "user".to_string(), content: spec::UserMessageContent::Text("Test".to_string()), name: None, }]; @@ -315,7 +307,6 @@ fn test_template_with_multimodal_content() { let processor = ChatTemplateProcessor::new(template.to_string()); let messages = [spec::ChatMessage::User { - role: "user".to_string(), content: spec::UserMessageContent::Parts(vec![ spec::ContentPart::Text { text: "Look at this:".to_string(), diff --git a/sgl-router/tests/chat_template_loading.rs b/sgl-router/tests/chat_template_loading.rs index b3a5a3e70..336b17126 100644 --- a/sgl-router/tests/chat_template_loading.rs +++ b/sgl-router/tests/chat_template_loading.rs @@ -57,14 +57,12 @@ mod tests { ) .unwrap(); - let messages = vec![ + let messages = [ spec::ChatMessage::User { - role: "user".to_string(), content: spec::UserMessageContent::Text("Hello".to_string()), name: None, }, spec::ChatMessage::Assistant { - role: "assistant".to_string(), content: Some("Hi there".to_string()), name: None, tool_calls: None, @@ -143,7 +141,6 @@ mod tests { .unwrap(); let messages = [spec::ChatMessage::User { - role: "user".to_string(), content: spec::UserMessageContent::Text("Test".to_string()), name: None, }]; @@ -201,14 +198,12 @@ mod tests { "NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}"; tokenizer.set_chat_template(new_template.to_string()); - let messages = vec![ + let messages = [ spec::ChatMessage::User { - role: "user".to_string(), content: spec::UserMessageContent::Text("Hello".to_string()), name: None, }, spec::ChatMessage::Assistant { - role: "assistant".to_string(), content: Some("World".to_string()), name: None, tool_calls: None, diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs index 9288a9b06..4a2ed2f90 100644 --- a/sgl-router/tests/common/mod.rs +++ b/sgl-router/tests/common/mod.rs @@ -119,6 +119,7 @@ pub fn create_test_tools() -> Vec { "query": {"type": "string"} } }), + strict: None, }, }, Tool { @@ -135,6 +136,7 @@ pub fn create_test_tools() -> Vec { "units": {"type": "string"} } }), + strict: None, }, }, Tool { @@ -149,6 +151,7 @@ pub fn create_test_tools() -> Vec { "y": {"type": "number"} } }), + strict: None, }, }, Tool { @@ -164,6 +167,7 @@ pub fn create_test_tools() -> Vec { "target_lang": {"type": "string"} } }), + strict: None, }, }, Tool { @@ -178,6 +182,7 @@ pub fn create_test_tools() -> Vec { "format": {"type": "string"} } }), + strict: None, }, }, Tool { @@ -192,6 +197,7 @@ pub fn create_test_tools() -> Vec { "format": {"type": "string"} } }), + strict: None, }, }, Tool { @@ -206,6 +212,7 @@ pub fn create_test_tools() -> Vec { "notifications": {"type": "boolean"} } }), + strict: None, }, }, Tool { @@ -214,6 +221,7 @@ pub fn create_test_tools() -> Vec { name: "ping".to_string(), description: Some("Ping service".to_string()), parameters: json!({"type": "object", "properties": {}}), + strict: None, }, }, Tool { @@ -222,6 +230,7 @@ pub fn create_test_tools() -> Vec { name: "test".to_string(), description: Some("Test function".to_string()), parameters: json!({"type": "object", "properties": {}}), + strict: None, }, }, Tool { @@ -239,6 +248,7 @@ pub fn create_test_tools() -> Vec { "text": {"type": "string"} } }), + strict: None, }, }, Tool { @@ -254,6 +264,7 @@ pub fn create_test_tools() -> Vec { "search_type": {"type": "string"} } }), + strict: None, }, }, Tool { @@ -267,6 +278,7 @@ pub fn create_test_tools() -> Vec { "city": {"type": "string"} } }), + strict: None, }, }, Tool { @@ -282,6 +294,7 @@ pub fn create_test_tools() -> Vec { "optional": {"type": "null"} } }), + strict: None, }, }, Tool { @@ -297,6 +310,7 @@ pub fn create_test_tools() -> Vec { "none_val": {"type": "null"} } }), + strict: None, }, }, Tool { @@ -311,6 +325,7 @@ pub fn create_test_tools() -> Vec { "email": {"type": "string"} } }), + strict: None, }, }, Tool { @@ -325,6 +340,7 @@ pub fn create_test_tools() -> Vec { "y": {"type": "number"} } }), + strict: None, }, }, Tool { @@ -338,6 +354,7 @@ pub fn create_test_tools() -> Vec { "x": {"type": "number"} } }), + strict: None, }, }, Tool { @@ -346,6 +363,7 @@ pub fn create_test_tools() -> Vec { name: "func1".to_string(), description: Some("Function 1".to_string()), parameters: json!({"type": "object", "properties": {}}), + strict: None, }, }, Tool { @@ -359,6 +377,7 @@ pub fn create_test_tools() -> Vec { "y": {"type": "number"} } }), + strict: None, }, }, Tool { @@ -367,6 +386,7 @@ pub fn create_test_tools() -> Vec { name: "tool1".to_string(), description: Some("Tool 1".to_string()), parameters: json!({"type": "object", "properties": {}}), + strict: None, }, }, Tool { @@ -380,6 +400,7 @@ pub fn create_test_tools() -> Vec { "y": {"type": "number"} } }), + strict: None, }, }, ] diff --git a/sgl-router/tests/spec/chat_completion.rs b/sgl-router/tests/spec/chat_completion.rs new file mode 100644 index 000000000..87eade72a --- /dev/null +++ b/sgl-router/tests/spec/chat_completion.rs @@ -0,0 +1,575 @@ +use serde_json::json; +use sglang_router_rs::protocols::spec::{ + ChatCompletionRequest, ChatMessage, Function, FunctionCall, FunctionChoice, StreamOptions, + Tool, ToolChoice, ToolChoiceValue, ToolReference, UserMessageContent, +}; +use sglang_router_rs::protocols::validated::Normalizable; +use validator::Validate; + +// Deprecated fields normalization tests + +#[test] +fn test_max_tokens_normalizes_to_max_completion_tokens() { + #[allow(deprecated)] + let mut req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + max_tokens: Some(100), + max_completion_tokens: None, + ..Default::default() + }; + + req.normalize(); + assert_eq!( + req.max_completion_tokens, + Some(100), + "max_tokens should be copied to max_completion_tokens" + ); + #[allow(deprecated)] + { + assert!( + req.max_tokens.is_none(), + "Deprecated field should be cleared" + ); + } + assert!( + req.validate().is_ok(), + "Should be valid after normalization" + ); +} + +#[test] +fn test_max_completion_tokens_takes_precedence() { + #[allow(deprecated)] + let mut req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + max_tokens: Some(100), + max_completion_tokens: Some(200), + ..Default::default() + }; + + req.normalize(); + assert_eq!( + req.max_completion_tokens, + Some(200), + "max_completion_tokens should take precedence" + ); + assert!( + req.validate().is_ok(), + "Should be valid after normalization" + ); +} + +#[test] +fn test_functions_normalizes_to_tools() { + #[allow(deprecated)] + let mut req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + functions: Some(vec![Function { + name: "test_func".to_string(), + description: Some("Test function".to_string()), + parameters: json!({}), + strict: None, + }]), + tools: None, + ..Default::default() + }; + + req.normalize(); + assert!(req.tools.is_some(), "functions should be migrated to tools"); + assert_eq!(req.tools.as_ref().unwrap().len(), 1); + assert_eq!(req.tools.as_ref().unwrap()[0].function.name, "test_func"); + #[allow(deprecated)] + { + assert!( + req.functions.is_none(), + "Deprecated field should be cleared" + ); + } + assert!( + req.validate().is_ok(), + "Should be valid after normalization" + ); +} + +#[test] +fn test_function_call_normalizes_to_tool_choice() { + #[allow(deprecated)] + let mut req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + function_call: Some(FunctionCall::None), + tool_choice: None, + ..Default::default() + }; + + req.normalize(); + assert!( + req.tool_choice.is_some(), + "function_call should be migrated to tool_choice" + ); + assert!(matches!( + req.tool_choice, + Some(ToolChoice::Value(ToolChoiceValue::None)) + )); + #[allow(deprecated)] + { + assert!( + req.function_call.is_none(), + "Deprecated field should be cleared" + ); + } + assert!( + req.validate().is_ok(), + "Should be valid after normalization" + ); +} + +#[test] +fn test_function_call_function_variant_normalizes() { + #[allow(deprecated)] + let mut req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + function_call: Some(FunctionCall::Function { + name: "my_function".to_string(), + }), + tool_choice: None, + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "my_function".to_string(), + description: None, + parameters: json!({}), + strict: None, + }, + }]), + ..Default::default() + }; + + req.normalize(); + assert!( + req.tool_choice.is_some(), + "function_call should be migrated to tool_choice" + ); + match &req.tool_choice { + Some(ToolChoice::Function { function, .. }) => { + assert_eq!(function.name, "my_function"); + } + _ => panic!("Expected ToolChoice::Function variant"), + } + #[allow(deprecated)] + { + assert!( + req.function_call.is_none(), + "Deprecated field should be cleared" + ); + } + assert!( + req.validate().is_ok(), + "Should be valid after normalization" + ); +} + +// Stream options validation tests + +#[test] +fn test_stream_options_requires_stream_enabled() { + let req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + stream: false, + stream_options: Some(StreamOptions { + include_usage: Some(true), + }), + ..Default::default() + }; + + let result = req.validate(); + assert!( + result.is_err(), + "Should reject stream_options when stream is false" + ); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("stream_options") && err.contains("stream") && err.contains("enabled"), + "Error should mention stream dependency: {}", + err + ); +} + +#[test] +fn test_stream_options_valid_when_stream_enabled() { + let req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + stream: true, + stream_options: Some(StreamOptions { + include_usage: Some(true), + }), + ..Default::default() + }; + + let result = req.validate(); + assert!( + result.is_ok(), + "Should accept stream_options when stream is true" + ); +} + +#[test] +fn test_no_stream_options_valid_when_stream_disabled() { + let req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + stream: false, + stream_options: None, + ..Default::default() + }; + + let result = req.validate(); + assert!( + result.is_ok(), + "Should accept no stream_options when stream is false" + ); +} + +// Tool choice validation tests +#[test] +fn test_tool_choice_function_not_found() { + let req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({}), + strict: None, + }, + }]), + tool_choice: Some(ToolChoice::Function { + function: FunctionChoice { + name: "nonexistent_function".to_string(), + }, + tool_type: "function".to_string(), + }), + ..Default::default() + }; + + let result = req.validate(); + assert!(result.is_err(), "Should reject nonexistent function name"); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("function 'nonexistent_function' not found"), + "Error should mention the missing function: {}", + err + ); +} + +#[test] +fn test_tool_choice_function_exists_valid() { + let req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({}), + strict: None, + }, + }]), + tool_choice: Some(ToolChoice::Function { + function: FunctionChoice { + name: "get_weather".to_string(), + }, + tool_type: "function".to_string(), + }), + ..Default::default() + }; + + let result = req.validate(); + assert!(result.is_ok(), "Should accept existing function name"); +} + +#[test] +fn test_tool_choice_allowed_tools_invalid_mode() { + let req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({}), + strict: None, + }, + }]), + tool_choice: Some(ToolChoice::AllowedTools { + mode: "invalid_mode".to_string(), + tools: vec![ToolReference { + tool_type: "function".to_string(), + name: "get_weather".to_string(), + }], + tool_type: "function".to_string(), + }), + ..Default::default() + }; + + let result = req.validate(); + assert!(result.is_err(), "Should reject invalid mode"); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("must be 'auto' or 'required'"), + "Error should mention valid modes: {}", + err + ); +} + +#[test] +fn test_tool_choice_allowed_tools_valid_mode_auto() { + let req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({}), + strict: None, + }, + }]), + tool_choice: Some(ToolChoice::AllowedTools { + mode: "auto".to_string(), + tools: vec![ToolReference { + tool_type: "function".to_string(), + name: "get_weather".to_string(), + }], + tool_type: "function".to_string(), + }), + ..Default::default() + }; + + let result = req.validate(); + assert!(result.is_ok(), "Should accept 'auto' mode"); +} + +#[test] +fn test_tool_choice_allowed_tools_valid_mode_required() { + let req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({}), + strict: None, + }, + }]), + tool_choice: Some(ToolChoice::AllowedTools { + mode: "required".to_string(), + tools: vec![ToolReference { + tool_type: "function".to_string(), + name: "get_weather".to_string(), + }], + tool_type: "function".to_string(), + }), + ..Default::default() + }; + + let result = req.validate(); + assert!(result.is_ok(), "Should accept 'required' mode"); +} + +#[test] +fn test_tool_choice_allowed_tools_tool_not_found() { + let req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({}), + strict: None, + }, + }]), + tool_choice: Some(ToolChoice::AllowedTools { + mode: "auto".to_string(), + tools: vec![ToolReference { + tool_type: "function".to_string(), + name: "nonexistent_tool".to_string(), + }], + tool_type: "function".to_string(), + }), + ..Default::default() + }; + + let result = req.validate(); + assert!(result.is_err(), "Should reject nonexistent tool name"); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("tool 'nonexistent_tool' not found"), + "Error should mention the missing tool: {}", + err + ); +} + +#[test] +fn test_tool_choice_allowed_tools_multiple_tools_valid() { + let req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + tools: Some(vec![ + Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({}), + strict: None, + }, + }, + Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_time".to_string(), + description: Some("Get time".to_string()), + parameters: json!({}), + strict: None, + }, + }, + ]), + tool_choice: Some(ToolChoice::AllowedTools { + mode: "auto".to_string(), + tools: vec![ + ToolReference { + tool_type: "function".to_string(), + name: "get_weather".to_string(), + }, + ToolReference { + tool_type: "function".to_string(), + name: "get_time".to_string(), + }, + ], + tool_type: "function".to_string(), + }), + ..Default::default() + }; + + let result = req.validate(); + assert!(result.is_ok(), "Should accept all valid tool references"); +} + +#[test] +fn test_tool_choice_allowed_tools_one_invalid_among_valid() { + let req = ChatCompletionRequest { + model: "test-model".to_string(), + messages: vec![ChatMessage::User { + content: UserMessageContent::Text("hello".to_string()), + name: None, + }], + tools: Some(vec![ + Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({}), + strict: None, + }, + }, + Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_time".to_string(), + description: Some("Get time".to_string()), + parameters: json!({}), + strict: None, + }, + }, + ]), + tool_choice: Some(ToolChoice::AllowedTools { + mode: "auto".to_string(), + tools: vec![ + ToolReference { + tool_type: "function".to_string(), + name: "get_weather".to_string(), + }, + ToolReference { + tool_type: "function".to_string(), + name: "nonexistent_tool".to_string(), + }, + ], + tool_type: "function".to_string(), + }), + ..Default::default() + }; + + let result = req.validate(); + assert!( + result.is_err(), + "Should reject if any tool reference is invalid" + ); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("tool 'nonexistent_tool' not found"), + "Error should mention the missing tool: {}", + err + ); +} diff --git a/sgl-router/tests/spec/chat_message.rs b/sgl-router/tests/spec/chat_message.rs new file mode 100644 index 000000000..0b158522b --- /dev/null +++ b/sgl-router/tests/spec/chat_message.rs @@ -0,0 +1,83 @@ +use serde_json::json; +use sglang_router_rs::protocols::spec::{ChatMessage, UserMessageContent}; + +#[test] +fn test_chat_message_tagged_by_role_system() { + let json = json!({ + "role": "system", + "content": "You are a helpful assistant" + }); + + let msg: ChatMessage = serde_json::from_value(json).unwrap(); + match msg { + ChatMessage::System { content, .. } => { + assert_eq!(content, "You are a helpful assistant"); + } + _ => panic!("Expected System variant"), + } +} + +#[test] +fn test_chat_message_tagged_by_role_user() { + let json = json!({ + "role": "user", + "content": "Hello" + }); + + let msg: ChatMessage = serde_json::from_value(json).unwrap(); + match msg { + ChatMessage::User { content, .. } => match content { + UserMessageContent::Text(text) => assert_eq!(text, "Hello"), + _ => panic!("Expected text content"), + }, + _ => panic!("Expected User variant"), + } +} + +#[test] +fn test_chat_message_tagged_by_role_assistant() { + let json = json!({ + "role": "assistant", + "content": "Hi there!" + }); + + let msg: ChatMessage = serde_json::from_value(json).unwrap(); + match msg { + ChatMessage::Assistant { content, .. } => { + assert_eq!(content, Some("Hi there!".to_string())); + } + _ => panic!("Expected Assistant variant"), + } +} + +#[test] +fn test_chat_message_tagged_by_role_tool() { + let json = json!({ + "role": "tool", + "content": "Tool result", + "tool_call_id": "call_123" + }); + + let msg: ChatMessage = serde_json::from_value(json).unwrap(); + match msg { + ChatMessage::Tool { + content, + tool_call_id, + } => { + assert_eq!(content, "Tool result"); + assert_eq!(tool_call_id, "call_123"); + } + _ => panic!("Expected Tool variant"), + } +} + +#[test] +fn test_chat_message_wrong_role_rejected() { + let json = json!({ + "role": "invalid_role", + "content": "test" + }); + + let result = serde_json::from_value::(json); + assert!(result.is_err(), "Should reject invalid role"); +} diff --git a/sgl-router/tests/spec/embedding.rs b/sgl-router/tests/spec/embedding.rs new file mode 100644 index 000000000..718dd5602 --- /dev/null +++ b/sgl-router/tests/spec/embedding.rs @@ -0,0 +1,96 @@ +use serde_json::{from_str, json, to_string}; +use sglang_router_rs::protocols::spec::{EmbeddingRequest, GenerationRequest}; + +#[test] +fn test_embedding_request_serialization_string_input() { + let req = EmbeddingRequest { + model: "test-emb".to_string(), + input: json!("hello"), + encoding_format: Some("float".to_string()), + user: Some("user-1".to_string()), + dimensions: Some(128), + rid: Some("rid-123".to_string()), + }; + + let serialized = to_string(&req).unwrap(); + let deserialized: EmbeddingRequest = from_str(&serialized).unwrap(); + + assert_eq!(deserialized.model, req.model); + assert_eq!(deserialized.input, req.input); + assert_eq!(deserialized.encoding_format, req.encoding_format); + assert_eq!(deserialized.user, req.user); + assert_eq!(deserialized.dimensions, req.dimensions); + assert_eq!(deserialized.rid, req.rid); +} + +#[test] +fn test_embedding_request_serialization_array_input() { + let req = EmbeddingRequest { + model: "test-emb".to_string(), + input: json!(["a", "b", "c"]), + encoding_format: None, + user: None, + dimensions: None, + rid: None, + }; + + let serialized = to_string(&req).unwrap(); + let de: EmbeddingRequest = from_str(&serialized).unwrap(); + assert_eq!(de.model, req.model); + assert_eq!(de.input, req.input); +} + +#[test] +fn test_embedding_generation_request_trait_string() { + let req = EmbeddingRequest { + model: "emb-model".to_string(), + input: json!("hello"), + encoding_format: None, + user: None, + dimensions: None, + rid: None, + }; + assert!(!req.is_stream()); + assert_eq!(req.get_model(), Some("emb-model")); + assert_eq!(req.extract_text_for_routing(), "hello"); +} + +#[test] +fn test_embedding_generation_request_trait_array() { + let req = EmbeddingRequest { + model: "emb-model".to_string(), + input: json!(["hello", "world"]), + encoding_format: None, + user: None, + dimensions: None, + rid: None, + }; + assert_eq!(req.extract_text_for_routing(), "hello world"); +} + +#[test] +fn test_embedding_generation_request_trait_non_text() { + let req = EmbeddingRequest { + model: "emb-model".to_string(), + input: json!({"tokens": [1, 2, 3]}), + encoding_format: None, + user: None, + dimensions: None, + rid: None, + }; + assert_eq!(req.extract_text_for_routing(), ""); +} + +#[test] +fn test_embedding_generation_request_trait_mixed_array_ignores_nested() { + let req = EmbeddingRequest { + model: "emb-model".to_string(), + input: json!(["a", ["b", "c"], 123, {"k": "v"}]), + encoding_format: None, + user: None, + dimensions: None, + rid: None, + }; + // Only top-level string elements are extracted + assert_eq!(req.extract_text_for_routing(), "a"); +} diff --git a/sgl-router/tests/spec/mod.rs b/sgl-router/tests/spec/mod.rs new file mode 100644 index 000000000..3a9582a58 --- /dev/null +++ b/sgl-router/tests/spec/mod.rs @@ -0,0 +1,8 @@ +// Protocol specification tests +// These tests were originally in src/protocols/spec.rs and have been moved here +// to reduce the size of that file and improve test organization. + +mod chat_completion; +mod chat_message; +mod embedding; +mod rerank; diff --git a/sgl-router/tests/spec/rerank.rs b/sgl-router/tests/spec/rerank.rs new file mode 100644 index 000000000..88b296298 --- /dev/null +++ b/sgl-router/tests/spec/rerank.rs @@ -0,0 +1,613 @@ +use serde_json::{from_str, to_string, Number, Value}; +use sglang_router_rs::protocols::spec::{ + default_model_name, GenerationRequest, RerankRequest, RerankResponse, RerankResult, + StringOrArray, UsageInfo, V1RerankReqInput, +}; +use std::collections::HashMap; + +#[test] +fn test_rerank_request_serialization() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + model: "test-model".to_string(), + top_k: Some(5), + return_documents: true, + rid: Some(StringOrArray::String("req-123".to_string())), + user: Some("user-456".to_string()), + }; + + let serialized = to_string(&request).unwrap(); + let deserialized: RerankRequest = from_str(&serialized).unwrap(); + + assert_eq!(deserialized.query, request.query); + assert_eq!(deserialized.documents, request.documents); + assert_eq!(deserialized.model, request.model); + assert_eq!(deserialized.top_k, request.top_k); + assert_eq!(deserialized.return_documents, request.return_documents); + assert_eq!(deserialized.rid, request.rid); + assert_eq!(deserialized.user, request.user); +} + +#[test] +fn test_rerank_request_deserialization_with_defaults() { + let json = r#"{ + "query": "test query", + "documents": ["doc1", "doc2"] + }"#; + + let request: RerankRequest = from_str(json).unwrap(); + + assert_eq!(request.query, "test query"); + assert_eq!(request.documents, vec!["doc1", "doc2"]); + assert_eq!(request.model, default_model_name()); + assert_eq!(request.top_k, None); + assert!(request.return_documents); + assert_eq!(request.rid, None); + assert_eq!(request.user, None); +} + +#[test] +fn test_rerank_request_validation_success() { + let request = RerankRequest { + query: "valid query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + model: "test-model".to_string(), + top_k: Some(2), + return_documents: true, + rid: None, + user: None, + }; + + assert!(request.validate().is_ok()); +} + +#[test] +fn test_rerank_request_validation_empty_query() { + let request = RerankRequest { + query: "".to_string(), + documents: vec!["doc1".to_string()], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: None, + user: None, + }; + + let result = request.validate(); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Query cannot be empty"); +} + +#[test] +fn test_rerank_request_validation_whitespace_query() { + let request = RerankRequest { + query: " ".to_string(), + documents: vec!["doc1".to_string()], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: None, + user: None, + }; + + let result = request.validate(); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Query cannot be empty"); +} + +#[test] +fn test_rerank_request_validation_empty_documents() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec![], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: None, + user: None, + }; + + let result = request.validate(); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Documents list cannot be empty"); +} + +#[test] +fn test_rerank_request_validation_top_k_zero() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + model: "test-model".to_string(), + top_k: Some(0), + return_documents: true, + rid: None, + user: None, + }; + + let result = request.validate(); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "top_k must be greater than 0"); +} + +#[test] +fn test_rerank_request_validation_top_k_greater_than_docs() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + model: "test-model".to_string(), + top_k: Some(5), + return_documents: true, + rid: None, + user: None, + }; + + // This should pass but log a warning + assert!(request.validate().is_ok()); +} + +#[test] +fn test_rerank_request_effective_top_k() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()], + model: "test-model".to_string(), + top_k: Some(2), + return_documents: true, + rid: None, + user: None, + }; + + assert_eq!(request.effective_top_k(), 2); +} + +#[test] +fn test_rerank_request_effective_top_k_none() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: None, + user: None, + }; + + assert_eq!(request.effective_top_k(), 3); +} + +#[test] +fn test_rerank_response_creation() { + let results = vec![ + RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }, + RerankResult { + score: 0.6, + document: Some("doc2".to_string()), + index: 1, + meta_info: None, + }, + ]; + + let response = RerankResponse::new( + results.clone(), + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); + + assert_eq!(response.results.len(), 2); + assert_eq!(response.model, "test-model"); + assert_eq!( + response.id, + Some(StringOrArray::String("req-123".to_string())) + ); + assert_eq!(response.object, "rerank"); + assert!(response.created > 0); +} + +#[test] +fn test_rerank_response_serialization() { + let results = vec![RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }]; + + let response = RerankResponse::new( + results, + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); + + let serialized = to_string(&response).unwrap(); + let deserialized: RerankResponse = from_str(&serialized).unwrap(); + + assert_eq!(deserialized.results.len(), response.results.len()); + assert_eq!(deserialized.model, response.model); + assert_eq!(deserialized.id, response.id); + assert_eq!(deserialized.object, response.object); +} + +#[test] +fn test_rerank_response_sort_by_score() { + let results = vec![ + RerankResult { + score: 0.6, + document: Some("doc2".to_string()), + index: 1, + meta_info: None, + }, + RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }, + RerankResult { + score: 0.4, + document: Some("doc3".to_string()), + index: 2, + meta_info: None, + }, + ]; + + let mut response = RerankResponse::new( + results, + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); + + response.sort_by_score(); + + assert_eq!(response.results[0].score, 0.8); + assert_eq!(response.results[0].index, 0); + assert_eq!(response.results[1].score, 0.6); + assert_eq!(response.results[1].index, 1); + assert_eq!(response.results[2].score, 0.4); + assert_eq!(response.results[2].index, 2); +} + +#[test] +fn test_rerank_response_apply_top_k() { + let results = vec![ + RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }, + RerankResult { + score: 0.6, + document: Some("doc2".to_string()), + index: 1, + meta_info: None, + }, + RerankResult { + score: 0.4, + document: Some("doc3".to_string()), + index: 2, + meta_info: None, + }, + ]; + + let mut response = RerankResponse::new( + results, + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); + + response.apply_top_k(2); + + assert_eq!(response.results.len(), 2); + assert_eq!(response.results[0].score, 0.8); + assert_eq!(response.results[1].score, 0.6); +} + +#[test] +fn test_rerank_response_apply_top_k_larger_than_results() { + let results = vec![RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }]; + + let mut response = RerankResponse::new( + results, + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); + + response.apply_top_k(5); + + assert_eq!(response.results.len(), 1); +} + +#[test] +fn test_rerank_response_drop_documents() { + let results = vec![RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }]; + let mut response = RerankResponse::new( + results, + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); + + response.drop_documents(); + + assert_eq!(response.results[0].document, None); +} + +#[test] +fn test_rerank_result_serialization() { + let result = RerankResult { + score: 0.85, + document: Some("test document".to_string()), + index: 42, + meta_info: Some(HashMap::from([ + ("confidence".to_string(), Value::String("high".to_string())), + ( + "processing_time".to_string(), + Value::Number(Number::from(150)), + ), + ])), + }; + + let serialized = to_string(&result).unwrap(); + let deserialized: RerankResult = from_str(&serialized).unwrap(); + + assert_eq!(deserialized.score, result.score); + assert_eq!(deserialized.document, result.document); + assert_eq!(deserialized.index, result.index); + assert_eq!(deserialized.meta_info, result.meta_info); +} + +#[test] +fn test_rerank_result_serialization_without_document() { + let result = RerankResult { + score: 0.85, + document: None, + index: 42, + meta_info: None, + }; + + let serialized = to_string(&result).unwrap(); + let deserialized: RerankResult = from_str(&serialized).unwrap(); + + assert_eq!(deserialized.score, result.score); + assert_eq!(deserialized.document, result.document); + assert_eq!(deserialized.index, result.index); + assert_eq!(deserialized.meta_info, result.meta_info); +} + +#[test] +fn test_v1_rerank_req_input_serialization() { + let v1_input = V1RerankReqInput { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + }; + + let serialized = to_string(&v1_input).unwrap(); + let deserialized: V1RerankReqInput = from_str(&serialized).unwrap(); + + assert_eq!(deserialized.query, v1_input.query); + assert_eq!(deserialized.documents, v1_input.documents); +} + +#[test] +fn test_v1_to_rerank_request_conversion() { + let v1_input = V1RerankReqInput { + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + }; + + let request: RerankRequest = v1_input.into(); + + assert_eq!(request.query, "test query"); + assert_eq!(request.documents, vec!["doc1", "doc2"]); + assert_eq!(request.model, default_model_name()); + assert_eq!(request.top_k, None); + assert!(request.return_documents); + assert_eq!(request.rid, None); + assert_eq!(request.user, None); +} + +#[test] +fn test_rerank_request_generation_request_trait() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string()], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: None, + user: None, + }; + + assert_eq!(request.get_model(), Some("test-model")); + assert!(!request.is_stream()); + assert_eq!(request.extract_text_for_routing(), "test query"); +} + +#[test] +fn test_rerank_request_very_long_query() { + let long_query = "a".repeat(100000); + let request = RerankRequest { + query: long_query, + documents: vec!["doc1".to_string()], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: None, + user: None, + }; + + assert!(request.validate().is_ok()); +} + +#[test] +fn test_rerank_request_many_documents() { + let documents: Vec = (0..1000).map(|i| format!("doc{}", i)).collect(); + let request = RerankRequest { + query: "test query".to_string(), + documents, + model: "test-model".to_string(), + top_k: Some(100), + return_documents: true, + rid: None, + user: None, + }; + + assert!(request.validate().is_ok()); + assert_eq!(request.effective_top_k(), 100); +} + +#[test] +fn test_rerank_request_special_characters() { + let request = RerankRequest { + query: "query with émojis 🚀 and unicode: 测试".to_string(), + documents: vec![ + "doc with émojis 🎉".to_string(), + "doc with unicode: 测试".to_string(), + ], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: Some(StringOrArray::String("req-🚀-123".to_string())), + user: Some("user-🎉-456".to_string()), + }; + + assert!(request.validate().is_ok()); +} + +#[test] +fn test_rerank_request_rid_array() { + let request = RerankRequest { + query: "test query".to_string(), + documents: vec!["doc1".to_string()], + model: "test-model".to_string(), + top_k: None, + return_documents: true, + rid: Some(StringOrArray::Array(vec![ + "req1".to_string(), + "req2".to_string(), + ])), + user: None, + }; + + assert!(request.validate().is_ok()); +} + +#[test] +fn test_rerank_response_with_usage_info() { + let results = vec![RerankResult { + score: 0.8, + document: Some("doc1".to_string()), + index: 0, + meta_info: None, + }]; + + let mut response = RerankResponse::new( + results, + "test-model".to_string(), + Some(StringOrArray::String("req-123".to_string())), + ); + + response.usage = Some(UsageInfo { + prompt_tokens: 100, + completion_tokens: 50, + total_tokens: 150, + reasoning_tokens: None, + prompt_tokens_details: None, + }); + + let serialized = to_string(&response).unwrap(); + let deserialized: RerankResponse = from_str(&serialized).unwrap(); + + assert!(deserialized.usage.is_some()); + let usage = deserialized.usage.unwrap(); + assert_eq!(usage.prompt_tokens, 100); + assert_eq!(usage.completion_tokens, 50); + assert_eq!(usage.total_tokens, 150); +} + +#[test] +fn test_full_rerank_workflow() { + // Create request + let request = RerankRequest { + query: "machine learning".to_string(), + documents: vec![ + "Introduction to machine learning algorithms".to_string(), + "Deep learning for computer vision".to_string(), + "Natural language processing basics".to_string(), + "Statistics and probability theory".to_string(), + ], + model: "rerank-model".to_string(), + top_k: Some(2), + return_documents: true, + rid: Some(StringOrArray::String("req-123".to_string())), + user: Some("user-456".to_string()), + }; + + // Validate request + assert!(request.validate().is_ok()); + + // Simulate reranking results (in real scenario, this would come from the model) + let results = vec![ + RerankResult { + score: 0.95, + document: Some("Introduction to machine learning algorithms".to_string()), + index: 0, + meta_info: None, + }, + RerankResult { + score: 0.87, + document: Some("Deep learning for computer vision".to_string()), + index: 1, + meta_info: None, + }, + RerankResult { + score: 0.72, + document: Some("Natural language processing basics".to_string()), + index: 2, + meta_info: None, + }, + RerankResult { + score: 0.45, + document: Some("Statistics and probability theory".to_string()), + index: 3, + meta_info: None, + }, + ]; + + // Create response + let mut response = RerankResponse::new(results, request.model.clone(), request.rid.clone()); + + // Sort by score + response.sort_by_score(); + + // Apply top_k + response.apply_top_k(request.effective_top_k()); + + assert_eq!(response.results.len(), 2); + assert_eq!(response.results[0].score, 0.95); + assert_eq!(response.results[0].index, 0); + assert_eq!(response.results[1].score, 0.87); + assert_eq!(response.results[1].index, 1); + assert_eq!(response.model, "rerank-model"); + + // Serialize and deserialize + let serialized = to_string(&response).unwrap(); + let deserialized: RerankResponse = from_str(&serialized).unwrap(); + assert_eq!(deserialized.results.len(), 2); + assert_eq!(deserialized.model, response.model); +} diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index 56f6f64f1..e864b4ec0 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -601,7 +601,6 @@ async fn test_unsupported_endpoints() { prompt: None, text: Some("Hello world".to_string()), input_ids: None, - parameters: None, sampling_params: None, stream: false, return_logprob: false, @@ -642,7 +641,6 @@ async fn test_openai_router_chat_completion_with_mock() { // Create a minimal chat completion request let mut chat_request = create_minimal_chat_request(); chat_request.messages = vec![ChatMessage::User { - role: "user".to_string(), content: UserMessageContent::Text("Hello, how are you?".to_string()), name: None, }];