From 4c9bcb9d5679aa90fc0813861c41bf2a76975f58 Mon Sep 17 00:00:00 2001 From: Keyang Ru Date: Thu, 16 Oct 2025 13:44:44 -0700 Subject: [PATCH] [Router] Refactor protocol definitions: split spec.rs into modular files (#11677) Co-authored-by: Chang Su --- sgl-router/benches/request_processing.rs | 9 +- sgl-router/benches/tool_parser_benchmark.rs | 2 +- .../src/grpc_client/sglang_scheduler.rs | 12 +- sgl-router/src/protocols/chat.rs | 682 ++++ sgl-router/src/protocols/common.rs | 345 +++ sgl-router/src/protocols/completion.rs | 213 ++ sgl-router/src/protocols/embedding.rs | 57 + sgl-router/src/protocols/generate.rs | 288 ++ sgl-router/src/protocols/mod.rs | 9 +- sgl-router/src/protocols/rerank.rs | 211 ++ sgl-router/src/protocols/responses.rs | 821 +++++ sgl-router/src/protocols/sampling_params.rs | 119 + sgl-router/src/protocols/spec.rs | 2739 ----------------- sgl-router/src/protocols/validated.rs | 11 +- sgl-router/src/routers/grpc/context.rs | 5 +- sgl-router/src/routers/grpc/mod.rs | 2 +- sgl-router/src/routers/grpc/pd_router.rs | 10 +- sgl-router/src/routers/grpc/pipeline.rs | 4 +- sgl-router/src/routers/grpc/processing.rs | 8 +- sgl-router/src/routers/grpc/router.rs | 10 +- sgl-router/src/routers/grpc/streaming.rs | 9 +- sgl-router/src/routers/grpc/utils.rs | 11 +- sgl-router/src/routers/http/pd_router.rs | 20 +- sgl-router/src/routers/http/router.rs | 16 +- sgl-router/src/routers/mod.rs | 10 +- .../src/routers/openai/conversations.rs | 4 +- sgl-router/src/routers/openai/mcp.rs | 6 +- sgl-router/src/routers/openai/responses.rs | 2 +- sgl-router/src/routers/openai/router.rs | 8 +- sgl-router/src/routers/openai/streaming.rs | 2 +- sgl-router/src/routers/router_manager.rs | 10 +- sgl-router/src/server.rs | 12 +- .../tool_parser/parsers/deepseek_parser.rs | 2 +- .../tool_parser/parsers/glm4_moe_parser.rs | 2 +- .../parsers/gpt_oss_harmony_parser.rs | 2 +- .../src/tool_parser/parsers/gpt_oss_parser.rs | 2 +- sgl-router/src/tool_parser/parsers/helpers.rs | 2 +- .../src/tool_parser/parsers/json_parser.rs | 2 +- .../src/tool_parser/parsers/kimik2_parser.rs | 2 +- .../src/tool_parser/parsers/llama_parser.rs | 2 +- .../src/tool_parser/parsers/mistral_parser.rs | 2 +- .../tool_parser/parsers/passthrough_parser.rs | 2 +- .../tool_parser/parsers/pythonic_parser.rs | 2 +- .../src/tool_parser/parsers/qwen_parser.rs | 2 +- .../src/tool_parser/parsers/step3_parser.rs | 2 +- sgl-router/src/tool_parser/traits.rs | 2 +- .../tests/chat_template_format_detection.rs | 12 +- sgl-router/tests/chat_template_integration.rs | 41 +- sgl-router/tests/chat_template_loading.rs | 18 +- sgl-router/tests/common/mod.rs | 2 +- sgl-router/tests/responses_api_test.rs | 36 +- sgl-router/tests/spec/chat_completion.rs | 7 +- sgl-router/tests/spec/chat_message.rs | 2 +- sgl-router/tests/spec/embedding.rs | 3 +- sgl-router/tests/spec/rerank.rs | 19 +- sgl-router/tests/test_openai_routing.rs | 18 +- 56 files changed, 2939 insertions(+), 2914 deletions(-) create mode 100644 sgl-router/src/protocols/chat.rs create mode 100644 sgl-router/src/protocols/common.rs create mode 100644 sgl-router/src/protocols/completion.rs create mode 100644 sgl-router/src/protocols/embedding.rs create mode 100644 sgl-router/src/protocols/generate.rs create mode 100644 sgl-router/src/protocols/rerank.rs create mode 100644 sgl-router/src/protocols/responses.rs create mode 100644 sgl-router/src/protocols/sampling_params.rs delete mode 100644 sgl-router/src/protocols/spec.rs diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index b2d6d7430..f9bce7942 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -3,10 +3,11 @@ use serde_json::{from_str, to_string, to_value, to_vec}; use std::time::Instant; use sglang_router_rs::core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType}; -use sglang_router_rs::protocols::spec::{ - ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, SamplingParams, - StringOrArray, UserMessageContent, -}; +use sglang_router_rs::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}; +use sglang_router_rs::protocols::common::StringOrArray; +use sglang_router_rs::protocols::completion::CompletionRequest; +use sglang_router_rs::protocols::generate::GenerateRequest; +use sglang_router_rs::protocols::sampling_params::SamplingParams; use sglang_router_rs::routers::http::pd_types::{generate_room_id, RequestWithBootstrap}; fn create_test_worker() -> BasicWorker { diff --git a/sgl-router/benches/tool_parser_benchmark.rs b/sgl-router/benches/tool_parser_benchmark.rs index 96f7d6f69..55c965fb7 100644 --- a/sgl-router/benches/tool_parser_benchmark.rs +++ b/sgl-router/benches/tool_parser_benchmark.rs @@ -9,7 +9,7 @@ use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput}; use serde_json::json; -use sglang_router_rs::protocols::spec::{Function, Tool}; +use sglang_router_rs::protocols::common::{Function, Tool}; use sglang_router_rs::tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser}; use std::collections::BTreeMap; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; diff --git a/sgl-router/src/grpc_client/sglang_scheduler.rs b/sgl-router/src/grpc_client/sglang_scheduler.rs index 9a5ef9a1f..92413116b 100644 --- a/sgl-router/src/grpc_client/sglang_scheduler.rs +++ b/sgl-router/src/grpc_client/sglang_scheduler.rs @@ -7,10 +7,10 @@ use std::time::Duration; use tonic::{transport::Channel, Request, Streaming}; use tracing::{debug, warn}; -use crate::protocols::spec::{ - ChatCompletionRequest, GenerateRequest, ResponseFormat, - SamplingParams as GenerateSamplingParams, StringOrArray, -}; +use crate::protocols::chat::ChatCompletionRequest; +use crate::protocols::common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue}; +use crate::protocols::generate::GenerateRequest; +use crate::protocols::sampling_params::SamplingParams as GenerateSamplingParams; // Include the generated protobuf code pub mod proto { @@ -306,9 +306,7 @@ impl SglangSchedulerClient { // Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none" let skip_special_tokens = if request.tools.is_some() { match &request.tool_choice { - Some(crate::protocols::spec::ToolChoice::Value( - crate::protocols::spec::ToolChoiceValue::None, - )) => request.skip_special_tokens, + Some(ToolChoice::Value(ToolChoiceValue::None)) => request.skip_special_tokens, Some(_) => false, // tool_choice is not "none" None => false, // TODO: this assumes tool_choice defaults to "auto" when tools present } diff --git a/sgl-router/src/protocols/chat.rs b/sgl-router/src/protocols/chat.rs new file mode 100644 index 000000000..362b9d524 --- /dev/null +++ b/sgl-router/src/protocols/chat.rs @@ -0,0 +1,682 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use validator::Validate; + +use super::common::*; +use super::sampling_params::{validate_top_k_value, validate_top_p_value}; +use crate::protocols::validated::Normalizable; + +// ============================================================================ +// Chat Messages +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "role")] +pub enum ChatMessage { + #[serde(rename = "system")] + System { + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + #[serde(rename = "user")] + User { + content: UserMessageContent, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + #[serde(rename = "assistant")] + Assistant { + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + /// Reasoning content for O1-style models (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_content: Option, + }, + #[serde(rename = "tool")] + Tool { + content: String, + tool_call_id: String, + }, + #[serde(rename = "function")] + Function { content: String, name: String }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum UserMessageContent { + Text(String), + Parts(Vec), +} + +// ============================================================================ +// Chat Completion Request +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)] +#[validate(schema(function = "validate_chat_cross_parameters"))] +pub struct ChatCompletionRequest { + /// A list of messages comprising the conversation so far + #[validate(custom(function = "validate_messages"))] + pub messages: Vec, + + /// ID of the model to use + #[serde(default = "default_model")] + pub model: String, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = -2.0, max = 2.0))] + pub frequency_penalty: Option, + + /// Deprecated: Replaced by tool_choice + #[serde(skip_serializing_if = "Option::is_none")] + #[deprecated(note = "Use tool_choice instead")] + pub function_call: Option, + + /// Deprecated: Replaced by tools + #[serde(skip_serializing_if = "Option::is_none")] + #[deprecated(note = "Use tools instead")] + pub functions: Option>, + + /// Modify the likelihood of specified tokens appearing in the completion + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + + /// Whether to return log probabilities of the output tokens + #[serde(default)] + pub logprobs: bool, + + /// Deprecated: Replaced by max_completion_tokens + #[serde(skip_serializing_if = "Option::is_none")] + #[deprecated(note = "Use max_completion_tokens instead")] + #[validate(range(min = 1))] + pub max_tokens: Option, + + /// An upper bound for the number of tokens that can be generated for a completion + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 1))] + pub max_completion_tokens: Option, + + /// Developer-defined tags and values used for filtering completions in the dashboard + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + + /// Output types that you would like the model to generate for this request + #[serde(skip_serializing_if = "Option::is_none")] + pub modalities: Option>, + + /// How many chat completion choices to generate for each input message + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 1, max = 10))] + pub n: Option, + + /// Whether to enable parallel function calling during tool use + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = -2.0, max = 2.0))] + pub presence_penalty: Option, + + /// Cache key for prompts (beta feature) + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_cache_key: Option, + + /// Effort level for reasoning models (low, medium, high) + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + + /// An object specifying the format that the model must output + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + + /// Safety identifier for content moderation + #[serde(skip_serializing_if = "Option::is_none")] + pub safety_identifier: Option, + + /// Deprecated: This feature is in Legacy mode + #[serde(skip_serializing_if = "Option::is_none")] + #[deprecated(note = "This feature is in Legacy mode")] + pub seed: Option, + + /// The service tier to use for this request + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, + + /// Up to 4 sequences where the API will stop generating further tokens + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(custom(function = "validate_stop"))] + pub stop: Option, + + /// If set, partial message deltas will be sent + #[serde(default)] + pub stream: bool, + + /// Options for streaming response + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + + /// What sampling temperature to use, between 0 and 2 + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0.0, max = 2.0))] + pub temperature: Option, + + /// Controls which (if any) tool is called by the model + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + + /// A list of tools the model may call + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + /// An integer between 0 and 20 specifying the number of most likely tokens to return + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0, max = 20))] + pub top_logprobs: Option, + + /// An alternative to sampling with temperature + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(custom(function = "validate_top_p_value"))] + pub top_p: Option, + + /// Verbosity level for debugging + #[serde(skip_serializing_if = "Option::is_none")] + pub verbosity: Option, + + // ============================================================================= + // Engine-Specific Sampling Parameters + // ============================================================================= + // These parameters are extensions beyond the OpenAI API specification and + // control model generation behavior in engine-specific ways. + // ============================================================================= + /// Top-k sampling parameter (-1 to disable) + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(custom(function = "validate_top_k_value"))] + pub top_k: Option, + + /// Min-p nucleus sampling parameter + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0.0, max = 1.0))] + pub min_p: Option, + + /// Minimum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 1))] + pub min_tokens: Option, + + /// Repetition penalty for reducing repetitive text + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0.0, max = 2.0))] + pub repetition_penalty: Option, + + /// Regex constraint for output generation + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + + /// EBNF grammar constraint for structured output + #[serde(skip_serializing_if = "Option::is_none")] + pub ebnf: Option, + + /// Specific token IDs to use as stop conditions + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + + /// Skip trimming stop tokens from output + #[serde(default)] + pub no_stop_trim: bool, + + /// Ignore end-of-sequence tokens during generation + #[serde(default)] + pub ignore_eos: bool, + + /// Continue generating from final assistant message + #[serde(default)] + pub continue_final_message: bool, + + /// Skip special tokens during detokenization + #[serde(default = "default_true")] + pub skip_special_tokens: bool, + + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Separate reasoning content from final answer (O1-style models) + #[serde(default = "default_true")] + pub separate_reasoning: bool, + + /// Stream reasoning tokens during generation + #[serde(default = "default_true")] + pub stream_reasoning: bool, + + /// Chat template kwargs + #[serde(skip_serializing_if = "Option::is_none")] + pub chat_template_kwargs: Option>, + + /// Return model hidden states + #[serde(default)] + pub return_hidden_states: bool, + + /// Random seed for sampling for deterministic outputs + #[serde(skip_serializing_if = "Option::is_none")] + pub sampling_seed: Option, +} + +// ============================================================================ +// Validation Functions +// ============================================================================ + +/// Validates stop sequences (max 4, non-empty strings) +fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> { + match stop { + StringOrArray::String(s) => { + if s.is_empty() { + return Err(validator::ValidationError::new( + "stop sequences cannot be empty", + )); + } + } + StringOrArray::Array(arr) => { + if arr.len() > 4 { + return Err(validator::ValidationError::new( + "maximum 4 stop sequences allowed", + )); + } + for s in arr { + if s.is_empty() { + return Err(validator::ValidationError::new( + "stop sequences cannot be empty", + )); + } + } + } + } + Ok(()) +} + +/// Validates messages array is not empty and has valid content +fn validate_messages(messages: &[ChatMessage]) -> Result<(), validator::ValidationError> { + if messages.is_empty() { + return Err(validator::ValidationError::new("messages cannot be empty")); + } + + for msg in messages.iter() { + if let ChatMessage::User { content, .. } = msg { + match content { + UserMessageContent::Text(text) if text.is_empty() => { + return Err(validator::ValidationError::new( + "message content cannot be empty", + )); + } + UserMessageContent::Parts(parts) if parts.is_empty() => { + return Err(validator::ValidationError::new( + "message content parts cannot be empty", + )); + } + _ => {} + } + } + } + Ok(()) +} + +/// Schema-level validation for cross-field dependencies +fn validate_chat_cross_parameters( + req: &ChatCompletionRequest, +) -> Result<(), validator::ValidationError> { + // 1. Validate logprobs dependency + if req.top_logprobs.is_some() && !req.logprobs { + let mut e = validator::ValidationError::new("top_logprobs_requires_logprobs"); + e.message = Some("top_logprobs is only allowed when logprobs is enabled".into()); + return Err(e); + } + + // 2. Validate stream_options dependency + if req.stream_options.is_some() && !req.stream { + let mut e = validator::ValidationError::new("stream_options_requires_stream"); + e.message = + Some("The 'stream_options' parameter is only allowed when 'stream' is enabled".into()); + return Err(e); + } + + // 3. Validate token limits - min <= max + if let (Some(min), Some(max)) = (req.min_tokens, req.max_completion_tokens) { + if min > max { + let mut e = validator::ValidationError::new("min_tokens_exceeds_max"); + e.message = Some("min_tokens cannot exceed max_tokens/max_completion_tokens".into()); + return Err(e); + } + } + + // 4. Validate structured output conflicts + let has_json_format = matches!( + req.response_format, + Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. }) + ); + + if has_json_format && req.regex.is_some() { + let mut e = validator::ValidationError::new("regex_conflicts_with_json"); + e.message = Some("cannot use regex constraint with JSON response format".into()); + return Err(e); + } + + if has_json_format && req.ebnf.is_some() { + let mut e = validator::ValidationError::new("ebnf_conflicts_with_json"); + e.message = Some("cannot use EBNF constraint with JSON response format".into()); + return Err(e); + } + + // 5. Validate mutually exclusive structured output constraints + let constraint_count = [ + req.regex.is_some(), + req.ebnf.is_some(), + matches!(req.response_format, Some(ResponseFormat::JsonSchema { .. })), + ] + .iter() + .filter(|&&x| x) + .count(); + + if constraint_count > 1 { + let mut e = validator::ValidationError::new("multiple_constraints"); + e.message = Some("only one structured output constraint (regex, ebnf, or json_schema) can be active at a time".into()); + return Err(e); + } + + // 6. Validate response format JSON schema name + if let Some(ResponseFormat::JsonSchema { json_schema }) = &req.response_format { + if json_schema.name.is_empty() { + let mut e = validator::ValidationError::new("json_schema_name_empty"); + e.message = Some("JSON schema name cannot be empty".into()); + return Err(e); + } + } + + // 7. Validate tool_choice requires tools (except for "none") + if let Some(ref tool_choice) = req.tool_choice { + let has_tools = req.tools.as_ref().is_some_and(|t| !t.is_empty()); + + // Check if tool_choice is anything other than "none" + let is_some_choice = !matches!(tool_choice, ToolChoice::Value(ToolChoiceValue::None)); + + if is_some_choice && !has_tools { + let mut e = validator::ValidationError::new("tool_choice_requires_tools"); + e.message = Some("Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.".into()); + return Err(e); + } + + // Additional validation when tools are present + if has_tools { + let tools = req.tools.as_ref().unwrap(); + + match tool_choice { + ToolChoice::Function { function, .. } => { + // Validate that the specified function name exists in tools + let function_exists = tools.iter().any(|tool| { + tool.tool_type == "function" && tool.function.name == function.name + }); + + if !function_exists { + let mut e = + validator::ValidationError::new("tool_choice_function_not_found"); + e.message = Some( + format!( + "Invalid value for 'tool_choice': function '{}' not found in 'tools'.", + function.name + ) + .into(), + ); + return Err(e); + } + } + ToolChoice::AllowedTools { + mode, + tools: allowed_tools, + .. + } => { + // Validate mode is "auto" or "required" + if mode != "auto" && mode != "required" { + let mut e = validator::ValidationError::new("tool_choice_invalid_mode"); + e.message = Some(format!( + "Invalid value for 'tool_choice.mode': must be 'auto' or 'required', got '{}'.", + mode + ).into()); + return Err(e); + } + + // Validate that all referenced tool names exist in tools + for tool_ref in allowed_tools { + let tool_exists = tools.iter().any(|tool| { + tool.tool_type == tool_ref.tool_type + && tool.function.name == tool_ref.name + }); + + if !tool_exists { + let mut e = + validator::ValidationError::new("tool_choice_tool_not_found"); + e.message = Some(format!( + "Invalid value for 'tool_choice.tools': tool '{}' not found in 'tools'.", + tool_ref.name + ).into()); + return Err(e); + } + } + } + _ => {} + } + } + } + + Ok(()) +} + +// ============================================================================ +// Normalizable Implementation +// ============================================================================ + +impl Normalizable for ChatCompletionRequest { + /// Normalize the request by applying migrations and defaults: + /// 1. Migrate deprecated fields to their replacements + /// 2. Clear deprecated fields and log warnings + /// 3. Apply OpenAI defaults for tool_choice + fn normalize(&mut self) { + // Migrate deprecated max_tokens → max_completion_tokens + #[allow(deprecated)] + if self.max_completion_tokens.is_none() && self.max_tokens.is_some() { + tracing::warn!("max_tokens is deprecated, use max_completion_tokens instead"); + self.max_completion_tokens = self.max_tokens; + self.max_tokens = None; // Clear deprecated field + } + + // Migrate deprecated functions → tools + #[allow(deprecated)] + if self.tools.is_none() && self.functions.is_some() { + tracing::warn!("functions is deprecated, use tools instead"); + self.tools = self.functions.as_ref().map(|functions| { + functions + .iter() + .map(|func| Tool { + tool_type: "function".to_string(), + function: func.clone(), + }) + .collect() + }); + self.functions = None; // Clear deprecated field + } + + // Migrate deprecated function_call → tool_choice + #[allow(deprecated)] + if self.tool_choice.is_none() && self.function_call.is_some() { + tracing::warn!("function_call is deprecated, use tool_choice instead"); + self.tool_choice = self.function_call.as_ref().map(|fc| match fc { + FunctionCall::None => ToolChoice::Value(ToolChoiceValue::None), + FunctionCall::Auto => ToolChoice::Value(ToolChoiceValue::Auto), + FunctionCall::Function { name } => ToolChoice::Function { + tool_type: "function".to_string(), + function: FunctionChoice { name: name.clone() }, + }, + }); + self.function_call = None; // Clear deprecated field + } + + // Apply tool_choice defaults + if self.tool_choice.is_none() { + let has_tools = self.tools.as_ref().is_some_and(|t| !t.is_empty()); + + self.tool_choice = if has_tools { + Some(ToolChoice::Value(ToolChoiceValue::Auto)) + } else { + Some(ToolChoice::Value(ToolChoiceValue::None)) + }; + } + } +} + +// ============================================================================ +// GenerationRequest Trait Implementation +// ============================================================================ + +impl GenerationRequest for ChatCompletionRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn extract_text_for_routing(&self) -> String { + // Extract text from messages for routing decisions + self.messages + .iter() + .filter_map(|msg| match msg { + ChatMessage::System { content, .. } => Some(content.clone()), + ChatMessage::User { content, .. } => match content { + UserMessageContent::Text(text) => Some(text.clone()), + UserMessageContent::Parts(parts) => { + let texts: Vec = parts + .iter() + .filter_map(|part| match part { + ContentPart::Text { text } => Some(text.clone()), + _ => None, + }) + .collect(); + Some(texts.join(" ")) + } + }, + ChatMessage::Assistant { + content, + reasoning_content, + .. + } => { + // Combine content and reasoning content for routing decisions + let main_content = content.clone().unwrap_or_default(); + let reasoning = reasoning_content.clone().unwrap_or_default(); + if main_content.is_empty() && reasoning.is_empty() { + None + } else { + Some(format!("{} {}", main_content, reasoning).trim().to_string()) + } + } + ChatMessage::Tool { content, .. } => Some(content.clone()), + ChatMessage::Function { content, .. } => Some(content.clone()), + }) + .collect::>() + .join(" ") + } +} + +// ============================================================================ +// Response Types +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, // "chat.completion" + pub created: u64, + pub model: String, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +/// Response message structure for ChatCompletionResponse (different from request ChatMessage) +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionMessage { + pub role: String, // Always "assistant" for responses + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + /// Reasoning content for O1-style models (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, + // Note: function_call is deprecated and not included + // Note: refusal, annotations, audio are not added yet +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatChoice { + pub index: u32, + pub message: ChatCompletionMessage, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, // "stop", "length", "tool_calls", "content_filter", "function_call" + /// Information about which stop condition was matched + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, // Can be string or integer + /// Hidden states from the model (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + pub hidden_states: Option>, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionStreamResponse { + pub id: String, + pub object: String, // "chat.completion.chunk" + pub created: u64, + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +/// Delta structure for streaming chat completion responses +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatMessageDelta { + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + /// Reasoning content delta for O1-style models (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatStreamChoice { + pub index: u32, + pub delta: ChatMessageDelta, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, +} diff --git a/sgl-router/src/protocols/common.rs b/sgl-router/src/protocols/common.rs new file mode 100644 index 000000000..339045fc1 --- /dev/null +++ b/sgl-router/src/protocols/common.rs @@ -0,0 +1,345 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; + +// ============================================================================ +// Default value helpers +// ============================================================================ + +/// Default model value when not specified +pub(crate) fn default_model() -> String { + "unknown".to_string() +} + +/// Helper function for serde default value (returns true) +pub fn default_true() -> bool { + true +} + +// ============================================================================ +// GenerationRequest Trait +// ============================================================================ + +/// Trait for unified access to generation request properties +/// Implemented by ChatCompletionRequest, CompletionRequest, GenerateRequest, +/// EmbeddingRequest, RerankRequest, and ResponsesRequest +pub trait GenerationRequest: Send + Sync { + /// Check if the request is for streaming + fn is_stream(&self) -> bool; + + /// Get the model name if specified + fn get_model(&self) -> Option<&str>; + + /// Extract text content for routing decisions + fn extract_text_for_routing(&self) -> String; +} + +// ============================================================================ +// String/Array Utilities +// ============================================================================ + +/// A type that can be either a single string or an array of strings +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] +pub enum StringOrArray { + String(String), + Array(Vec), +} + +impl StringOrArray { + /// Get the number of items in the StringOrArray + pub fn len(&self) -> usize { + match self { + StringOrArray::String(_) => 1, + StringOrArray::Array(arr) => arr.len(), + } + } + + /// Check if the StringOrArray is empty + pub fn is_empty(&self) -> bool { + match self { + StringOrArray::String(s) => s.is_empty(), + StringOrArray::Array(arr) => arr.is_empty(), + } + } + + /// Convert to a vector of strings + pub fn to_vec(&self) -> Vec { + match self { + StringOrArray::String(s) => vec![s.clone()], + StringOrArray::Array(arr) => arr.clone(), + } + } +} + +// ============================================================================ +// Content Parts (for multimodal messages) +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +pub enum ContentPart { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + ImageUrl { image_url: ImageUrl }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ImageUrl { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, // "auto", "low", or "high" +} + +// ============================================================================ +// Response Format (for structured outputs) +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +pub enum ResponseFormat { + #[serde(rename = "text")] + Text, + #[serde(rename = "json_object")] + JsonObject, + #[serde(rename = "json_schema")] + JsonSchema { json_schema: JsonSchemaFormat }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct JsonSchemaFormat { + pub name: String, + pub schema: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +// ============================================================================ +// Streaming +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct StreamOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub include_usage: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolCallDelta { + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "type")] + pub tool_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionCallDelta { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, +} + +// ============================================================================ +// Tools and Function Calling +// ============================================================================ + +/// Tool choice value for simple string options +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ToolChoiceValue { + Auto, + Required, + None, +} + +/// Tool choice for both Chat Completion and Responses APIs +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ToolChoice { + Value(ToolChoiceValue), + Function { + #[serde(rename = "type")] + tool_type: String, // "function" + function: FunctionChoice, + }, + AllowedTools { + #[serde(rename = "type")] + tool_type: String, // "allowed_tools" + mode: String, // "auto" | "required" TODO: need validation + tools: Vec, + }, +} + +impl Default for ToolChoice { + fn default() -> Self { + Self::Value(ToolChoiceValue::Auto) + } +} + +/// Function choice specification for ToolChoice::Function +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionChoice { + pub name: String, +} + +/// Tool reference for ToolChoice::AllowedTools +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolReference { + #[serde(rename = "type")] + pub tool_type: String, // "function" + pub name: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Tool { + #[serde(rename = "type")] + pub tool_type: String, // "function" + pub function: Function, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Function { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub parameters: Value, // JSON Schema + /// Whether to enable strict schema adherence (OpenAI structured outputs) + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub tool_type: String, // "function" + pub function: FunctionCallResponse, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum FunctionCall { + None, + Auto, + Function { name: String }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionCallResponse { + pub name: String, + #[serde(default)] + pub arguments: Option, // JSON string +} + +// ============================================================================ +// Usage and Logging +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_tokens_details: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionTokensDetails { + pub reasoning_tokens: Option, +} + +/// Usage information (used by rerank and other endpoints) +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct UsageInfo { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_tokens_details: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct PromptTokenUsageInfo { + pub cached_tokens: u32, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct LogProbs { + pub tokens: Vec, + pub token_logprobs: Vec>, + pub top_logprobs: Vec>>, + pub text_offset: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ChatLogProbs { + Detailed { + #[serde(skip_serializing_if = "Option::is_none")] + content: Option>, + }, + Raw(Value), +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatLogProbsContent { + pub token: String, + pub logprob: f32, + pub bytes: Option>, + pub top_logprobs: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct TopLogProb { + pub token: String, + pub logprob: f32, + pub bytes: Option>, +} + +// ============================================================================ +// Error Types +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ErrorResponse { + pub error: ErrorDetail, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ErrorDetail { + pub message: String, + #[serde(rename = "type")] + pub error_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub param: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, +} + +// ============================================================================ +// Input Types +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum InputIds { + Single(Vec), + Batch(Vec>), +} + +/// LoRA adapter path - can be single path or batch of paths (SGLang extension) +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum LoRAPath { + Single(Option), + Batch(Vec>), +} diff --git a/sgl-router/src/protocols/completion.rs b/sgl-router/src/protocols/completion.rs new file mode 100644 index 000000000..a7bdfcfde --- /dev/null +++ b/sgl-router/src/protocols/completion.rs @@ -0,0 +1,213 @@ +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; +use std::collections::HashMap; + +use super::common::*; + +// ============================================================================ +// Completions API (v1/completions) - DEPRECATED but still supported +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionRequest { + /// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang) + pub model: String, + + /// The prompt(s) to generate completions for + pub prompt: StringOrArray, + + /// The suffix that comes after a completion of inserted text + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, + + /// The maximum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// What sampling temperature to use, between 0 and 2 + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// An alternative to sampling with temperature (nucleus sampling) + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// How many completions to generate for each prompt + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + + /// Whether to stream back partial progress + #[serde(default)] + pub stream: bool, + + /// Options for streaming response + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + + /// Include the log probabilities on the logprobs most likely tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + + /// Echo back the prompt in addition to the completion + #[serde(default)] + pub echo: bool, + + /// Up to 4 sequences where the API will stop generating further tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + + /// Generates best_of completions server-side and returns the "best" + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of: Option, + + /// Modify the likelihood of specified tokens appearing in the completion + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + + /// A unique identifier representing your end-user + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// If specified, our system will make a best effort to sample deterministically + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + + // -------- Engine Specific Sampling Parameters -------- + /// Top-k sampling parameter (-1 to disable) + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + + /// Min-p nucleus sampling parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub min_p: Option, + + /// Minimum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + + /// Repetition penalty for reducing repetitive text + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + + /// Regex constraint for output generation + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + + /// EBNF grammar constraint for structured output + #[serde(skip_serializing_if = "Option::is_none")] + pub ebnf: Option, + + /// JSON schema constraint for structured output + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, + + /// Specific token IDs to use as stop conditions + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + + /// Skip trimming stop tokens from output + #[serde(default)] + pub no_stop_trim: bool, + + /// Ignore end-of-sequence tokens during generation + #[serde(default)] + pub ignore_eos: bool, + + /// Skip special tokens during detokenization + #[serde(default = "default_true")] + pub skip_special_tokens: bool, + + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Return model hidden states + #[serde(default)] + pub return_hidden_states: bool, + + /// Sampling seed for deterministic outputs + #[serde(skip_serializing_if = "Option::is_none")] + pub sampling_seed: Option, + + /// Additional fields including bootstrap info for PD routing + #[serde(flatten)] + pub other: Map, +} + +impl GenerationRequest for CompletionRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn extract_text_for_routing(&self) -> String { + match &self.prompt { + StringOrArray::String(s) => s.clone(), + StringOrArray::Array(v) => v.join(" "), + } + } +} + +// ============================================================================ +// Response Types +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionResponse { + pub id: String, + pub object: String, // "text_completion" + pub created: u64, + pub model: String, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionChoice { + pub text: String, + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, // "stop", "length", "content_filter", etc. + /// Information about which stop condition was matched + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, // Can be string or integer +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionStreamResponse { + pub id: String, + pub object: String, // "text_completion" + pub created: u64, + pub choices: Vec, + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionStreamChoice { + pub text: String, + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, +} diff --git a/sgl-router/src/protocols/embedding.rs b/sgl-router/src/protocols/embedding.rs new file mode 100644 index 000000000..a76c9b67d --- /dev/null +++ b/sgl-router/src/protocols/embedding.rs @@ -0,0 +1,57 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use super::common::GenerationRequest; + +// ============================================================================ +// Embedding API +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct EmbeddingRequest { + /// ID of the model to use + pub model: String, + + /// Input can be a string, array of strings, tokens, or batch inputs + pub input: Value, + + /// Optional encoding format (e.g., "float", "base64") + #[serde(skip_serializing_if = "Option::is_none")] + pub encoding_format: Option, + + /// Optional user identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// Optional number of dimensions for the embedding + #[serde(skip_serializing_if = "Option::is_none")] + pub dimensions: Option, + + /// SGLang extension: request id for tracking + #[serde(skip_serializing_if = "Option::is_none")] + pub rid: Option, +} + +impl GenerationRequest for EmbeddingRequest { + fn is_stream(&self) -> bool { + // Embeddings are non-streaming + false + } + + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn extract_text_for_routing(&self) -> String { + // Best effort: extract text content for routing decisions + match &self.input { + Value::String(s) => s.clone(), + Value::Array(arr) => arr + .iter() + .filter_map(|v| v.as_str()) + .collect::>() + .join(" "), + _ => String::new(), + } + } +} diff --git a/sgl-router/src/protocols/generate.rs b/sgl-router/src/protocols/generate.rs new file mode 100644 index 000000000..3aac25640 --- /dev/null +++ b/sgl-router/src/protocols/generate.rs @@ -0,0 +1,288 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use validator::Validate; + +use super::common::{default_true, GenerationRequest, InputIds}; +use super::sampling_params::SamplingParams; +use crate::protocols::validated::Normalizable; + +// ============================================================================ +// SGLang Generate API (native format) +// ============================================================================ + +#[derive(Clone, Debug, Serialize, Deserialize, Validate)] +#[validate(schema(function = "validate_generate_request"))] +pub struct GenerateRequest { + /// Text input - SGLang native format + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + + /// Input IDs for tokenized input + #[serde(skip_serializing_if = "Option::is_none")] + pub input_ids: Option, + + /// Input embeddings for direct embedding input + /// Can be a 2D array (single request) or 3D array (batch of requests) + /// Placeholder for future use + #[serde(skip_serializing_if = "Option::is_none")] + pub input_embeds: Option, + + /// Image input data + /// Can be an image instance, file name, URL, or base64 encoded string + /// Supports single images, lists of images, or nested lists for batch processing + /// Placeholder for future use + #[serde(skip_serializing_if = "Option::is_none")] + pub image_data: Option, + + /// Video input data + /// Can be a file name, URL, or base64 encoded string + /// Supports single videos, lists of videos, or nested lists for batch processing + /// Placeholder for future use + #[serde(skip_serializing_if = "Option::is_none")] + pub video_data: Option, + + /// Audio input data + /// Can be a file name, URL, or base64 encoded string + /// Supports single audio files, lists of audio, or nested lists for batch processing + /// Placeholder for future use + #[serde(skip_serializing_if = "Option::is_none")] + pub audio_data: Option, + + /// Sampling parameters (sglang style) + #[serde(skip_serializing_if = "Option::is_none")] + pub sampling_params: Option, + + /// Whether to return logprobs + #[serde(skip_serializing_if = "Option::is_none")] + pub return_logprob: Option, + + /// If return logprobs, the start location in the prompt for returning logprobs. + #[serde(skip_serializing_if = "Option::is_none")] + pub logprob_start_len: Option, + + /// If return logprobs, the number of top logprobs to return at each position. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs_num: Option, + + /// If return logprobs, the token ids to return logprob for. + #[serde(skip_serializing_if = "Option::is_none")] + pub token_ids_logprob: Option>, + + /// Whether to detokenize tokens in text in the returned logprobs. + #[serde(default)] + pub return_text_in_logprobs: bool, + + /// Whether to stream the response + #[serde(default)] + pub stream: bool, + + /// Whether to log metrics for this request (e.g. health_generate calls do not log metrics) + #[serde(default = "default_true")] + pub log_metrics: bool, + + /// Return model hidden states + #[serde(default)] + pub return_hidden_states: bool, + + /// The modalities of the image data [image, multi-images, video] + #[serde(skip_serializing_if = "Option::is_none")] + pub modalities: Option>, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// LoRA adapter ID (if pre-loaded) + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_id: Option, + + /// Custom logit processor for advanced sampling control. Must be a serialized instance + /// of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + /// Use the processor's `to_str()` method to generate the serialized string. + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_logit_processor: Option, + + /// For disaggregated inference + #[serde(skip_serializing_if = "Option::is_none")] + pub bootstrap_host: Option, + + /// For disaggregated inference + #[serde(skip_serializing_if = "Option::is_none")] + pub bootstrap_port: Option, + + /// For disaggregated inference + #[serde(skip_serializing_if = "Option::is_none")] + pub bootstrap_room: Option, + + /// For disaggregated inference + #[serde(skip_serializing_if = "Option::is_none")] + pub bootstrap_pair_key: Option, + + /// Data parallel rank routing + #[serde(skip_serializing_if = "Option::is_none")] + pub data_parallel_rank: Option, + + /// Background response + #[serde(default)] + pub background: bool, + + /// Conversation ID for tracking + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_id: Option, + + /// Priority for the request + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + + /// Extra key for classifying the request (e.g. cache_salt) + #[serde(skip_serializing_if = "Option::is_none")] + pub extra_key: Option, + + /// Whether to disallow logging for this request (e.g. due to ZDR) + #[serde(default)] + pub no_logs: bool, + + /// Custom metric labels + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_labels: Option>, + + /// Whether to return bytes for image generation + #[serde(default)] + pub return_bytes: bool, + + /// Whether to return entropy + #[serde(default)] + pub return_entropy: bool, + + /// Request ID for tracking (inherited from BaseReq in Python) + #[serde(skip_serializing_if = "Option::is_none")] + pub rid: Option, +} + +impl Normalizable for GenerateRequest { + // Use default no-op implementation - no normalization needed for GenerateRequest +} + +/// Validation function for GenerateRequest - ensure exactly one input type is provided +fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> { + // Exactly one of text or input_ids must be provided + // Note: input_embeds not yet supported in Rust implementation + let has_text = req.text.is_some(); + let has_input_ids = req.input_ids.is_some(); + + let count = [has_text, has_input_ids].iter().filter(|&&x| x).count(); + + if count == 0 { + return Err(validator::ValidationError::new( + "Either text or input_ids should be provided.", + )); + } + + if count > 1 { + return Err(validator::ValidationError::new( + "Either text or input_ids should be provided.", + )); + } + + Ok(()) +} + +impl GenerationRequest for GenerateRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + // Generate requests typically don't have a model field + None + } + + fn extract_text_for_routing(&self) -> String { + // Check fields in priority order: text, input_ids + if let Some(ref text) = self.text { + return text.clone(); + } + + if let Some(ref input_ids) = self.input_ids { + return match input_ids { + InputIds::Single(ids) => ids + .iter() + .map(|&id| id.to_string()) + .collect::>() + .join(" "), + InputIds::Batch(batches) => batches + .iter() + .flat_map(|batch| batch.iter().map(|&id| id.to_string())) + .collect::>() + .join(" "), + }; + } + + // No text input found + String::new() + } +} + +// ============================================================================ +// SGLang Generate Response Types +// ============================================================================ + +/// SGLang generate response (single completion or array for n>1) +/// +/// Format for n=1: +/// ```json +/// { +/// "text": "...", +/// "output_ids": [...], +/// "meta_info": { ... } +/// } +/// ``` +/// +/// Format for n>1: +/// ```json +/// [ +/// {"text": "...", "output_ids": [...], "meta_info": {...}}, +/// {"text": "...", "output_ids": [...], "meta_info": {...}} +/// ] +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GenerateResponse { + pub text: String, + pub output_ids: Vec, + pub meta_info: GenerateMetaInfo, +} + +/// Metadata for a single generate completion +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GenerateMetaInfo { + pub id: String, + pub finish_reason: GenerateFinishReason, + pub prompt_tokens: u32, + pub weight_version: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub input_token_logprobs: Option>>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_token_logprobs: Option>>>, + pub completion_tokens: u32, + pub cached_tokens: u32, + pub e2e_latency: f64, + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, +} + +/// Finish reason for generate endpoint +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum GenerateFinishReason { + Length { + length: u32, + }, + Stop, + #[serde(untagged)] + Other(Value), +} diff --git a/sgl-router/src/protocols/mod.rs b/sgl-router/src/protocols/mod.rs index 418c2568b..5ba6b1893 100644 --- a/sgl-router/src/protocols/mod.rs +++ b/sgl-router/src/protocols/mod.rs @@ -1,6 +1,13 @@ // Protocol definitions and validation for various LLM APIs // This module provides a structured approach to handling different API protocols -pub mod spec; +pub mod chat; +pub mod common; +pub mod completion; +pub mod embedding; +pub mod generate; +pub mod rerank; +pub mod responses; +pub mod sampling_params; pub mod validated; pub mod worker_spec; diff --git a/sgl-router/src/protocols/rerank.rs b/sgl-router/src/protocols/rerank.rs new file mode 100644 index 000000000..584a66b4b --- /dev/null +++ b/sgl-router/src/protocols/rerank.rs @@ -0,0 +1,211 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use validator::Validate; + +use super::common::{default_model, default_true, GenerationRequest, StringOrArray, UsageInfo}; + +fn default_rerank_object() -> String { + "rerank".to_string() +} + +/// TODO: Create timestamp should not be in protocol layer +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs() as i64 +} + +// ============================================================================ +// Rerank API +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize, Validate)] +#[validate(schema(function = "validate_rerank_request"))] +pub struct RerankRequest { + /// The query text to rank documents against + #[validate(custom(function = "validate_query"))] + pub query: String, + + /// List of documents to be ranked + #[validate(custom(function = "validate_documents"))] + pub documents: Vec, + + /// Model to use for reranking + #[serde(default = "default_model")] + pub model: String, + + /// Maximum number of documents to return (optional) + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 1))] + pub top_k: Option, + + /// Whether to return documents in addition to scores + #[serde(default = "default_true")] + pub return_documents: bool, + + // SGLang specific extensions + /// Request ID for tracking + pub rid: Option, + + /// User identifier + pub user: Option, +} + +impl GenerationRequest for RerankRequest { + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn is_stream(&self) -> bool { + false // Reranking doesn't support streaming + } + + fn extract_text_for_routing(&self) -> String { + self.query.clone() + } +} + +impl super::validated::Normalizable for RerankRequest { + // Use default no-op normalization +} + +// ============================================================================ +// Validation Functions +// ============================================================================ + +/// Validates that the query is not empty +fn validate_query(query: &str) -> Result<(), validator::ValidationError> { + if query.trim().is_empty() { + return Err(validator::ValidationError::new("query cannot be empty")); + } + Ok(()) +} + +/// Validates that the documents list is not empty +fn validate_documents(documents: &[String]) -> Result<(), validator::ValidationError> { + if documents.is_empty() { + return Err(validator::ValidationError::new( + "documents list cannot be empty", + )); + } + Ok(()) +} + +/// Schema-level validation for cross-field dependencies +fn validate_rerank_request(req: &RerankRequest) -> Result<(), validator::ValidationError> { + // Validate top_k if specified + if let Some(k) = req.top_k { + if k > req.documents.len() { + // This is allowed but we log a warning + tracing::warn!( + "top_k ({}) is greater than number of documents ({})", + k, + req.documents.len() + ); + } + } + Ok(()) +} + +impl RerankRequest { + /// Get the effective top_k value + pub fn effective_top_k(&self) -> usize { + self.top_k.unwrap_or(self.documents.len()) + } +} + +/// Individual rerank result +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RerankResult { + /// Relevance score for the document + pub score: f32, + + /// The document text (if return_documents was true) + #[serde(skip_serializing_if = "Option::is_none")] + pub document: Option, + + /// Original index of the document in the request + pub index: usize, + + /// Additional metadata about the ranking + #[serde(skip_serializing_if = "Option::is_none")] + pub meta_info: Option>, +} + +/// Rerank response containing sorted results +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RerankResponse { + /// Ranked results sorted by score (highest first) + pub results: Vec, + + /// Model used for reranking + pub model: String, + + /// Usage information + pub usage: Option, + + /// Response object type + #[serde(default = "default_rerank_object")] + pub object: String, + + /// Response ID + pub id: Option, + + /// Creation timestamp + pub created: i64, +} + +impl RerankResponse { + /// Create a new RerankResponse with the given results and model + pub fn new( + results: Vec, + model: String, + request_id: Option, + ) -> Self { + RerankResponse { + results, + model, + usage: None, + object: default_rerank_object(), + id: request_id, + created: current_timestamp(), + } + } + + /// Apply top_k limit to results + pub fn apply_top_k(&mut self, k: usize) { + self.results.truncate(k); + } + + /// Drop documents from results (when return_documents is false) + pub fn drop_documents(&mut self) { + for result in &mut self.results { + result.document = None; + } + } +} + +/// V1 API compatibility format for rerank requests +/// Matches Python's V1RerankReqInput +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct V1RerankReqInput { + pub query: String, + pub documents: Vec, +} + +/// Convert V1RerankReqInput to RerankRequest +impl From for RerankRequest { + fn from(v1: V1RerankReqInput) -> Self { + RerankRequest { + query: v1.query, + documents: v1.documents, + model: default_model(), + top_k: None, + return_documents: true, + rid: None, + user: None, + } + } +} diff --git a/sgl-router/src/protocols/responses.rs b/sgl-router/src/protocols/responses.rs new file mode 100644 index 000000000..1fbda00ad --- /dev/null +++ b/sgl-router/src/protocols/responses.rs @@ -0,0 +1,821 @@ +// OpenAI Responses API types +// https://platform.openai.com/docs/api-reference/responses + +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; + +// Import shared types from common module +use super::common::{ + default_true, ChatLogProbs, GenerationRequest, PromptTokenUsageInfo, StringOrArray, ToolChoice, + UsageInfo, +}; + +// ============================================================================ +// Response Tools (MCP and others) +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponseTool { + #[serde(rename = "type")] + pub r#type: ResponseToolType, + // MCP-specific fields (used when type == "mcp") + #[serde(skip_serializing_if = "Option::is_none")] + pub server_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub authorization: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub server_label: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub server_description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub require_approval: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub allowed_tools: Option>, +} + +impl Default for ResponseTool { + fn default() -> Self { + Self { + r#type: ResponseToolType::WebSearchPreview, + server_url: None, + authorization: None, + server_label: None, + server_description: None, + require_approval: None, + allowed_tools: None, + } + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ResponseToolType { + WebSearchPreview, + CodeInterpreter, + Mcp, +} + +// ============================================================================ +// Reasoning Parameters +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponseReasoningParam { + #[serde(default = "default_reasoning_effort")] + #[serde(skip_serializing_if = "Option::is_none")] + pub effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +fn default_reasoning_effort() -> Option { + Some(ReasoningEffort::Medium) +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ReasoningEffort { + Low, + Medium, + High, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ReasoningSummary { + Auto, + Concise, + Detailed, +} + +// ============================================================================ +// Input/Output Items +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ResponseInputOutputItem { + #[serde(rename = "message")] + Message { + id: String, + role: String, + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + #[serde(rename = "reasoning")] + Reasoning { + id: String, + #[serde(skip_serializing_if = "Vec::is_empty")] + summary: Vec, + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + #[serde(rename = "function_tool_call")] + FunctionToolCall { + id: String, + name: String, + arguments: String, + #[serde(skip_serializing_if = "Option::is_none")] + output: Option, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ResponseContentPart { + #[serde(rename = "output_text")] + OutputText { + text: String, + #[serde(skip_serializing_if = "Vec::is_empty")] + annotations: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + logprobs: Option, + }, + #[serde(rename = "input_text")] + InputText { text: String }, + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ResponseReasoningContent { + #[serde(rename = "reasoning_text")] + ReasoningText { text: String }, +} + +/// MCP Tool information for the mcp_list_tools output item +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct McpToolInfo { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub input_schema: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub annotations: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ResponseOutputItem { + #[serde(rename = "message")] + Message { + id: String, + role: String, + content: Vec, + status: String, + }, + #[serde(rename = "reasoning")] + Reasoning { + id: String, + #[serde(skip_serializing_if = "Vec::is_empty")] + summary: Vec, + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + #[serde(rename = "function_tool_call")] + FunctionToolCall { + id: String, + name: String, + arguments: String, + #[serde(skip_serializing_if = "Option::is_none")] + output: Option, + status: String, + }, + #[serde(rename = "mcp_list_tools")] + McpListTools { + id: String, + server_label: String, + tools: Vec, + }, + #[serde(rename = "mcp_call")] + McpCall { + id: String, + status: String, + #[serde(skip_serializing_if = "Option::is_none")] + approval_request_id: Option, + arguments: String, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + name: String, + output: String, + server_label: String, + }, +} + +// ============================================================================ +// Configuration Enums +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ServiceTier { + Auto, + Default, + Flex, + Scale, + Priority, +} + +impl Default for ServiceTier { + fn default() -> Self { + Self::Auto + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum Truncation { + Auto, + Disabled, +} + +impl Default for Truncation { + fn default() -> Self { + Self::Disabled + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ResponseStatus { + Queued, + InProgress, + Completed, + Failed, + Cancelled, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ReasoningInfo { + #[serde(skip_serializing_if = "Option::is_none")] + pub effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponseTextFormat { + pub format: TextFormatType, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct TextFormatType { + #[serde(rename = "type")] + pub format_type: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum IncludeField { + #[serde(rename = "code_interpreter_call.outputs")] + CodeInterpreterCallOutputs, + #[serde(rename = "computer_call_output.output.image_url")] + ComputerCallOutputImageUrl, + #[serde(rename = "file_search_call.results")] + FileSearchCallResults, + #[serde(rename = "message.input_image.image_url")] + MessageInputImageUrl, + #[serde(rename = "message.output_text.logprobs")] + MessageOutputTextLogprobs, + #[serde(rename = "reasoning.encrypted_content")] + ReasoningEncryptedContent, +} + +// ============================================================================ +// Usage Types (Responses API format) +// ============================================================================ + +/// OpenAI Responses API usage format (different from standard UsageInfo) +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponseUsage { + pub input_tokens: u32, + pub output_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub input_tokens_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens_details: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ResponsesUsage { + Classic(UsageInfo), + Modern(ResponseUsage), +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct InputTokensDetails { + pub cached_tokens: u32, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct OutputTokensDetails { + pub reasoning_tokens: u32, +} + +impl UsageInfo { + /// Convert to OpenAI Responses API format + pub fn to_response_usage(&self) -> ResponseUsage { + ResponseUsage { + input_tokens: self.prompt_tokens, + output_tokens: self.completion_tokens, + total_tokens: self.total_tokens, + input_tokens_details: self.prompt_tokens_details.as_ref().map(|details| { + InputTokensDetails { + cached_tokens: details.cached_tokens, + } + }), + output_tokens_details: self.reasoning_tokens.map(|tokens| OutputTokensDetails { + reasoning_tokens: tokens, + }), + } + } +} + +impl From for ResponseUsage { + fn from(usage: UsageInfo) -> Self { + usage.to_response_usage() + } +} + +impl ResponseUsage { + /// Convert back to standard UsageInfo format + pub fn to_usage_info(&self) -> UsageInfo { + UsageInfo { + prompt_tokens: self.input_tokens, + completion_tokens: self.output_tokens, + total_tokens: self.total_tokens, + reasoning_tokens: self + .output_tokens_details + .as_ref() + .map(|details| details.reasoning_tokens), + prompt_tokens_details: self.input_tokens_details.as_ref().map(|details| { + PromptTokenUsageInfo { + cached_tokens: details.cached_tokens, + } + }), + } + } +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct ResponsesGetParams { + #[serde(default)] + pub include: Vec, + #[serde(default)] + pub include_obfuscation: Option, + #[serde(default)] + pub starting_after: Option, + #[serde(default)] + pub stream: Option, +} + +impl ResponsesUsage { + pub fn to_response_usage(&self) -> ResponseUsage { + match self { + ResponsesUsage::Classic(usage) => usage.to_response_usage(), + ResponsesUsage::Modern(usage) => usage.clone(), + } + } + + pub fn to_usage_info(&self) -> UsageInfo { + match self { + ResponsesUsage::Classic(usage) => usage.clone(), + ResponsesUsage::Modern(usage) => usage.to_usage_info(), + } + } +} + +// ============================================================================ +// Helper Functions for Defaults +// ============================================================================ + +fn default_top_k() -> i32 { + -1 +} + +fn default_repetition_penalty() -> f32 { + 1.0 +} + +// ============================================================================ +// Request/Response Types +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponsesRequest { + /// Run the request in the background + #[serde(skip_serializing_if = "Option::is_none")] + pub background: Option, + + /// Fields to include in the response + #[serde(skip_serializing_if = "Option::is_none")] + pub include: Option>, + + /// Input content - can be string or structured items + pub input: ResponseInput, + + /// System instructions for the model + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + /// Maximum number of output tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + + /// Maximum number of tool calls + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tool_calls: Option, + + /// Additional metadata + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + + /// Model to use (optional to match vLLM) + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + /// Optional conversation id to persist input/output as items + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation: Option, + + /// Whether to enable parallel tool calls + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + + /// ID of previous response to continue from + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_response_id: Option, + + /// Reasoning configuration + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + + /// Service tier + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, + + /// Whether to store the response + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, + + /// Whether to stream the response + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + + /// Temperature for sampling + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// Tool choice behavior + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + + /// Available tools + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + /// Number of top logprobs to return + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option, + + /// Top-p sampling parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// Truncation behavior + #[serde(skip_serializing_if = "Option::is_none")] + pub truncation: Option, + + /// User identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// Request ID + #[serde(skip_serializing_if = "Option::is_none")] + pub request_id: Option, + + /// Request priority + #[serde(default)] + pub priority: i32, + + /// Frequency penalty + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + + /// Presence penalty + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + + /// Stop sequences + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + + /// Top-k sampling parameter (SGLang extension) + #[serde(default = "default_top_k")] + pub top_k: i32, + + /// Min-p sampling parameter (SGLang extension) + #[serde(default)] + pub min_p: f32, + + /// Repetition penalty (SGLang extension) + #[serde(default = "default_repetition_penalty")] + pub repetition_penalty: f32, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ResponseInput { + Text(String), + Items(Vec), +} + +impl Default for ResponsesRequest { + fn default() -> Self { + Self { + background: None, + include: None, + input: ResponseInput::Text(String::new()), + instructions: None, + max_output_tokens: None, + max_tool_calls: None, + metadata: None, + model: None, + conversation: None, + parallel_tool_calls: None, + previous_response_id: None, + reasoning: None, + service_tier: None, + store: None, + stream: None, + temperature: None, + tool_choice: None, + tools: None, + top_logprobs: None, + top_p: None, + truncation: None, + user: None, + request_id: None, + priority: 0, + frequency_penalty: None, + presence_penalty: None, + stop: None, + top_k: default_top_k(), + min_p: 0.0, + repetition_penalty: default_repetition_penalty(), + } + } +} + +impl GenerationRequest for ResponsesRequest { + fn is_stream(&self) -> bool { + self.stream.unwrap_or(false) + } + + fn get_model(&self) -> Option<&str> { + self.model.as_deref() + } + + fn extract_text_for_routing(&self) -> String { + match &self.input { + ResponseInput::Text(text) => text.clone(), + ResponseInput::Items(items) => items + .iter() + .filter_map(|item| match item { + ResponseInputOutputItem::Message { content, .. } => { + let texts: Vec = content + .iter() + .filter_map(|part| match part { + ResponseContentPart::OutputText { text, .. } => Some(text.clone()), + ResponseContentPart::InputText { text } => Some(text.clone()), + ResponseContentPart::Unknown => None, + }) + .collect(); + if texts.is_empty() { + None + } else { + Some(texts.join(" ")) + } + } + ResponseInputOutputItem::Reasoning { content, .. } => { + let texts: Vec = content + .iter() + .map(|part| match part { + ResponseReasoningContent::ReasoningText { text } => text.clone(), + }) + .collect(); + if texts.is_empty() { + None + } else { + Some(texts.join(" ")) + } + } + ResponseInputOutputItem::FunctionToolCall { arguments, .. } => { + Some(arguments.clone()) + } + }) + .collect::>() + .join(" "), + } + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponsesResponse { + /// Response ID + pub id: String, + + /// Object type + #[serde(default = "default_object_type")] + pub object: String, + + /// Creation timestamp + pub created_at: i64, + + /// Response status + pub status: ResponseStatus, + + /// Error information if status is failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + + /// Incomplete details if response was truncated + #[serde(skip_serializing_if = "Option::is_none")] + pub incomplete_details: Option, + + /// System instructions used + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + /// Max output tokens setting + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + + /// Model name + pub model: String, + + /// Output items + #[serde(default)] + pub output: Vec, + + /// Whether parallel tool calls are enabled + #[serde(default = "default_true")] + pub parallel_tool_calls: bool, + + /// Previous response ID if this is a continuation + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_response_id: Option, + + /// Reasoning information + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + + /// Whether the response is stored + #[serde(default = "default_true")] + pub store: bool, + + /// Temperature setting used + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// Text format settings + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + + /// Tool choice setting + #[serde(default = "default_tool_choice")] + pub tool_choice: String, + + /// Available tools + #[serde(default)] + pub tools: Vec, + + /// Top-p setting used + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// Truncation strategy used + #[serde(skip_serializing_if = "Option::is_none")] + pub truncation: Option, + + /// Usage information + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + + /// User identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// Additional metadata + #[serde(default)] + pub metadata: HashMap, +} + +fn default_object_type() -> String { + "response".to_string() +} + +fn default_tool_choice() -> String { + "auto".to_string() +} + +impl ResponsesResponse { + /// Check if the response is complete + pub fn is_complete(&self) -> bool { + matches!(self.status, ResponseStatus::Completed) + } + + /// Check if the response is in progress + pub fn is_in_progress(&self) -> bool { + matches!(self.status, ResponseStatus::InProgress) + } + + /// Check if the response failed + pub fn is_failed(&self) -> bool { + matches!(self.status, ResponseStatus::Failed) + } +} + +impl ResponseOutputItem { + /// Create a new message output item + pub fn new_message( + id: String, + role: String, + content: Vec, + status: String, + ) -> Self { + Self::Message { + id, + role, + content, + status, + } + } + + /// Create a new reasoning output item + pub fn new_reasoning( + id: String, + summary: Vec, + content: Vec, + status: Option, + ) -> Self { + Self::Reasoning { + id, + summary, + content, + status, + } + } + + /// Create a new function tool call output item + pub fn new_function_tool_call( + id: String, + name: String, + arguments: String, + output: Option, + status: String, + ) -> Self { + Self::FunctionToolCall { + id, + name, + arguments, + output, + status, + } + } +} + +impl ResponseContentPart { + /// Create a new text content part + pub fn new_text( + text: String, + annotations: Vec, + logprobs: Option, + ) -> Self { + Self::OutputText { + text, + annotations, + logprobs, + } + } +} + +impl ResponseReasoningContent { + /// Create a new reasoning text content + pub fn new_reasoning_text(text: String) -> Self { + Self::ReasoningText { text } + } +} diff --git a/sgl-router/src/protocols/sampling_params.rs b/sgl-router/src/protocols/sampling_params.rs new file mode 100644 index 000000000..9055a53dd --- /dev/null +++ b/sgl-router/src/protocols/sampling_params.rs @@ -0,0 +1,119 @@ +use serde::{Deserialize, Serialize}; +use validator::Validate; + +use super::common::StringOrArray; + +/// Sampling parameters for text generation +#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)] +#[validate(schema(function = "validate_sampling_params"))] +pub struct SamplingParams { + /// Temperature for sampling (must be >= 0.0, no upper limit) + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0.0))] + pub temperature: Option, + /// Maximum number of new tokens to generate (must be >= 0) + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0))] + pub max_new_tokens: Option, + /// Top-p nucleus sampling (0.0 < top_p <= 1.0) + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(custom(function = "validate_top_p_value"))] + pub top_p: Option, + /// Top-k sampling (-1 to disable, or >= 1) + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(custom(function = "validate_top_k_value"))] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = -2.0, max = 2.0))] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = -2.0, max = 2.0))] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0.0, max = 2.0))] + pub repetition_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ignore_eos: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub skip_special_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ebnf: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[validate(range(min = 0.0, max = 1.0))] + pub min_p: Option, + /// Minimum number of new tokens (validated in schema function for cross-field check with max_new_tokens) + #[serde(skip_serializing_if = "Option::is_none")] + pub min_new_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub no_stop_trim: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub sampling_seed: Option, +} + +// ============================================================================ +// Shared Validation Functions +// ============================================================================ + +/// Validates top_p: 0.0 < top_p <= 1.0 (can't use range validator for open interval) +pub fn validate_top_p_value(top_p: f32) -> Result<(), validator::ValidationError> { + if !(top_p > 0.0 && top_p <= 1.0) { + return Err(validator::ValidationError::new( + "top_p must be in (0, 1] - greater than 0.0 and at most 1.0", + )); + } + Ok(()) +} + +/// Validates top_k: -1 (disabled) or >= 1 (special -1 case - can't use range validator) +pub fn validate_top_k_value(top_k: i32) -> Result<(), validator::ValidationError> { + if top_k != -1 && top_k < 1 { + return Err(validator::ValidationError::new( + "top_k must be -1 (disabled) or at least 1", + )); + } + Ok(()) +} + +// ============================================================================ +// SamplingParams-Specific Validation +// ============================================================================ + +/// Validation function for SamplingParams - cross-field validation only +fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::ValidationError> { + // 1. Cross-field validation: min_new_tokens <= max_new_tokens + if let (Some(min), Some(max)) = (params.min_new_tokens, params.max_new_tokens) { + if min > max { + return Err(validator::ValidationError::new( + "min_new_tokens cannot exceed max_new_tokens", + )); + } + } + + // 2. Validate mutually exclusive structured output constraints + let constraint_count = [ + params.regex.is_some(), + params.ebnf.is_some(), + params.json_schema.is_some(), + ] + .iter() + .filter(|&&x| x) + .count(); + + if constraint_count > 1 { + return Err(validator::ValidationError::new( + "only one of regex, ebnf, or json_schema can be set", + )); + } + + Ok(()) +} diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs deleted file mode 100644 index 394b0d28d..000000000 --- a/sgl-router/src/protocols/spec.rs +++ /dev/null @@ -1,2739 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::{to_value, Map, Value}; -use std::collections::HashMap; -use validator::Validate; - -use crate::protocols::validated::Normalizable; - -// Default model value when not specified -fn default_model() -> String { - "unknown".to_string() -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "role")] -pub enum ChatMessage { - #[serde(rename = "system")] - System { - content: String, - #[serde(skip_serializing_if = "Option::is_none")] - name: Option, - }, - #[serde(rename = "user")] - User { - content: UserMessageContent, - #[serde(skip_serializing_if = "Option::is_none")] - name: Option, - }, - #[serde(rename = "assistant")] - Assistant { - #[serde(skip_serializing_if = "Option::is_none")] - content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tool_calls: Option>, - /// Reasoning content for O1-style models (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - reasoning_content: Option, - }, - #[serde(rename = "tool")] - Tool { - content: String, - tool_call_id: String, - }, - #[serde(rename = "function")] - Function { content: String, name: String }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum UserMessageContent { - Text(String), - Parts(Vec), -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -pub enum ContentPart { - #[serde(rename = "text")] - Text { text: String }, - #[serde(rename = "image_url")] - ImageUrl { image_url: ImageUrl }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ImageUrl { - pub url: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub detail: Option, // "auto", "low", or "high" -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -pub enum ResponseFormat { - #[serde(rename = "text")] - Text, - #[serde(rename = "json_object")] - JsonObject, - #[serde(rename = "json_schema")] - JsonSchema { json_schema: JsonSchemaFormat }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct JsonSchemaFormat { - pub name: String, - pub schema: Value, - #[serde(skip_serializing_if = "Option::is_none")] - pub strict: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatMessageDelta { - #[serde(skip_serializing_if = "Option::is_none")] - pub role: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - /// Reasoning content delta for O1-style models (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning_content: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ToolCallDelta { - pub index: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(rename = "type")] - pub tool_type: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub function: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct FunctionCallDelta { - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)] -#[validate(schema(function = "validate_chat_cross_parameters"))] -pub struct ChatCompletionRequest { - /// A list of messages comprising the conversation so far - #[validate(custom(function = "validate_messages"))] - pub messages: Vec, - - /// ID of the model to use - #[serde(default = "default_model")] - pub model: String, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = -2.0, max = 2.0))] - pub frequency_penalty: Option, - - /// Deprecated: Replaced by tool_choice - #[serde(skip_serializing_if = "Option::is_none")] - #[deprecated(note = "Use tool_choice instead")] - pub function_call: Option, - - /// Deprecated: Replaced by tools - #[serde(skip_serializing_if = "Option::is_none")] - #[deprecated(note = "Use tools instead")] - pub functions: Option>, - - /// Modify the likelihood of specified tokens appearing in the completion - #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, - - /// Whether to return log probabilities of the output tokens - #[serde(default)] - pub logprobs: bool, - - /// Deprecated: Replaced by max_completion_tokens - #[serde(skip_serializing_if = "Option::is_none")] - #[deprecated(note = "Use max_completion_tokens instead")] - #[validate(range(min = 1))] - pub max_tokens: Option, - - /// An upper bound for the number of tokens that can be generated for a completion - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = 1))] - pub max_completion_tokens: Option, - - /// Developer-defined tags and values used for filtering completions in the dashboard - #[serde(skip_serializing_if = "Option::is_none")] - pub metadata: Option>, - - /// Output types that you would like the model to generate for this request - #[serde(skip_serializing_if = "Option::is_none")] - pub modalities: Option>, - - /// How many chat completion choices to generate for each input message - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = 1, max = 10))] - pub n: Option, - - /// Whether to enable parallel function calling during tool use - #[serde(skip_serializing_if = "Option::is_none")] - pub parallel_tool_calls: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = -2.0, max = 2.0))] - pub presence_penalty: Option, - - /// Cache key for prompts (beta feature) - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt_cache_key: Option, - - /// Effort level for reasoning models (low, medium, high) - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning_effort: Option, - - /// An object specifying the format that the model must output - #[serde(skip_serializing_if = "Option::is_none")] - pub response_format: Option, - - /// Safety identifier for content moderation - #[serde(skip_serializing_if = "Option::is_none")] - pub safety_identifier: Option, - - /// Deprecated: This feature is in Legacy mode - #[serde(skip_serializing_if = "Option::is_none")] - #[deprecated(note = "This feature is in Legacy mode")] - pub seed: Option, - - /// The service tier to use for this request - #[serde(skip_serializing_if = "Option::is_none")] - pub service_tier: Option, - - /// Up to 4 sequences where the API will stop generating further tokens - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(custom(function = "validate_stop"))] - pub stop: Option, - - /// If set, partial message deltas will be sent - #[serde(default)] - pub stream: bool, - - /// Options for streaming response - #[serde(skip_serializing_if = "Option::is_none")] - pub stream_options: Option, - - /// What sampling temperature to use, between 0 and 2 - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = 0.0, max = 2.0))] - pub temperature: Option, - - /// Controls which (if any) tool is called by the model - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_choice: Option, - - /// A list of tools the model may call - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, - - /// An integer between 0 and 20 specifying the number of most likely tokens to return - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = 0, max = 20))] - pub top_logprobs: Option, - - /// An alternative to sampling with temperature - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(custom(function = "validate_top_p_value"))] - pub top_p: Option, - - /// Verbosity level for debugging - #[serde(skip_serializing_if = "Option::is_none")] - pub verbosity: Option, - - // ============================================================================= - // Engine-Specific Sampling Parameters - // ============================================================================= - // These parameters are extensions beyond the OpenAI API specification and - // control model generation behavior in engine-specific ways. - // ============================================================================= - /// Top-k sampling parameter (-1 to disable) - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(custom(function = "validate_top_k_value"))] - pub top_k: Option, - - /// Min-p nucleus sampling parameter - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = 0.0, max = 1.0))] - pub min_p: Option, - - /// Minimum number of tokens to generate - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = 1))] - pub min_tokens: Option, - - /// Repetition penalty for reducing repetitive text - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = 0.0, max = 2.0))] - pub repetition_penalty: Option, - - /// Regex constraint for output generation - #[serde(skip_serializing_if = "Option::is_none")] - pub regex: Option, - - /// EBNF grammar constraint for structured output - #[serde(skip_serializing_if = "Option::is_none")] - pub ebnf: Option, - - /// Specific token IDs to use as stop conditions - #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, - - /// Skip trimming stop tokens from output - #[serde(default)] - pub no_stop_trim: bool, - - /// Ignore end-of-sequence tokens during generation - #[serde(default)] - pub ignore_eos: bool, - - /// Continue generating from final assistant message - #[serde(default)] - pub continue_final_message: bool, - - /// Skip special tokens during detokenization - #[serde(default = "default_true")] - pub skip_special_tokens: bool, - - /// Path to LoRA adapter(s) for model customization - #[serde(skip_serializing_if = "Option::is_none")] - pub lora_path: Option, - - /// Session parameters for continual prompting - #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, - - /// Separate reasoning content from final answer (O1-style models) - #[serde(default = "default_true")] - pub separate_reasoning: bool, - - /// Stream reasoning tokens during generation - #[serde(default = "default_true")] - pub stream_reasoning: bool, - - /// Chat template kwargs - #[serde(skip_serializing_if = "Option::is_none")] - pub chat_template_kwargs: Option>, - - /// Return model hidden states - #[serde(default)] - pub return_hidden_states: bool, - - /// Random seed for sampling for deterministic outputs - #[serde(skip_serializing_if = "Option::is_none")] - pub sampling_seed: Option, -} - -// Validation functions for ChatCompletionRequest -// These are automatically called by the validator derive macro - -/// Validates stop sequences (max 4, non-empty strings) -fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> { - match stop { - StringOrArray::String(s) => { - if s.is_empty() { - return Err(validator::ValidationError::new( - "stop sequences cannot be empty", - )); - } - } - StringOrArray::Array(arr) => { - if arr.len() > 4 { - return Err(validator::ValidationError::new( - "maximum 4 stop sequences allowed", - )); - } - for s in arr { - if s.is_empty() { - return Err(validator::ValidationError::new( - "stop sequences cannot be empty", - )); - } - } - } - } - Ok(()) -} - -/// Validates messages array is not empty and has valid content -fn validate_messages(messages: &[ChatMessage]) -> Result<(), validator::ValidationError> { - if messages.is_empty() { - return Err(validator::ValidationError::new("messages cannot be empty")); - } - - for msg in messages.iter() { - if let ChatMessage::User { content, .. } = msg { - match content { - UserMessageContent::Text(text) if text.is_empty() => { - return Err(validator::ValidationError::new( - "message content cannot be empty", - )); - } - UserMessageContent::Parts(parts) if parts.is_empty() => { - return Err(validator::ValidationError::new( - "message content parts cannot be empty", - )); - } - _ => {} - } - } - } - Ok(()) -} - -/// Validates top_p: 0.0 < top_p <= 1.0 (exclusive lower bound - can't use range validator) -fn validate_top_p_value(top_p: f32) -> Result<(), validator::ValidationError> { - if !(top_p > 0.0 && top_p <= 1.0) { - return Err(validator::ValidationError::new( - "top_p must be in (0, 1] - greater than 0.0 and at most 1.0", - )); - } - Ok(()) -} - -/// Validates top_k: -1 (disabled) or >= 1 (special -1 case - can't use range validator) -fn validate_top_k_value(top_k: i32) -> Result<(), validator::ValidationError> { - if top_k != -1 && top_k < 1 { - return Err(validator::ValidationError::new( - "top_k must be -1 (disabled) or at least 1", - )); - } - Ok(()) -} - -/// Schema-level validation for cross-field dependencies -fn validate_chat_cross_parameters( - req: &ChatCompletionRequest, -) -> Result<(), validator::ValidationError> { - // 1. Validate logprobs dependency - if req.top_logprobs.is_some() && !req.logprobs { - let mut e = validator::ValidationError::new("top_logprobs_requires_logprobs"); - e.message = Some("top_logprobs is only allowed when logprobs is enabled".into()); - return Err(e); - } - - // 2. Validate stream_options dependency - if req.stream_options.is_some() && !req.stream { - let mut e = validator::ValidationError::new("stream_options_requires_stream"); - e.message = - Some("The 'stream_options' parameter is only allowed when 'stream' is enabled".into()); - return Err(e); - } - - // 3. Validate token limits - min <= max - if let (Some(min), Some(max)) = (req.min_tokens, req.max_completion_tokens) { - if min > max { - let mut e = validator::ValidationError::new("min_tokens_exceeds_max"); - e.message = Some("min_tokens cannot exceed max_tokens/max_completion_tokens".into()); - return Err(e); - } - } - - // 4. Validate structured output conflicts - let has_json_format = matches!( - req.response_format, - Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. }) - ); - - if has_json_format && req.regex.is_some() { - let mut e = validator::ValidationError::new("regex_conflicts_with_json"); - e.message = Some("cannot use regex constraint with JSON response format".into()); - return Err(e); - } - - if has_json_format && req.ebnf.is_some() { - let mut e = validator::ValidationError::new("ebnf_conflicts_with_json"); - e.message = Some("cannot use EBNF constraint with JSON response format".into()); - return Err(e); - } - - // 5. Validate mutually exclusive structured output constraints - let constraint_count = [ - req.regex.is_some(), - req.ebnf.is_some(), - matches!(req.response_format, Some(ResponseFormat::JsonSchema { .. })), - ] - .iter() - .filter(|&&x| x) - .count(); - - if constraint_count > 1 { - let mut e = validator::ValidationError::new("multiple_constraints"); - e.message = Some("only one structured output constraint (regex, ebnf, or json_schema) can be active at a time".into()); - return Err(e); - } - - // 6. Validate response format JSON schema name - if let Some(ResponseFormat::JsonSchema { json_schema }) = &req.response_format { - if json_schema.name.is_empty() { - let mut e = validator::ValidationError::new("json_schema_name_empty"); - e.message = Some("JSON schema name cannot be empty".into()); - return Err(e); - } - } - - // 7. Validate tool_choice requires tools (except for "none") - if let Some(ref tool_choice) = req.tool_choice { - let has_tools = req.tools.as_ref().is_some_and(|t| !t.is_empty()); - - // Check if tool_choice is anything other than "none" - let is_some_choice = !matches!(tool_choice, ToolChoice::Value(ToolChoiceValue::None)); - - if is_some_choice && !has_tools { - let mut e = validator::ValidationError::new("tool_choice_requires_tools"); - e.message = Some("Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.".into()); - return Err(e); - } - - // Additional validation when tools are present - if has_tools { - let tools = req.tools.as_ref().unwrap(); - - match tool_choice { - ToolChoice::Function { function, .. } => { - // Validate that the specified function name exists in tools - let function_exists = tools.iter().any(|tool| { - tool.tool_type == "function" && tool.function.name == function.name - }); - - if !function_exists { - let mut e = - validator::ValidationError::new("tool_choice_function_not_found"); - e.message = Some( - format!( - "Invalid value for 'tool_choice': function '{}' not found in 'tools'.", - function.name - ) - .into(), - ); - return Err(e); - } - } - ToolChoice::AllowedTools { - mode, - tools: allowed_tools, - .. - } => { - // Validate mode is "auto" or "required" - if mode != "auto" && mode != "required" { - let mut e = validator::ValidationError::new("tool_choice_invalid_mode"); - e.message = Some(format!( - "Invalid value for 'tool_choice.mode': must be 'auto' or 'required', got '{}'.", - mode - ).into()); - return Err(e); - } - - // Validate that all referenced tool names exist in tools - for tool_ref in allowed_tools { - let tool_exists = tools.iter().any(|tool| { - tool.tool_type == tool_ref.tool_type - && tool.function.name == tool_ref.name - }); - - if !tool_exists { - let mut e = - validator::ValidationError::new("tool_choice_tool_not_found"); - e.message = Some(format!( - "Invalid value for 'tool_choice.tools': tool '{}' not found in 'tools'.", - tool_ref.name - ).into()); - return Err(e); - } - } - } - _ => {} - } - } - } - - Ok(()) -} - -impl Normalizable for ChatCompletionRequest { - /// Normalize the request by applying migrations and defaults: - /// 1. Migrate deprecated fields to their replacements - /// 2. Clear deprecated fields and log warnings - /// 3. Apply OpenAI defaults for tool_choice - fn normalize(&mut self) { - // Migrate deprecated max_tokens → max_completion_tokens - #[allow(deprecated)] - if self.max_completion_tokens.is_none() && self.max_tokens.is_some() { - tracing::warn!("max_tokens is deprecated, use max_completion_tokens instead"); - self.max_completion_tokens = self.max_tokens; - self.max_tokens = None; // Clear deprecated field - } - - // Migrate deprecated functions → tools - #[allow(deprecated)] - if self.tools.is_none() && self.functions.is_some() { - tracing::warn!("functions is deprecated, use tools instead"); - self.tools = self.functions.as_ref().map(|functions| { - functions - .iter() - .map(|func| Tool { - tool_type: "function".to_string(), - function: func.clone(), - }) - .collect() - }); - self.functions = None; // Clear deprecated field - } - - // Migrate deprecated function_call → tool_choice - #[allow(deprecated)] - if self.tool_choice.is_none() && self.function_call.is_some() { - tracing::warn!("function_call is deprecated, use tool_choice instead"); - self.tool_choice = self.function_call.as_ref().map(|fc| match fc { - FunctionCall::None => ToolChoice::Value(ToolChoiceValue::None), - FunctionCall::Auto => ToolChoice::Value(ToolChoiceValue::Auto), - FunctionCall::Function { name } => ToolChoice::Function { - tool_type: "function".to_string(), - function: FunctionChoice { name: name.clone() }, - }, - }); - self.function_call = None; // Clear deprecated field - } - - // Apply tool_choice defaults - if self.tool_choice.is_none() { - let has_tools = self.tools.as_ref().is_some_and(|t| !t.is_empty()); - - self.tool_choice = if has_tools { - Some(ToolChoice::Value(ToolChoiceValue::Auto)) - } else { - Some(ToolChoice::Value(ToolChoiceValue::None)) - }; - } - } -} - -impl GenerationRequest for ChatCompletionRequest { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_model(&self) -> Option<&str> { - Some(&self.model) - } - - fn extract_text_for_routing(&self) -> String { - // Extract text from messages for routing decisions - self.messages - .iter() - .filter_map(|msg| match msg { - ChatMessage::System { content, .. } => Some(content.clone()), - ChatMessage::User { content, .. } => match content { - UserMessageContent::Text(text) => Some(text.clone()), - UserMessageContent::Parts(parts) => { - let texts: Vec = parts - .iter() - .filter_map(|part| match part { - ContentPart::Text { text } => Some(text.clone()), - _ => None, - }) - .collect(); - Some(texts.join(" ")) - } - }, - ChatMessage::Assistant { - content, - reasoning_content, - .. - } => { - // Combine content and reasoning content for routing decisions - let main_content = content.clone().unwrap_or_default(); - let reasoning = reasoning_content.clone().unwrap_or_default(); - if main_content.is_empty() && reasoning.is_empty() { - None - } else { - Some(format!("{} {}", main_content, reasoning).trim().to_string()) - } - } - ChatMessage::Tool { content, .. } => Some(content.clone()), - ChatMessage::Function { content, .. } => Some(content.clone()), - }) - .collect::>() - .join(" ") - } -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionResponse { - pub id: String, - pub object: String, // "chat.completion" - pub created: u64, - pub model: String, - pub choices: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, -} - -/// Response message structure for ChatCompletionResponse (different from request ChatMessage) -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionMessage { - pub role: String, // Always "assistant" for responses - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - /// Reasoning content for O1-style models (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning_content: Option, - // Note: function_call is deprecated and not included - // Note: refusal, annotations, audio are not added yet -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatChoice { - pub index: u32, - pub message: ChatCompletionMessage, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - pub finish_reason: Option, // "stop", "length", "tool_calls", "content_filter", "function_call" - /// Information about which stop condition was matched - #[serde(skip_serializing_if = "Option::is_none")] - pub matched_stop: Option, // Can be string or integer - /// Hidden states from the model (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - pub hidden_states: Option>, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionStreamResponse { - pub id: String, - pub object: String, // "chat.completion.chunk" - pub created: u64, - pub model: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, - pub choices: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatStreamChoice { - pub index: u32, - pub delta: ChatMessageDelta, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - pub finish_reason: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub matched_stop: Option, -} - -// Completions API request types (v1/completions) - DEPRECATED but still supported - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionRequest { - /// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang) - pub model: String, - - /// The prompt(s) to generate completions for - pub prompt: StringOrArray, - - /// The suffix that comes after a completion of inserted text - #[serde(skip_serializing_if = "Option::is_none")] - pub suffix: Option, - - /// The maximum number of tokens to generate - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - - /// What sampling temperature to use, between 0 and 2 - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - - /// An alternative to sampling with temperature (nucleus sampling) - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - - /// How many completions to generate for each prompt - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, - - /// Whether to stream back partial progress - #[serde(default)] - pub stream: bool, - - /// Options for streaming response - #[serde(skip_serializing_if = "Option::is_none")] - pub stream_options: Option, - - /// Include the log probabilities on the logprobs most likely tokens - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - - /// Echo back the prompt in addition to the completion - #[serde(default)] - pub echo: bool, - - /// Up to 4 sequences where the API will stop generating further tokens - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - - /// Generates best_of completions server-side and returns the "best" - #[serde(skip_serializing_if = "Option::is_none")] - pub best_of: Option, - - /// Modify the likelihood of specified tokens appearing in the completion - #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, - - /// A unique identifier representing your end-user - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, - - /// If specified, our system will make a best effort to sample deterministically - #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, - - // -------- Engine Specific Sampling Parameters -------- - /// Top-k sampling parameter (-1 to disable) - #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, - - /// Min-p nucleus sampling parameter - #[serde(skip_serializing_if = "Option::is_none")] - pub min_p: Option, - - /// Minimum number of tokens to generate - #[serde(skip_serializing_if = "Option::is_none")] - pub min_tokens: Option, - - /// Repetition penalty for reducing repetitive text - #[serde(skip_serializing_if = "Option::is_none")] - pub repetition_penalty: Option, - - /// Regex constraint for output generation - #[serde(skip_serializing_if = "Option::is_none")] - pub regex: Option, - - /// EBNF grammar constraint for structured output - #[serde(skip_serializing_if = "Option::is_none")] - pub ebnf: Option, - - /// JSON schema constraint for structured output - #[serde(skip_serializing_if = "Option::is_none")] - pub json_schema: Option, - - /// Specific token IDs to use as stop conditions - #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, - - /// Skip trimming stop tokens from output - #[serde(default)] - pub no_stop_trim: bool, - - /// Ignore end-of-sequence tokens during generation - #[serde(default)] - pub ignore_eos: bool, - - /// Skip special tokens during detokenization - #[serde(default = "default_true")] - pub skip_special_tokens: bool, - - /// Path to LoRA adapter(s) for model customization - #[serde(skip_serializing_if = "Option::is_none")] - pub lora_path: Option, - - /// Session parameters for continual prompting - #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, - - /// Return model hidden states - #[serde(default)] - pub return_hidden_states: bool, - - /// Sampling seed for deterministic outputs - #[serde(skip_serializing_if = "Option::is_none")] - pub sampling_seed: Option, - - /// Additional fields including bootstrap info for PD routing - #[serde(flatten)] - pub other: Map, -} - -impl GenerationRequest for CompletionRequest { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_model(&self) -> Option<&str> { - Some(&self.model) - } - - fn extract_text_for_routing(&self) -> String { - match &self.prompt { - StringOrArray::String(s) => s.clone(), - StringOrArray::Array(v) => v.join(" "), - } - } -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionResponse { - pub id: String, - pub object: String, // "text_completion" - pub created: u64, - pub model: String, - pub choices: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionChoice { - pub text: String, - pub index: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - pub finish_reason: Option, // "stop", "length", "content_filter", etc. - /// Information about which stop condition was matched - #[serde(skip_serializing_if = "Option::is_none")] - pub matched_stop: Option, // Can be string or integer - /// Hidden states from the model (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - pub hidden_states: Option>, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionStreamResponse { - pub id: String, - pub object: String, // "text_completion" - pub created: u64, - pub choices: Vec, - pub model: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionStreamChoice { - pub text: String, - pub index: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - pub finish_reason: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ResponseTool { - #[serde(rename = "type")] - pub r#type: ResponseToolType, - // MCP-specific fields (used when type == "mcp") - #[serde(skip_serializing_if = "Option::is_none")] - pub server_url: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub authorization: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub server_label: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub server_description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub require_approval: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub allowed_tools: Option>, -} - -impl Default for ResponseTool { - fn default() -> Self { - Self { - r#type: ResponseToolType::WebSearchPreview, - server_url: None, - authorization: None, - server_label: None, - server_description: None, - require_approval: None, - allowed_tools: None, - } - } -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ResponseToolType { - WebSearchPreview, - CodeInterpreter, - Mcp, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ResponseReasoningParam { - #[serde(default = "default_reasoning_effort")] - #[serde(skip_serializing_if = "Option::is_none")] - pub effort: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub summary: Option, -} - -fn default_reasoning_effort() -> Option { - Some(ReasoningEffort::Medium) -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ReasoningEffort { - Low, - Medium, - High, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ReasoningSummary { - Auto, - Concise, - Detailed, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -#[serde(rename_all = "snake_case")] -pub enum ResponseInputOutputItem { - #[serde(rename = "message")] - Message { - id: String, - role: String, - content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - status: Option, - }, - #[serde(rename = "reasoning")] - Reasoning { - id: String, - #[serde(skip_serializing_if = "Vec::is_empty")] - summary: Vec, - content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - status: Option, - }, - #[serde(rename = "function_tool_call")] - FunctionToolCall { - id: String, - name: String, - arguments: String, - #[serde(skip_serializing_if = "Option::is_none")] - output: Option, - #[serde(skip_serializing_if = "Option::is_none")] - status: Option, - }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -#[serde(rename_all = "snake_case")] -pub enum ResponseContentPart { - #[serde(rename = "output_text")] - OutputText { - text: String, - #[serde(skip_serializing_if = "Vec::is_empty")] - annotations: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - logprobs: Option, - }, - #[serde(rename = "input_text")] - InputText { text: String }, - #[serde(other)] - Unknown, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -#[serde(rename_all = "snake_case")] -pub enum ResponseReasoningContent { - #[serde(rename = "reasoning_text")] - ReasoningText { text: String }, -} - -/// MCP Tool information for the mcp_list_tools output item -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct McpToolInfo { - pub name: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - pub input_schema: Value, - #[serde(skip_serializing_if = "Option::is_none")] - pub annotations: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -#[serde(rename_all = "snake_case")] -pub enum ResponseOutputItem { - #[serde(rename = "message")] - Message { - id: String, - role: String, - content: Vec, - status: String, - }, - #[serde(rename = "reasoning")] - Reasoning { - id: String, - #[serde(skip_serializing_if = "Vec::is_empty")] - summary: Vec, - content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - status: Option, - }, - #[serde(rename = "function_tool_call")] - FunctionToolCall { - id: String, - name: String, - arguments: String, - #[serde(skip_serializing_if = "Option::is_none")] - output: Option, - status: String, - }, - #[serde(rename = "mcp_list_tools")] - McpListTools { - id: String, - server_label: String, - tools: Vec, - }, - #[serde(rename = "mcp_call")] - McpCall { - id: String, - status: String, - #[serde(skip_serializing_if = "Option::is_none")] - approval_request_id: Option, - arguments: String, - #[serde(skip_serializing_if = "Option::is_none")] - error: Option, - name: String, - output: String, - server_label: String, - }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ServiceTier { - Auto, - Default, - Flex, - Scale, - Priority, -} - -impl Default for ServiceTier { - fn default() -> Self { - Self::Auto - } -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum Truncation { - Auto, - Disabled, -} - -impl Default for Truncation { - fn default() -> Self { - Self::Disabled - } -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ResponseStatus { - Queued, - InProgress, - Completed, - Failed, - Cancelled, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ReasoningInfo { - #[serde(skip_serializing_if = "Option::is_none")] - pub effort: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub summary: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ResponseTextFormat { - pub format: TextFormatType, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct TextFormatType { - #[serde(rename = "type")] - pub format_type: String, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum IncludeField { - #[serde(rename = "code_interpreter_call.outputs")] - CodeInterpreterCallOutputs, - #[serde(rename = "computer_call_output.output.image_url")] - ComputerCallOutputImageUrl, - #[serde(rename = "file_search_call.results")] - FileSearchCallResults, - #[serde(rename = "message.input_image.image_url")] - MessageInputImageUrl, - #[serde(rename = "message.output_text.logprobs")] - MessageOutputTextLogprobs, - #[serde(rename = "reasoning.encrypted_content")] - ReasoningEncryptedContent, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct UsageInfo { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt_tokens_details: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct PromptTokenUsageInfo { - pub cached_tokens: u32, -} - -/// OpenAI Responses API usage format (different from standard UsageInfo) -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ResponseUsage { - pub input_tokens: u32, - pub output_tokens: u32, - pub total_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub input_tokens_details: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub output_tokens_details: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum ResponsesUsage { - Classic(UsageInfo), - Modern(ResponseUsage), -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct InputTokensDetails { - pub cached_tokens: u32, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct OutputTokensDetails { - pub reasoning_tokens: u32, -} - -impl UsageInfo { - /// Convert to OpenAI Responses API format - pub fn to_response_usage(&self) -> ResponseUsage { - ResponseUsage { - input_tokens: self.prompt_tokens, - output_tokens: self.completion_tokens, - total_tokens: self.total_tokens, - input_tokens_details: self.prompt_tokens_details.as_ref().map(|details| { - InputTokensDetails { - cached_tokens: details.cached_tokens, - } - }), - output_tokens_details: self.reasoning_tokens.map(|tokens| OutputTokensDetails { - reasoning_tokens: tokens, - }), - } - } -} - -impl From for ResponseUsage { - fn from(usage: UsageInfo) -> Self { - usage.to_response_usage() - } -} - -impl ResponseUsage { - /// Convert back to standard UsageInfo format - pub fn to_usage_info(&self) -> UsageInfo { - UsageInfo { - prompt_tokens: self.input_tokens, - completion_tokens: self.output_tokens, - total_tokens: self.total_tokens, - reasoning_tokens: self - .output_tokens_details - .as_ref() - .map(|details| details.reasoning_tokens), - prompt_tokens_details: self.input_tokens_details.as_ref().map(|details| { - PromptTokenUsageInfo { - cached_tokens: details.cached_tokens, - } - }), - } - } -} - -#[derive(Debug, Clone, Default, Deserialize, Serialize)] -pub struct ResponsesGetParams { - #[serde(default)] - pub include: Vec, - #[serde(default)] - pub include_obfuscation: Option, - #[serde(default)] - pub starting_after: Option, - #[serde(default)] - pub stream: Option, -} - -impl ResponsesUsage { - pub fn to_response_usage(&self) -> ResponseUsage { - match self { - ResponsesUsage::Classic(usage) => usage.to_response_usage(), - ResponsesUsage::Modern(usage) => usage.clone(), - } - } - - pub fn to_usage_info(&self) -> UsageInfo { - match self { - ResponsesUsage::Classic(usage) => usage.clone(), - ResponsesUsage::Modern(usage) => usage.to_usage_info(), - } - } -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ResponsesRequest { - /// Run the request in the background - #[serde(skip_serializing_if = "Option::is_none")] - pub background: Option, - - /// Fields to include in the response - #[serde(skip_serializing_if = "Option::is_none")] - pub include: Option>, - - /// Input content - can be string or structured items - pub input: ResponseInput, - - /// System instructions for the model - #[serde(skip_serializing_if = "Option::is_none")] - pub instructions: Option, - - /// Maximum number of output tokens - #[serde(skip_serializing_if = "Option::is_none")] - pub max_output_tokens: Option, - - /// Maximum number of tool calls - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tool_calls: Option, - - /// Additional metadata - #[serde(skip_serializing_if = "Option::is_none")] - pub metadata: Option>, - - /// Model to use (optional to match vLLM) - #[serde(skip_serializing_if = "Option::is_none")] - pub model: Option, - - /// Optional conversation id to persist input/output as items - #[serde(skip_serializing_if = "Option::is_none")] - pub conversation: Option, - - /// Whether to enable parallel tool calls - #[serde(skip_serializing_if = "Option::is_none")] - pub parallel_tool_calls: Option, - - /// ID of previous response to continue from - #[serde(skip_serializing_if = "Option::is_none")] - pub previous_response_id: Option, - - /// Reasoning configuration - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning: Option, - - /// Service tier - #[serde(skip_serializing_if = "Option::is_none")] - pub service_tier: Option, - - /// Whether to store the response - #[serde(skip_serializing_if = "Option::is_none")] - pub store: Option, - - /// Whether to stream the response - #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, - - /// Temperature for sampling - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - - /// Tool choice behavior - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_choice: Option, - - /// Available tools - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, - - /// Number of top logprobs to return - #[serde(skip_serializing_if = "Option::is_none")] - pub top_logprobs: Option, - - /// Top-p sampling parameter - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - - /// Truncation behavior - #[serde(skip_serializing_if = "Option::is_none")] - pub truncation: Option, - - /// User identifier - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, - - /// Request ID - #[serde(skip_serializing_if = "Option::is_none")] - pub request_id: Option, - - /// Request priority - #[serde(default)] - pub priority: i32, - - /// Frequency penalty - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - - /// Presence penalty - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - - /// Stop sequences - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, - - /// Top-k sampling parameter (SGLang extension) - #[serde(default = "default_top_k")] - pub top_k: i32, - - /// Min-p sampling parameter (SGLang extension) - #[serde(default)] - pub min_p: f32, - - /// Repetition penalty (SGLang extension) - #[serde(default = "default_repetition_penalty")] - pub repetition_penalty: f32, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum ResponseInput { - Text(String), - Items(Vec), -} - -fn default_top_k() -> i32 { - -1 -} - -fn default_repetition_penalty() -> f32 { - 1.0 -} - -impl Default for ResponsesRequest { - fn default() -> Self { - Self { - background: None, - include: None, - input: ResponseInput::Text(String::new()), - instructions: None, - max_output_tokens: None, - max_tool_calls: None, - metadata: None, - model: None, - conversation: None, - parallel_tool_calls: None, - previous_response_id: None, - reasoning: None, - service_tier: None, - store: None, - stream: None, - temperature: None, - tool_choice: None, - tools: None, - top_logprobs: None, - top_p: None, - truncation: None, - user: None, - request_id: None, - priority: 0, - frequency_penalty: None, - presence_penalty: None, - stop: None, - top_k: default_top_k(), - min_p: 0.0, - repetition_penalty: default_repetition_penalty(), - } - } -} - -impl GenerationRequest for ResponsesRequest { - fn is_stream(&self) -> bool { - self.stream.unwrap_or(false) - } - - fn get_model(&self) -> Option<&str> { - self.model.as_deref() - } - - fn extract_text_for_routing(&self) -> String { - match &self.input { - ResponseInput::Text(text) => text.clone(), - ResponseInput::Items(items) => items - .iter() - .filter_map(|item| match item { - ResponseInputOutputItem::Message { content, .. } => { - let texts: Vec = content - .iter() - .filter_map(|part| match part { - ResponseContentPart::OutputText { text, .. } => Some(text.clone()), - ResponseContentPart::InputText { text } => Some(text.clone()), - ResponseContentPart::Unknown => None, - }) - .collect(); - if texts.is_empty() { - None - } else { - Some(texts.join(" ")) - } - } - ResponseInputOutputItem::Reasoning { content, .. } => { - let texts: Vec = content - .iter() - .map(|part| match part { - ResponseReasoningContent::ReasoningText { text } => text.clone(), - }) - .collect(); - if texts.is_empty() { - None - } else { - Some(texts.join(" ")) - } - } - ResponseInputOutputItem::FunctionToolCall { arguments, .. } => { - Some(arguments.clone()) - } - }) - .collect::>() - .join(" "), - } - } -} - -fn generate_response_id() -> String { - format!("resp_{}", uuid::Uuid::new_v4().simple()) -} - -fn current_timestamp() -> i64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs() as i64 -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ResponsesResponse { - /// Response ID - #[serde(default = "generate_response_id")] - pub id: String, - - /// Object type - #[serde(default = "default_object_type")] - pub object: String, - - /// Creation timestamp - #[serde(default = "current_timestamp")] - pub created_at: i64, - - /// Response status - pub status: ResponseStatus, - - /// Error information if status is failed - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, - - /// Incomplete details if response was truncated - #[serde(skip_serializing_if = "Option::is_none")] - pub incomplete_details: Option, - - /// System instructions used - #[serde(skip_serializing_if = "Option::is_none")] - pub instructions: Option, - - /// Max output tokens setting - #[serde(skip_serializing_if = "Option::is_none")] - pub max_output_tokens: Option, - - /// Model name - pub model: String, - - /// Output items - #[serde(default)] - pub output: Vec, - - /// Whether parallel tool calls are enabled - #[serde(default = "default_true")] - pub parallel_tool_calls: bool, - - /// Previous response ID if this is a continuation - #[serde(skip_serializing_if = "Option::is_none")] - pub previous_response_id: Option, - - /// Reasoning information - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning: Option, - - /// Whether the response is stored - #[serde(default = "default_true")] - pub store: bool, - - /// Temperature setting used - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - - /// Text format settings - #[serde(skip_serializing_if = "Option::is_none")] - pub text: Option, - - /// Tool choice setting - #[serde(default = "default_tool_choice")] - pub tool_choice: String, - - /// Available tools - #[serde(default)] - pub tools: Vec, - - /// Top-p setting used - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - - /// Truncation strategy used - #[serde(skip_serializing_if = "Option::is_none")] - pub truncation: Option, - - /// Usage information - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, - - /// User identifier - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, - - /// Additional metadata - #[serde(default)] - pub metadata: HashMap, -} - -fn default_object_type() -> String { - "response".to_string() -} - -fn default_tool_choice() -> String { - "auto".to_string() -} - -impl ResponsesResponse { - /// Create a response from a request - #[allow(clippy::too_many_arguments)] - pub fn from_request( - request: &ResponsesRequest, - _sampling_params: &HashMap, - model_name: String, - created_time: i64, - output: Vec, - status: ResponseStatus, - usage: Option, - ) -> Self { - Self { - id: request - .request_id - .clone() - .expect("request_id should be set by middleware"), - object: "response".to_string(), - created_at: created_time, - status, - error: None, - incomplete_details: None, - instructions: request.instructions.clone(), - max_output_tokens: request.max_output_tokens, - model: model_name, - output, - parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true), - previous_response_id: request.previous_response_id.clone(), - reasoning: request.reasoning.as_ref().map(|r| ReasoningInfo { - effort: r.effort.as_ref().map(|e| format!("{:?}", e)), - summary: None, - }), - store: request.store.unwrap_or(false), - temperature: request.temperature, - text: Some(ResponseTextFormat { - format: TextFormatType { - format_type: "text".to_string(), - }, - }), - tool_choice: match &request.tool_choice { - Some(ToolChoice::Value(ToolChoiceValue::Auto)) => "auto".to_string(), - Some(ToolChoice::Value(ToolChoiceValue::Required)) => "required".to_string(), - Some(ToolChoice::Value(ToolChoiceValue::None)) => "none".to_string(), - Some(ToolChoice::Function { .. }) => "function".to_string(), - Some(ToolChoice::AllowedTools { mode, .. }) => mode.clone(), - None => "auto".to_string(), - }, - tools: request.tools.clone().unwrap_or_default(), - top_p: request.top_p, - truncation: match &request.truncation { - Some(Truncation::Auto) => Some("auto".to_string()), - Some(Truncation::Disabled) => Some("disabled".to_string()), - None => None, - }, - usage: usage.map(ResponsesUsage::Classic), - user: request.user.clone(), - metadata: request.metadata.clone().unwrap_or_default(), - } - } - - /// Create a new response with default values - pub fn new(request_id: String, model: String, status: ResponseStatus) -> Self { - Self { - id: request_id, - object: "response".to_string(), - created_at: current_timestamp(), - status, - error: None, - incomplete_details: None, - instructions: None, - max_output_tokens: None, - model, - output: Vec::new(), - parallel_tool_calls: true, - previous_response_id: None, - reasoning: None, - store: true, - temperature: None, - text: None, - tool_choice: "auto".to_string(), - tools: Vec::new(), - top_p: None, - truncation: None, - usage: None, - user: None, - metadata: HashMap::new(), - } - } - - /// Add an output item to the response - pub fn add_output(&mut self, item: ResponseOutputItem) { - self.output.push(item); - } - - /// Set the usage information - pub fn set_usage(&mut self, usage: UsageInfo) { - self.usage = Some(ResponsesUsage::Classic(usage)); - } - - /// Update the status - pub fn set_status(&mut self, status: ResponseStatus) { - self.status = status; - } - - /// Check if the response is complete - pub fn is_complete(&self) -> bool { - matches!(self.status, ResponseStatus::Completed) - } - - /// Check if the response is in progress - pub fn is_in_progress(&self) -> bool { - matches!(self.status, ResponseStatus::InProgress) - } - - /// Check if the response failed - pub fn is_failed(&self) -> bool { - matches!(self.status, ResponseStatus::Failed) - } - - /// Check if the response was cancelled - pub fn is_cancelled(&self) -> bool { - matches!(self.status, ResponseStatus::Cancelled) - } - - /// Check if the response is queued - pub fn is_queued(&self) -> bool { - matches!(self.status, ResponseStatus::Queued) - } - - /// Convert usage to OpenAI Responses API format - pub fn usage_in_response_format(&self) -> Option { - self.usage.as_ref().map(|usage| usage.to_response_usage()) - } - - /// Get the response as a JSON value with usage in response format - pub fn to_response_format(&self) -> Value { - let mut response = to_value(self).unwrap_or(Value::Null); - - // Convert usage to response format if present - if let Some(usage) = &self.usage { - if let Ok(usage_value) = to_value(usage.to_response_usage()) { - response["usage"] = usage_value; - } - } - - response - } -} - -impl ResponseOutputItem { - /// Create a new message output item - pub fn new_message( - id: String, - role: String, - content: Vec, - status: String, - ) -> Self { - Self::Message { - id, - role, - content, - status, - } - } - - /// Create a new reasoning output item - pub fn new_reasoning( - id: String, - summary: Vec, - content: Vec, - status: Option, - ) -> Self { - Self::Reasoning { - id, - summary, - content, - status, - } - } - - /// Create a new function tool call output item - pub fn new_function_tool_call( - id: String, - name: String, - arguments: String, - output: Option, - status: String, - ) -> Self { - Self::FunctionToolCall { - id, - name, - arguments, - output, - status, - } - } -} - -impl ResponseContentPart { - /// Create a new text content part - pub fn new_text( - text: String, - annotations: Vec, - logprobs: Option, - ) -> Self { - Self::OutputText { - text, - annotations, - logprobs, - } - } -} - -impl ResponseReasoningContent { - /// Create a new reasoning text content - pub fn new_reasoning_text(text: String) -> Self { - Self::ReasoningText { text } - } -} - -impl UsageInfo { - /// Create a new usage info with token counts - pub fn new(prompt_tokens: u32, completion_tokens: u32, reasoning_tokens: Option) -> Self { - Self { - prompt_tokens, - completion_tokens, - total_tokens: prompt_tokens + completion_tokens, - reasoning_tokens, - prompt_tokens_details: None, - } - } - - /// Create usage info with cached token details - pub fn new_with_cached( - prompt_tokens: u32, - completion_tokens: u32, - reasoning_tokens: Option, - cached_tokens: u32, - ) -> Self { - Self { - prompt_tokens, - completion_tokens, - total_tokens: prompt_tokens + completion_tokens, - reasoning_tokens, - prompt_tokens_details: Some(PromptTokenUsageInfo { cached_tokens }), - } - } -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct StreamOptions { - #[serde(skip_serializing_if = "Option::is_none")] - pub include_usage: Option, -} - -/// Tool choice value for simple string options -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ToolChoiceValue { - Auto, - Required, - None, -} - -/// Tool choice for both Chat Completion and Responses APIs -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum ToolChoice { - Value(ToolChoiceValue), - Function { - #[serde(rename = "type")] - tool_type: String, // "function" - function: FunctionChoice, - }, - AllowedTools { - #[serde(rename = "type")] - tool_type: String, // "allowed_tools" - mode: String, // "auto" | "required" TODO: need validation - tools: Vec, - }, -} - -impl Default for ToolChoice { - fn default() -> Self { - Self::Value(ToolChoiceValue::Auto) - } -} - -/// Function choice specification for ToolChoice::Function -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct FunctionChoice { - pub name: String, -} - -/// Tool reference for ToolChoice::AllowedTools -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ToolReference { - #[serde(rename = "type")] - pub tool_type: String, // "function" - pub name: String, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Tool { - #[serde(rename = "type")] - pub tool_type: String, // "function" - pub function: Function, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Function { - pub name: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - pub parameters: Value, // JSON Schema - /// Whether to enable strict schema adherence (OpenAI structured outputs) - #[serde(skip_serializing_if = "Option::is_none")] - pub strict: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ToolCall { - pub id: String, - #[serde(rename = "type")] - pub tool_type: String, // "function" - pub function: FunctionCallResponse, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum FunctionCall { - None, - Auto, - Function { name: String }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct FunctionCallResponse { - pub name: String, - #[serde(default)] - pub arguments: Option, // JSON string -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub completion_tokens_details: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionTokensDetails { - pub reasoning_tokens: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct LogProbs { - pub tokens: Vec, - pub token_logprobs: Vec>, - pub top_logprobs: Vec>>, - pub text_offset: Vec, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum ChatLogProbs { - Detailed { - #[serde(skip_serializing_if = "Option::is_none")] - content: Option>, - }, - Raw(Value), -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatLogProbsContent { - pub token: String, - pub logprob: f32, - pub bytes: Option>, - pub top_logprobs: Vec, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct TopLogProb { - pub token: String, - pub logprob: f32, - pub bytes: Option>, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ErrorResponse { - pub error: ErrorDetail, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ErrorDetail { - pub message: String, - #[serde(rename = "type")] - pub error_type: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub param: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub code: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum InputIds { - Single(Vec), - Batch(Vec>), -} - -#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)] -#[validate(schema(function = "validate_sampling_params"))] -pub struct SamplingParams { - /// Temperature for sampling (must be >= 0.0, no upper limit) - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = 0.0))] - pub temperature: Option, - /// Maximum number of new tokens to generate (must be >= 0) - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = 0))] - pub max_new_tokens: Option, - /// Top-p nucleus sampling (0.0 < top_p <= 1.0) - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(custom(function = "validate_top_p_value"))] - pub top_p: Option, - /// Top-k sampling (-1 to disable, or >= 1) - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(custom(function = "validate_top_k_value"))] - pub top_k: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = -2.0, max = 2.0))] - pub frequency_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = -2.0, max = 2.0))] - pub presence_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = 0.0, max = 2.0))] - pub repetition_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub ignore_eos: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub skip_special_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub json_schema: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub regex: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub ebnf: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[validate(range(min = 0.0, max = 1.0))] - pub min_p: Option, - /// Minimum number of new tokens (validated in schema function for cross-field check with max_new_tokens) - #[serde(skip_serializing_if = "Option::is_none")] - pub min_new_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub no_stop_trim: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub sampling_seed: Option, -} - -/// Validation function for SamplingParams - cross-field validation only -fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::ValidationError> { - // 1. Cross-field validation: min_new_tokens <= max_new_tokens - if let (Some(min), Some(max)) = (params.min_new_tokens, params.max_new_tokens) { - if min > max { - return Err(validator::ValidationError::new( - "min_new_tokens cannot exceed max_new_tokens", - )); - } - } - - // 2. Validate mutually exclusive structured output constraints - let constraint_count = [ - params.regex.is_some(), - params.ebnf.is_some(), - params.json_schema.is_some(), - ] - .iter() - .filter(|&&x| x) - .count(); - - if constraint_count > 1 { - return Err(validator::ValidationError::new( - "only one of regex, ebnf, or json_schema can be set", - )); - } - - Ok(()) -} - -#[derive(Clone, Debug, Serialize, Deserialize, Validate)] -#[validate(schema(function = "validate_generate_request"))] -pub struct GenerateRequest { - /// Text input - SGLang native format - #[serde(skip_serializing_if = "Option::is_none")] - pub text: Option, - - /// Input IDs for tokenized input - #[serde(skip_serializing_if = "Option::is_none")] - pub input_ids: Option, - - /// Input embeddings for direct embedding input - /// Can be a 2D array (single request) or 3D array (batch of requests) - /// Placeholder for future use - #[serde(skip_serializing_if = "Option::is_none")] - pub input_embeds: Option, - - /// Image input data - /// Can be an image instance, file name, URL, or base64 encoded string - /// Supports single images, lists of images, or nested lists for batch processing - /// Placeholder for future use - #[serde(skip_serializing_if = "Option::is_none")] - pub image_data: Option, - - /// Video input data - /// Can be a file name, URL, or base64 encoded string - /// Supports single videos, lists of videos, or nested lists for batch processing - /// Placeholder for future use - #[serde(skip_serializing_if = "Option::is_none")] - pub video_data: Option, - - /// Audio input data - /// Can be a file name, URL, or base64 encoded string - /// Supports single audio files, lists of audio, or nested lists for batch processing - /// Placeholder for future use - #[serde(skip_serializing_if = "Option::is_none")] - pub audio_data: Option, - - /// Sampling parameters (sglang style) - #[serde(skip_serializing_if = "Option::is_none")] - pub sampling_params: Option, - - /// Whether to return logprobs - #[serde(skip_serializing_if = "Option::is_none")] - pub return_logprob: Option, - - /// If return logprobs, the start location in the prompt for returning logprobs. - #[serde(skip_serializing_if = "Option::is_none")] - pub logprob_start_len: Option, - - /// If return logprobs, the number of top logprobs to return at each position. - #[serde(skip_serializing_if = "Option::is_none")] - pub top_logprobs_num: Option, - - /// If return logprobs, the token ids to return logprob for. - #[serde(skip_serializing_if = "Option::is_none")] - pub token_ids_logprob: Option>, - - /// Whether to detokenize tokens in text in the returned logprobs. - #[serde(default)] - pub return_text_in_logprobs: bool, - - /// Whether to stream the response - #[serde(default)] - pub stream: bool, - - /// Whether to log metrics for this request (e.g. health_generate calls do not log metrics) - #[serde(default = "default_true")] - pub log_metrics: bool, - - /// Return model hidden states - #[serde(default)] - pub return_hidden_states: bool, - - /// The modalities of the image data [image, multi-images, video] - #[serde(skip_serializing_if = "Option::is_none")] - pub modalities: Option>, - - /// Session parameters for continual prompting - #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, - - /// Path to LoRA adapter(s) for model customization - #[serde(skip_serializing_if = "Option::is_none")] - pub lora_path: Option, - - /// LoRA adapter ID (if pre-loaded) - #[serde(skip_serializing_if = "Option::is_none")] - pub lora_id: Option, - - /// Custom logit processor for advanced sampling control. Must be a serialized instance - /// of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py - /// Use the processor's `to_str()` method to generate the serialized string. - #[serde(skip_serializing_if = "Option::is_none")] - pub custom_logit_processor: Option, - - /// For disaggregated inference - #[serde(skip_serializing_if = "Option::is_none")] - pub bootstrap_host: Option, - - /// For disaggregated inference - #[serde(skip_serializing_if = "Option::is_none")] - pub bootstrap_port: Option, - - /// For disaggregated inference - #[serde(skip_serializing_if = "Option::is_none")] - pub bootstrap_room: Option, - - /// For disaggregated inference - #[serde(skip_serializing_if = "Option::is_none")] - pub bootstrap_pair_key: Option, - - /// Data parallel rank routing - #[serde(skip_serializing_if = "Option::is_none")] - pub data_parallel_rank: Option, - - /// Background response - #[serde(default)] - pub background: bool, - - /// Conversation ID for tracking - #[serde(skip_serializing_if = "Option::is_none")] - pub conversation_id: Option, - - /// Priority for the request - #[serde(skip_serializing_if = "Option::is_none")] - pub priority: Option, - - /// Extra key for classifying the request (e.g. cache_salt) - #[serde(skip_serializing_if = "Option::is_none")] - pub extra_key: Option, - - /// Whether to disallow logging for this request (e.g. due to ZDR) - #[serde(default)] - pub no_logs: bool, - - /// Custom metric labels - #[serde(skip_serializing_if = "Option::is_none")] - pub custom_labels: Option>, - - /// Whether to return bytes for image generation - #[serde(default)] - pub return_bytes: bool, - - /// Whether to return entropy - #[serde(default)] - pub return_entropy: bool, - - /// Request ID for tracking (inherited from BaseReq in Python) - #[serde(skip_serializing_if = "Option::is_none")] - pub rid: Option, -} - -impl Normalizable for GenerateRequest { - // Use default no-op implementation - no normalization needed for GenerateRequest -} - -/// Validation function for GenerateRequest - ensure exactly one input type is provided -fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> { - // Exactly one of text or input_ids must be provided - // Note: input_embeds not yet supported in Rust implementation - let has_text = req.text.is_some(); - let has_input_ids = req.input_ids.is_some(); - - let count = [has_text, has_input_ids].iter().filter(|&&x| x).count(); - - if count == 0 { - return Err(validator::ValidationError::new( - "Either text or input_ids should be provided.", - )); - } - - if count > 1 { - return Err(validator::ValidationError::new( - "Either text or input_ids should be provided.", - )); - } - - Ok(()) -} - -impl GenerationRequest for GenerateRequest { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_model(&self) -> Option<&str> { - // Generate requests typically don't have a model field - None - } - - fn extract_text_for_routing(&self) -> String { - // Check fields in priority order: text, input_ids - if let Some(ref text) = self.text { - return text.clone(); - } - - if let Some(ref input_ids) = self.input_ids { - return match input_ids { - InputIds::Single(ids) => ids - .iter() - .map(|&id| id.to_string()) - .collect::>() - .join(" "), - InputIds::Batch(batches) => batches - .iter() - .flat_map(|batch| batch.iter().map(|&id| id.to_string())) - .collect::>() - .join(" "), - }; - } - - // No text input found - String::new() - } -} - -// ============================================================================ -// SGLang Generate Response Types -// ============================================================================ - -/// SGLang generate response (single completion or array for n>1) -/// -/// Format for n=1: -/// ```json -/// { -/// "text": "...", -/// "output_ids": [...], -/// "meta_info": { ... } -/// } -/// ``` -/// -/// Format for n>1: -/// ```json -/// [ -/// {"text": "...", "output_ids": [...], "meta_info": {...}}, -/// {"text": "...", "output_ids": [...], "meta_info": {...}} -/// ] -/// ``` -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GenerateResponse { - pub text: String, - pub output_ids: Vec, - pub meta_info: GenerateMetaInfo, -} - -/// Metadata for a single generate completion -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GenerateMetaInfo { - pub id: String, - pub finish_reason: GenerateFinishReason, - pub prompt_tokens: u32, - pub weight_version: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub input_token_logprobs: Option>>>, - #[serde(skip_serializing_if = "Option::is_none")] - pub output_token_logprobs: Option>>>, - pub completion_tokens: u32, - pub cached_tokens: u32, - pub e2e_latency: f64, - #[serde(skip_serializing_if = "Option::is_none")] - pub matched_stop: Option, -} - -/// Finish reason for generate endpoint -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "lowercase")] -pub enum GenerateFinishReason { - Length { - length: u32, - }, - Stop, - #[serde(untagged)] - Other(Value), -} - -/// Rerank request for scoring documents against a query -/// Used for RAG systems and document relevance scoring -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RerankRequest { - /// The query text to rank documents against - pub query: String, - - /// List of documents to be ranked - pub documents: Vec, - - /// Model to use for reranking - #[serde(default = "default_model")] - pub model: String, - - /// Maximum number of documents to return (optional) - pub top_k: Option, - - /// Whether to return documents in addition to scores - #[serde(default = "default_return_documents")] - pub return_documents: bool, - - // SGLang specific extensions - /// Request ID for tracking - pub rid: Option, - - /// User identifier - pub user: Option, -} - -fn default_return_documents() -> bool { - true -} - -/// Individual rerank result -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RerankResult { - /// Relevance score for the document - pub score: f32, - - /// The document text (if return_documents was true) - #[serde(skip_serializing_if = "Option::is_none")] - pub document: Option, - - /// Original index of the document in the request - pub index: usize, - - /// Additional metadata about the ranking - #[serde(skip_serializing_if = "Option::is_none")] - pub meta_info: Option>, -} - -/// Rerank response containing sorted results -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RerankResponse { - /// Ranked results sorted by score (highest first) - pub results: Vec, - - /// Model used for reranking - pub model: String, - - /// Usage information - pub usage: Option, - - /// Response object type - #[serde(default = "default_rerank_object")] - pub object: String, - - /// Response ID - pub id: Option, - - /// Creation timestamp - pub created: i64, -} - -fn default_rerank_object() -> String { - "rerank".to_string() -} - -/// V1 API compatibility format for rerank requests -/// Matches Python's V1RerankReqInput -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct V1RerankReqInput { - pub query: String, - pub documents: Vec, -} - -/// Convert V1RerankReqInput to RerankRequest -impl From for RerankRequest { - fn from(v1: V1RerankReqInput) -> Self { - RerankRequest { - query: v1.query, - documents: v1.documents, - model: default_model(), - top_k: None, - return_documents: true, - rid: None, - user: None, - } - } -} - -/// Implementation of GenerationRequest trait for RerankRequest -impl GenerationRequest for RerankRequest { - fn get_model(&self) -> Option<&str> { - Some(&self.model) - } - - fn is_stream(&self) -> bool { - false // Reranking doesn't support streaming - } - - fn extract_text_for_routing(&self) -> String { - self.query.clone() - } -} - -impl RerankRequest { - pub fn validate(&self) -> Result<(), String> { - // Validate query is not empty - if self.query.trim().is_empty() { - return Err("Query cannot be empty".to_string()); - } - - // Validate documents list - if self.documents.is_empty() { - return Err("Documents list cannot be empty".to_string()); - } - - // Validate top_k if specified - if let Some(k) = self.top_k { - if k == 0 { - return Err("top_k must be greater than 0".to_string()); - } - if k > self.documents.len() { - // This is allowed but we log a warning - tracing::warn!( - "top_k ({}) is greater than number of documents ({})", - k, - self.documents.len() - ); - } - } - - Ok(()) - } - - /// Get the effective top_k value - pub fn effective_top_k(&self) -> usize { - self.top_k.unwrap_or(self.documents.len()) - } -} - -impl RerankResponse { - pub fn new( - results: Vec, - model: String, - request_id: Option, - ) -> Self { - RerankResponse { - results, - model, - usage: None, - object: default_rerank_object(), - id: request_id, - created: current_timestamp(), - } - } - - /// Sort results by score in descending order - pub fn sort_by_score(&mut self) { - self.results.sort_by(|a, b| { - b.score - .partial_cmp(&a.score) - .unwrap_or(std::cmp::Ordering::Equal) - }); - } - - /// Apply top_k limit to results - pub fn apply_top_k(&mut self, k: usize) { - self.results.truncate(k); - } - - /// Drop documents from results - pub fn drop_documents(&mut self) { - self.results.iter_mut().for_each(|result| { - result.document = None; - }); - } -} - -/// Embeddings request compatible with OpenAI API -/// We intentionally keep fields flexible to pass through to workers. -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct EmbeddingRequest { - /// ID of the model to use - pub model: String, - - /// Input can be a string, array of strings, tokens, or batch inputs - pub input: Value, - - /// Optional encoding format (e.g., "float", "base64") - #[serde(skip_serializing_if = "Option::is_none")] - pub encoding_format: Option, - - /// Optional user identifier - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, - - /// Optional number of dimensions for the embedding - #[serde(skip_serializing_if = "Option::is_none")] - pub dimensions: Option, - - /// SGLang extension: request id for tracking - #[serde(skip_serializing_if = "Option::is_none")] - pub rid: Option, -} - -impl GenerationRequest for EmbeddingRequest { - fn is_stream(&self) -> bool { - // Embeddings are non-streaming - false - } - - fn get_model(&self) -> Option<&str> { - Some(&self.model) - } - - fn extract_text_for_routing(&self) -> String { - // Best effort: extract text content for routing decisions - match &self.input { - Value::String(s) => s.clone(), - Value::Array(arr) => arr - .iter() - .filter_map(|v| v.as_str()) - .collect::>() - .join(" "), - _ => String::new(), - } - } -} - -/// Helper function for serde default value -pub fn default_true() -> bool { - true -} - -/// Common trait for all generation requests across different APIs -pub trait GenerationRequest: Send + Sync { - /// Check if the request is for streaming - fn is_stream(&self) -> bool; - - /// Get the model name if specified - fn get_model(&self) -> Option<&str>; - - /// Extract text content for routing decisions - fn extract_text_for_routing(&self) -> String; -} - -/// Helper type for string or array of strings -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] -#[serde(untagged)] -pub enum StringOrArray { - String(String), - Array(Vec), -} -impl StringOrArray { - /// Get the number of items in the StringOrArray - pub fn len(&self) -> usize { - match self { - StringOrArray::String(_) => 1, - StringOrArray::Array(arr) => arr.len(), - } - } - - /// Check if the StringOrArray is empty - pub fn is_empty(&self) -> bool { - match self { - StringOrArray::String(s) => s.is_empty(), - StringOrArray::Array(arr) => arr.is_empty(), - } - } - - /// Convert to a vector of strings - pub fn to_vec(&self) -> Vec { - match self { - StringOrArray::String(s) => vec![s.clone()], - StringOrArray::Array(arr) => arr.clone(), - } - } -} - -/// LoRA adapter path - can be single path or batch of paths (SGLang extension) -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum LoRAPath { - Single(Option), - Batch(Vec>), -} diff --git a/sgl-router/src/protocols/validated.rs b/sgl-router/src/protocols/validated.rs index 58958f66b..4e88def58 100644 --- a/sgl-router/src/protocols/validated.rs +++ b/sgl-router/src/protocols/validated.rs @@ -83,20 +83,11 @@ where // Then, automatically validate the data data.validate().map_err(|validation_errors| { - // Extract the first error message from the validation errors - let error_message = validation_errors - .field_errors() - .values() - .flat_map(|errors| errors.iter()) - .find_map(|e| e.message.as_ref()) - .map(|m| m.to_string()) - .unwrap_or_else(|| "Validation failed".to_string()); - ( StatusCode::BAD_REQUEST, Json(json!({ "error": { - "message": error_message, + "message": validation_errors.to_string(), "type": "invalid_request_error", "code": 400 } diff --git a/sgl-router/src/routers/grpc/context.rs b/sgl-router/src/routers/grpc/context.rs index edd5a94d7..dc1f7a3c2 100644 --- a/sgl-router/src/routers/grpc/context.rs +++ b/sgl-router/src/routers/grpc/context.rs @@ -12,9 +12,8 @@ use serde_json::Value; use crate::core::Worker; use crate::grpc_client::{proto, SglangSchedulerClient}; -use crate::protocols::spec::{ - ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, GenerateResponse, -}; +use crate::protocols::chat::{ChatCompletionRequest, ChatCompletionResponse}; +use crate::protocols::generate::{GenerateRequest, GenerateResponse}; use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; use crate::tokenizer::stop::StopSequenceDecoder; use crate::tokenizer::traits::Tokenizer; diff --git a/sgl-router/src/routers/grpc/mod.rs b/sgl-router/src/routers/grpc/mod.rs index 2378ae9b9..14ed36de4 100644 --- a/sgl-router/src/routers/grpc/mod.rs +++ b/sgl-router/src/routers/grpc/mod.rs @@ -1,7 +1,7 @@ //! gRPC router implementations use crate::grpc_client::proto; -use crate::protocols::spec::StringOrArray; +use crate::protocols::common::StringOrArray; pub mod context; pub mod pd_router; diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index ade564e51..1e524c27c 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -3,10 +3,12 @@ use crate::config::types::RetryConfig; use crate::core::{ConnectionMode, WorkerRegistry, WorkerType}; use crate::policies::PolicyRegistry; -use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, - ResponsesGetParams, ResponsesRequest, -}; +use crate::protocols::chat::ChatCompletionRequest; +use crate::protocols::completion::CompletionRequest; +use crate::protocols::embedding::EmbeddingRequest; +use crate::protocols::generate::GenerateRequest; +use crate::protocols::rerank::RerankRequest; +use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; use crate::routers::RouterTrait; use crate::server::AppContext; diff --git a/sgl-router/src/routers/grpc/pipeline.rs b/sgl-router/src/routers/grpc/pipeline.rs index ec46757aa..dc595f673 100644 --- a/sgl-router/src/routers/grpc/pipeline.rs +++ b/sgl-router/src/routers/grpc/pipeline.rs @@ -14,7 +14,9 @@ use super::utils; use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; use crate::grpc_client::proto; use crate::policies::PolicyRegistry; -use crate::protocols::spec::{ChatCompletionRequest, GenerateRequest, InputIds}; +use crate::protocols::chat::ChatCompletionRequest; +use crate::protocols::common::InputIds; +use crate::protocols::generate::GenerateRequest; use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; use crate::tokenizer::traits::Tokenizer; use crate::tool_parser::ParserFactory as ToolParserFactory; diff --git a/sgl-router/src/routers/grpc/processing.rs b/sgl-router/src/routers/grpc/processing.rs index 294b7d6af..886ec0d95 100644 --- a/sgl-router/src/routers/grpc/processing.rs +++ b/sgl-router/src/routers/grpc/processing.rs @@ -9,11 +9,13 @@ use serde_json::Value; use tracing::error; use crate::grpc_client::proto; -use crate::protocols::spec::{ +use crate::protocols::chat::{ ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, - FunctionCallResponse, GenerateMetaInfo, GenerateRequest, GenerateResponse, ToolCall, - ToolChoice, ToolChoiceValue, Usage, }; +use crate::protocols::common::{ + FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage, +}; +use crate::protocols::generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse}; use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; use crate::tokenizer::traits::Tokenizer; diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index d167e7036..d798c851b 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -14,10 +14,12 @@ use tracing::debug; use crate::config::types::RetryConfig; use crate::core::WorkerRegistry; use crate::policies::PolicyRegistry; -use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, - ResponsesGetParams, ResponsesRequest, -}; +use crate::protocols::chat::ChatCompletionRequest; +use crate::protocols::completion::CompletionRequest; +use crate::protocols::embedding::EmbeddingRequest; +use crate::protocols::generate::GenerateRequest; +use crate::protocols::rerank::RerankRequest; +use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; use crate::routers::RouterTrait; use crate::server::AppContext; diff --git a/sgl-router/src/routers/grpc/streaming.rs b/sgl-router/src/routers/grpc/streaming.rs index d27d95f98..dcc127dab 100644 --- a/sgl-router/src/routers/grpc/streaming.rs +++ b/sgl-router/src/routers/grpc/streaming.rs @@ -19,7 +19,14 @@ use tracing::{debug, error, warn}; use super::context; use super::utils; use crate::grpc_client::proto; -use crate::protocols::spec::*; +use crate::protocols::chat::{ + ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, +}; +use crate::protocols::common::{ + ChatLogProbs, FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice, + ToolChoiceValue, Usage, +}; +use crate::protocols::generate::GenerateRequest; use crate::reasoning_parser::ReasoningParser; use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; use crate::tokenizer::traits::Tokenizer; diff --git a/sgl-router/src/routers/grpc/utils.rs b/sgl-router/src/routers/grpc/utils.rs index 9b4891a66..86e1532d5 100644 --- a/sgl-router/src/routers/grpc/utils.rs +++ b/sgl-router/src/routers/grpc/utils.rs @@ -4,10 +4,12 @@ use super::ProcessedMessages; use crate::core::Worker; use crate::grpc_client::sglang_scheduler::AbortOnDropStream; use crate::grpc_client::{proto, SglangSchedulerClient}; -use crate::protocols::spec::{ - ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse, - GenerateFinishReason, StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb, +use crate::protocols::chat::{ChatCompletionRequest, ChatMessage}; +use crate::protocols::common::{ + ChatLogProbs, ChatLogProbsContent, FunctionCallResponse, StringOrArray, Tool, ToolCall, + ToolChoice, ToolChoiceValue, TopLogProb, }; +use crate::protocols::generate::GenerateFinishReason; use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::HuggingFaceTokenizer; @@ -952,7 +954,8 @@ pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> Generate #[cfg(test)] mod tests { use super::*; - use crate::protocols::spec::{ChatMessage, ContentPart, ImageUrl, UserMessageContent}; + use crate::protocols::chat::{ChatMessage, UserMessageContent}; + use crate::protocols::common::{ContentPart, ImageUrl}; use crate::tokenizer::chat_template::ChatTemplateContentFormat; use serde_json::json; diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index 5a1a6fc2c..8b6864a62 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -5,10 +5,13 @@ use crate::core::{ }; use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; -use crate::protocols::spec::{ - ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, RerankRequest, - ResponsesGetParams, ResponsesRequest, StringOrArray, UserMessageContent, -}; +use crate::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}; +use crate::protocols::common::{InputIds, StringOrArray}; +use crate::protocols::completion::CompletionRequest; +use crate::protocols::embedding::EmbeddingRequest; +use crate::protocols::generate::GenerateRequest; +use crate::protocols::rerank::RerankRequest; +use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; use crate::routers::header_utils; use crate::routers::RouterTrait; use async_trait::async_trait; @@ -150,9 +153,10 @@ impl PDRouter { } fn get_generate_batch_size(req: &GenerateRequest) -> Option { - if let Some(text) = &req.text { - if text.contains("[") && text.contains("]") { - return None; + // GenerateRequest doesn't support batch via arrays, only via input_ids + if let Some(InputIds::Batch(batches)) = &req.input_ids { + if !batches.is_empty() { + return Some(batches.len()); } } None @@ -1185,7 +1189,7 @@ impl RouterTrait for PDRouter { async fn route_embeddings( &self, _headers: Option<&HeaderMap>, - _body: &crate::protocols::spec::EmbeddingRequest, + _body: &EmbeddingRequest, _model_id: Option<&str>, ) -> Response { ( diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index 1ac198d43..b20203166 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -4,10 +4,13 @@ use crate::core::{ }; use crate::metrics::RouterMetrics; use crate::policies::PolicyRegistry; -use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest, - RerankRequest, RerankResponse, RerankResult, ResponsesGetParams, ResponsesRequest, -}; +use crate::protocols::chat::ChatCompletionRequest; +use crate::protocols::common::GenerationRequest; +use crate::protocols::completion::CompletionRequest; +use crate::protocols::embedding::EmbeddingRequest; +use crate::protocols::generate::GenerateRequest; +use crate::protocols::rerank::{RerankRequest, RerankResponse, RerankResult}; +use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; use crate::routers::header_utils; use crate::routers::RouterTrait; use axum::body::to_bytes; @@ -628,7 +631,7 @@ impl Router { let rerank_results = serde_json::from_slice::>(&body_bytes)?; let mut rerank_response = RerankResponse::new(rerank_results, req.model.clone(), req.rid.clone()); - rerank_response.sort_by_score(); + // Sorting is handled by Python worker (serving_rerank.py) if let Some(top_k) = req.top_k { rerank_response.apply_top_k(top_k); } @@ -748,9 +751,6 @@ impl RouterTrait for Router { body: &RerankRequest, model_id: Option<&str>, ) -> Response { - if let Err(e) = body.validate() { - return (StatusCode::BAD_REQUEST, e).into_response(); - } let response = self .route_typed_request(headers, body, "/v1/rerank", model_id) .await; diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 58274de3f..410976089 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -9,10 +9,12 @@ use axum::{ }; use std::fmt::Debug; -use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, - ResponsesGetParams, ResponsesRequest, -}; +use crate::protocols::chat::ChatCompletionRequest; +use crate::protocols::completion::CompletionRequest; +use crate::protocols::embedding::EmbeddingRequest; +use crate::protocols::generate::GenerateRequest; +use crate::protocols::rerank::RerankRequest; +use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; use serde_json::Value; pub mod factory; diff --git a/sgl-router/src/routers/openai/conversations.rs b/sgl-router/src/routers/openai/conversations.rs index 4f534b943..6fdadde53 100644 --- a/sgl-router/src/routers/openai/conversations.rs +++ b/sgl-router/src/routers/openai/conversations.rs @@ -6,7 +6,7 @@ use crate::data_connector::{ NewConversationItem, ResponseId, ResponseStorage, SharedConversationItemStorage, SharedConversationStorage, }; -use crate::protocols::spec::{ResponseInput, ResponsesRequest}; +use crate::protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest}; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::Json; @@ -1028,7 +1028,7 @@ async fn persist_items_with_storages( ResponseInput::Items(items_array) => { for input_item in items_array { match input_item { - crate::protocols::spec::ResponseInputOutputItem::Message { + ResponseInputOutputItem::Message { role, content, status, diff --git a/sgl-router/src/routers/openai/mcp.rs b/sgl-router/src/routers/openai/mcp.rs index d23ca396a..de86690d7 100644 --- a/sgl-router/src/routers/openai/mcp.rs +++ b/sgl-router/src/routers/openai/mcp.rs @@ -9,7 +9,9 @@ //! - Metadata injection for MCP operations use crate::mcp::McpClientManager; -use crate::protocols::spec::{ResponseInput, ResponseToolType, ResponsesRequest}; +use crate::protocols::responses::{ + ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest, +}; use crate::routers::header_utils::apply_request_headers; use axum::http::HeaderMap; use bytes::Bytes; @@ -127,7 +129,7 @@ impl FunctionCallInProgress { /// Build a request-scoped MCP manager from request tools, if present. pub(super) async fn mcp_manager_from_request_tools( - tools: &[crate::protocols::spec::ResponseTool], + tools: &[ResponseTool], ) -> Option> { let tool = tools .iter() diff --git a/sgl-router/src/routers/openai/responses.rs b/sgl-router/src/routers/openai/responses.rs index 3c5a73d28..fbd3a1ee2 100644 --- a/sgl-router/src/routers/openai/responses.rs +++ b/sgl-router/src/routers/openai/responses.rs @@ -1,7 +1,7 @@ //! Response storage, patching, and extraction utilities use crate::data_connector::{ResponseId, StoredResponse}; -use crate::protocols::spec::{ResponseInput, ResponseToolType, ResponsesRequest}; +use crate::protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest}; use serde_json::{json, Value}; use std::collections::HashMap; use tracing::warn; diff --git a/sgl-router/src/routers/openai/router.rs b/sgl-router/src/routers/openai/router.rs index 0648c1652..dd6d2307e 100644 --- a/sgl-router/src/routers/openai/router.rs +++ b/sgl-router/src/routers/openai/router.rs @@ -6,8 +6,12 @@ use crate::data_connector::{ conversation_items::ListParams, conversation_items::SortOrder, ConversationId, ResponseId, SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, }; -use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, +use crate::protocols::chat::ChatCompletionRequest; +use crate::protocols::completion::CompletionRequest; +use crate::protocols::embedding::EmbeddingRequest; +use crate::protocols::generate::GenerateRequest; +use crate::protocols::rerank::RerankRequest; +use crate::protocols::responses::{ ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams, ResponsesRequest, }; diff --git a/sgl-router/src/routers/openai/streaming.rs b/sgl-router/src/routers/openai/streaming.rs index 9a630ff82..349531d0e 100644 --- a/sgl-router/src/routers/openai/streaming.rs +++ b/sgl-router/src/routers/openai/streaming.rs @@ -10,7 +10,7 @@ use crate::data_connector::{ SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage, }; -use crate::protocols::spec::{ResponseToolType, ResponsesRequest}; +use crate::protocols::responses::{ResponseToolType, ResponsesRequest}; use crate::routers::header_utils::{apply_request_headers, preserve_response_headers}; use axum::{ body::Body, diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index 13ece2c97..23f19f20f 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -6,10 +6,12 @@ use crate::config::{ConnectionMode, RoutingMode}; use crate::core::{WorkerRegistry, WorkerType}; -use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, - ResponsesGetParams, ResponsesRequest, -}; +use crate::protocols::chat::ChatCompletionRequest; +use crate::protocols::completion::CompletionRequest; +use crate::protocols::embedding::EmbeddingRequest; +use crate::protocols::generate::GenerateRequest; +use crate::protocols::rerank::RerankRequest; +use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest}; use crate::routers::RouterTrait; use crate::server::{AppContext, ServerConfig}; use async_trait::async_trait; diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index d8ad5253f..2eb7484b3 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -15,10 +15,12 @@ use crate::{ middleware::{self, AuthConfig, QueuedRequest, TokenBucket}, policies::PolicyRegistry, protocols::{ - spec::{ - ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, - RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput, - }, + chat::ChatCompletionRequest, + completion::CompletionRequest, + embedding::EmbeddingRequest, + generate::GenerateRequest, + rerank::{RerankRequest, V1RerankReqInput}, + responses::{ResponsesGetParams, ResponsesRequest}, validated::ValidatedJson, worker_spec::{WorkerConfigRequest, WorkerErrorResponse, WorkerInfo}, }, @@ -223,7 +225,7 @@ async fn v1_completions( async fn rerank( State(state): State>, headers: http::HeaderMap, - Json(body): Json, + ValidatedJson(body): ValidatedJson, ) -> Response { state.router.route_rerank(Some(&headers), &body, None).await } diff --git a/sgl-router/src/tool_parser/parsers/deepseek_parser.rs b/sgl-router/src/tool_parser/parsers/deepseek_parser.rs index 371be9b68..bb6306043 100644 --- a/sgl-router/src/tool_parser/parsers/deepseek_parser.rs +++ b/sgl-router/src/tool_parser/parsers/deepseek_parser.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use regex::Regex; use serde_json::Value; -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::{ errors::{ParserError, ParserResult}, diff --git a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs index d40273466..38cd86558 100644 --- a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs +++ b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use regex::Regex; use serde_json::Value; -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::{ errors::{ParserError, ParserResult}, diff --git a/sgl-router/src/tool_parser/parsers/gpt_oss_harmony_parser.rs b/sgl-router/src/tool_parser/parsers/gpt_oss_harmony_parser.rs index 091971df9..5c66d08ae 100644 --- a/sgl-router/src/tool_parser/parsers/gpt_oss_harmony_parser.rs +++ b/sgl-router/src/tool_parser/parsers/gpt_oss_harmony_parser.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::{ errors::ParserResult, diff --git a/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs b/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs index 0dd58cf87..dcd9afb26 100644 --- a/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs +++ b/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use regex::Regex; use serde_json::Value; -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::{ errors::{ParserError, ParserResult}, diff --git a/sgl-router/src/tool_parser/parsers/helpers.rs b/sgl-router/src/tool_parser/parsers/helpers.rs index c71cf66a0..109cea53f 100644 --- a/sgl-router/src/tool_parser/parsers/helpers.rs +++ b/sgl-router/src/tool_parser/parsers/helpers.rs @@ -1,4 +1,4 @@ -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use serde_json::Value; use std::collections::HashMap; diff --git a/sgl-router/src/tool_parser/parsers/json_parser.rs b/sgl-router/src/tool_parser/parsers/json_parser.rs index 04b0ca1de..1c7b481ab 100644 --- a/sgl-router/src/tool_parser/parsers/json_parser.rs +++ b/sgl-router/src/tool_parser/parsers/json_parser.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use serde_json::Value; -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::{ errors::{ParserError, ParserResult}, diff --git a/sgl-router/src/tool_parser/parsers/kimik2_parser.rs b/sgl-router/src/tool_parser/parsers/kimik2_parser.rs index 2e2237f0c..9cc11437b 100644 --- a/sgl-router/src/tool_parser/parsers/kimik2_parser.rs +++ b/sgl-router/src/tool_parser/parsers/kimik2_parser.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use regex::Regex; use serde_json::Value; -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::{ errors::ParserResult, diff --git a/sgl-router/src/tool_parser/parsers/llama_parser.rs b/sgl-router/src/tool_parser/parsers/llama_parser.rs index 3af8b9bda..e42c2d679 100644 --- a/sgl-router/src/tool_parser/parsers/llama_parser.rs +++ b/sgl-router/src/tool_parser/parsers/llama_parser.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use serde_json::Value; -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::{ errors::{ParserError, ParserResult}, diff --git a/sgl-router/src/tool_parser/parsers/mistral_parser.rs b/sgl-router/src/tool_parser/parsers/mistral_parser.rs index c87d8ce7a..151e7fccf 100644 --- a/sgl-router/src/tool_parser/parsers/mistral_parser.rs +++ b/sgl-router/src/tool_parser/parsers/mistral_parser.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use serde_json::Value; -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::{ errors::{ParserError, ParserResult}, diff --git a/sgl-router/src/tool_parser/parsers/passthrough_parser.rs b/sgl-router/src/tool_parser/parsers/passthrough_parser.rs index cb793d597..b718bff58 100644 --- a/sgl-router/src/tool_parser/parsers/passthrough_parser.rs +++ b/sgl-router/src/tool_parser/parsers/passthrough_parser.rs @@ -4,7 +4,7 @@ //! tool call parsing should be performed. It simply returns the input text //! with no tool calls detected. -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::errors::ParserResult; use crate::tool_parser::traits::ToolParser; use crate::tool_parser::types::{StreamingParseResult, ToolCall, ToolCallItem}; diff --git a/sgl-router/src/tool_parser/parsers/pythonic_parser.rs b/sgl-router/src/tool_parser/parsers/pythonic_parser.rs index 4c712c7bd..317e5836d 100644 --- a/sgl-router/src/tool_parser/parsers/pythonic_parser.rs +++ b/sgl-router/src/tool_parser/parsers/pythonic_parser.rs @@ -15,7 +15,7 @@ use rustpython_parser::{parse, Mode}; use serde_json::{Map, Number, Value}; use std::sync::OnceLock; -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::{ errors::{ParserError, ParserResult}, diff --git a/sgl-router/src/tool_parser/parsers/qwen_parser.rs b/sgl-router/src/tool_parser/parsers/qwen_parser.rs index e0072debc..a3f5d965a 100644 --- a/sgl-router/src/tool_parser/parsers/qwen_parser.rs +++ b/sgl-router/src/tool_parser/parsers/qwen_parser.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use regex::Regex; use serde_json::Value; -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::{ errors::{ParserError, ParserResult}, diff --git a/sgl-router/src/tool_parser/parsers/step3_parser.rs b/sgl-router/src/tool_parser/parsers/step3_parser.rs index 01f3674aa..1b311cc67 100644 --- a/sgl-router/src/tool_parser/parsers/step3_parser.rs +++ b/sgl-router/src/tool_parser/parsers/step3_parser.rs @@ -3,7 +3,7 @@ use regex::Regex; use serde_json::Value; use std::collections::HashMap; -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::{ errors::{ParserError, ParserResult}, diff --git a/sgl-router/src/tool_parser/traits.rs b/sgl-router/src/tool_parser/traits.rs index f4e64a053..482f11dea 100644 --- a/sgl-router/src/tool_parser/traits.rs +++ b/sgl-router/src/tool_parser/traits.rs @@ -1,4 +1,4 @@ -use crate::protocols::spec::Tool; +use crate::protocols::common::Tool; use crate::tool_parser::{ errors::ParserResult, types::{StreamingParseResult, ToolCall}, diff --git a/sgl-router/tests/chat_template_format_detection.rs b/sgl-router/tests/chat_template_format_detection.rs index 3efa6676d..64ef20f02 100644 --- a/sgl-router/tests/chat_template_format_detection.rs +++ b/sgl-router/tests/chat_template_format_detection.rs @@ -1,4 +1,4 @@ -use sglang_router_rs::protocols::spec; +use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent}; use sglang_router_rs::tokenizer::chat_template::{ detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, ChatTemplateProcessor, @@ -173,12 +173,12 @@ assistant: let processor = ChatTemplateProcessor::new(template.to_string()); let messages = [ - spec::ChatMessage::System { + ChatMessage::System { content: "You are helpful".to_string(), name: None, }, - spec::ChatMessage::User { - content: spec::UserMessageContent::Text("Hello".to_string()), + ChatMessage::User { + content: UserMessageContent::Text("Hello".to_string()), name: None, }, ]; @@ -213,8 +213,8 @@ fn test_chat_template_with_tokens_unit_test() { let processor = ChatTemplateProcessor::new(template.to_string()); - let messages = [spec::ChatMessage::User { - content: spec::UserMessageContent::Text("Test".to_string()), + let messages = [ChatMessage::User { + content: UserMessageContent::Text("Test".to_string()), name: None, }]; diff --git a/sgl-router/tests/chat_template_integration.rs b/sgl-router/tests/chat_template_integration.rs index 077b813b4..30a0b146a 100644 --- a/sgl-router/tests/chat_template_integration.rs +++ b/sgl-router/tests/chat_template_integration.rs @@ -1,4 +1,5 @@ -use sglang_router_rs::protocols::spec; +use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent}; +use sglang_router_rs::protocols::common::{ContentPart, ImageUrl}; use sglang_router_rs::tokenizer::chat_template::{ detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, ChatTemplateProcessor, @@ -17,8 +18,8 @@ fn test_simple_chat_template() { let processor = ChatTemplateProcessor::new(template.to_string()); - let messages = [spec::ChatMessage::User { - content: spec::UserMessageContent::Text("Test".to_string()), + let messages = [ChatMessage::User { + content: UserMessageContent::Text("Test".to_string()), name: None, }]; @@ -51,8 +52,8 @@ fn test_chat_template_with_tokens() { let processor = ChatTemplateProcessor::new(template.to_string()); - let messages = [spec::ChatMessage::User { - content: spec::UserMessageContent::Text("Test".to_string()), + let messages = [ChatMessage::User { + content: UserMessageContent::Text("Test".to_string()), name: None, }]; @@ -112,12 +113,12 @@ fn test_llama_style_template() { let processor = ChatTemplateProcessor::new(template.to_string()); let messages = [ - spec::ChatMessage::System { + ChatMessage::System { content: "You are a helpful assistant".to_string(), name: None, }, - spec::ChatMessage::User { - content: spec::UserMessageContent::Text("What is 2+2?".to_string()), + ChatMessage::User { + content: UserMessageContent::Text("What is 2+2?".to_string()), name: None, }, ]; @@ -167,18 +168,18 @@ fn test_chatml_template() { let processor = ChatTemplateProcessor::new(template.to_string()); let messages = vec![ - spec::ChatMessage::User { - content: spec::UserMessageContent::Text("Hello".to_string()), + ChatMessage::User { + content: UserMessageContent::Text("Hello".to_string()), name: None, }, - spec::ChatMessage::Assistant { + ChatMessage::Assistant { content: Some("Hi there!".to_string()), name: None, tool_calls: None, reasoning_content: None, }, - spec::ChatMessage::User { - content: spec::UserMessageContent::Text("How are you?".to_string()), + ChatMessage::User { + content: UserMessageContent::Text("How are you?".to_string()), name: None, }, ]; @@ -219,8 +220,8 @@ assistant: let processor = ChatTemplateProcessor::new(template.to_string()); - let messages = [spec::ChatMessage::User { - content: spec::UserMessageContent::Text("Test".to_string()), + let messages = [ChatMessage::User { + content: UserMessageContent::Text("Test".to_string()), name: None, }]; @@ -306,13 +307,13 @@ fn test_template_with_multimodal_content() { let processor = ChatTemplateProcessor::new(template.to_string()); - let messages = [spec::ChatMessage::User { - content: spec::UserMessageContent::Parts(vec![ - spec::ContentPart::Text { + let messages = [ChatMessage::User { + content: UserMessageContent::Parts(vec![ + ContentPart::Text { text: "Look at this:".to_string(), }, - spec::ContentPart::ImageUrl { - image_url: spec::ImageUrl { + ContentPart::ImageUrl { + image_url: ImageUrl { url: "https://example.com/image.jpg".to_string(), detail: None, }, diff --git a/sgl-router/tests/chat_template_loading.rs b/sgl-router/tests/chat_template_loading.rs index 336b17126..5297198c7 100644 --- a/sgl-router/tests/chat_template_loading.rs +++ b/sgl-router/tests/chat_template_loading.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use sglang_router_rs::protocols::spec; + use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent}; use sglang_router_rs::tokenizer::chat_template::ChatTemplateParams; use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; use std::fs; @@ -58,11 +58,11 @@ mod tests { .unwrap(); let messages = [ - spec::ChatMessage::User { - content: spec::UserMessageContent::Text("Hello".to_string()), + ChatMessage::User { + content: UserMessageContent::Text("Hello".to_string()), name: None, }, - spec::ChatMessage::Assistant { + ChatMessage::Assistant { content: Some("Hi there".to_string()), name: None, tool_calls: None, @@ -140,8 +140,8 @@ mod tests { ) .unwrap(); - let messages = [spec::ChatMessage::User { - content: spec::UserMessageContent::Text("Test".to_string()), + let messages = [ChatMessage::User { + content: UserMessageContent::Text("Test".to_string()), name: None, }]; @@ -199,11 +199,11 @@ mod tests { tokenizer.set_chat_template(new_template.to_string()); let messages = [ - spec::ChatMessage::User { - content: spec::UserMessageContent::Text("Hello".to_string()), + ChatMessage::User { + content: UserMessageContent::Text("Hello".to_string()), name: None, }, - spec::ChatMessage::Assistant { + ChatMessage::Assistant { content: Some("World".to_string()), name: None, tool_calls: None, diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs index dc7e7a8d5..8b9a9bd75 100644 --- a/sgl-router/tests/common/mod.rs +++ b/sgl-router/tests/common/mod.rs @@ -15,7 +15,7 @@ use sglang_router_rs::data_connector::{ }; use sglang_router_rs::middleware::TokenBucket; use sglang_router_rs::policies::PolicyRegistry; -use sglang_router_rs::protocols::spec::{Function, Tool}; +use sglang_router_rs::protocols::common::{Function, Tool}; use sglang_router_rs::server::AppContext; use std::fs; use std::path::PathBuf; diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index 60ab83c9f..896d4e484 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -1,10 +1,12 @@ // Integration test for Responses API use axum::http::StatusCode; -use sglang_router_rs::protocols::spec::{ - GenerationRequest, ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseStatus, - ResponseTool, ResponseToolType, ResponsesRequest, ResponsesResponse, ServiceTier, ToolChoice, - ToolChoiceValue, Truncation, UsageInfo, +use sglang_router_rs::protocols::common::{ + GenerationRequest, ToolChoice, ToolChoiceValue, UsageInfo, +}; +use sglang_router_rs::protocols::responses::{ + ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseTool, ResponseToolType, + ResponsesRequest, ServiceTier, Truncation, }; mod common; @@ -430,24 +432,18 @@ fn test_responses_request_sglang_extensions() { assert_eq!(parsed.repetition_penalty, 1.1); } -#[test] -fn test_responses_response_creation() { - let response = ResponsesResponse::new( - "resp_test789".to_string(), - "test-model".to_string(), - ResponseStatus::Completed, - ); - - assert_eq!(response.id, "resp_test789"); - assert_eq!(response.model, "test-model"); - assert!(response.is_complete()); - assert!(!response.is_in_progress()); - assert!(!response.is_failed()); -} - #[test] fn test_usage_conversion() { - let usage_info = UsageInfo::new_with_cached(15, 25, Some(8), 3); + // Construct UsageInfo directly with cached token details + let usage_info = UsageInfo { + prompt_tokens: 15, + completion_tokens: 25, + total_tokens: 40, + reasoning_tokens: Some(8), + prompt_tokens_details: Some(sglang_router_rs::protocols::common::PromptTokenUsageInfo { + cached_tokens: 3, + }), + }; let response_usage = usage_info.to_response_usage(); assert_eq!(response_usage.input_tokens, 15); diff --git a/sgl-router/tests/spec/chat_completion.rs b/sgl-router/tests/spec/chat_completion.rs index 87eade72a..66e417ea0 100644 --- a/sgl-router/tests/spec/chat_completion.rs +++ b/sgl-router/tests/spec/chat_completion.rs @@ -1,7 +1,8 @@ use serde_json::json; -use sglang_router_rs::protocols::spec::{ - ChatCompletionRequest, ChatMessage, Function, FunctionCall, FunctionChoice, StreamOptions, - Tool, ToolChoice, ToolChoiceValue, ToolReference, UserMessageContent, +use sglang_router_rs::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}; +use sglang_router_rs::protocols::common::{ + Function, FunctionCall, FunctionChoice, StreamOptions, Tool, ToolChoice, ToolChoiceValue, + ToolReference, }; use sglang_router_rs::protocols::validated::Normalizable; use validator::Validate; diff --git a/sgl-router/tests/spec/chat_message.rs b/sgl-router/tests/spec/chat_message.rs index 0b158522b..35908b56c 100644 --- a/sgl-router/tests/spec/chat_message.rs +++ b/sgl-router/tests/spec/chat_message.rs @@ -1,5 +1,5 @@ use serde_json::json; -use sglang_router_rs::protocols::spec::{ChatMessage, UserMessageContent}; +use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent}; #[test] fn test_chat_message_tagged_by_role_system() { diff --git a/sgl-router/tests/spec/embedding.rs b/sgl-router/tests/spec/embedding.rs index 718dd5602..e7c832884 100644 --- a/sgl-router/tests/spec/embedding.rs +++ b/sgl-router/tests/spec/embedding.rs @@ -1,5 +1,6 @@ use serde_json::{from_str, json, to_string}; -use sglang_router_rs::protocols::spec::{EmbeddingRequest, GenerationRequest}; +use sglang_router_rs::protocols::common::GenerationRequest; +use sglang_router_rs::protocols::embedding::EmbeddingRequest; #[test] fn test_embedding_request_serialization_string_input() { diff --git a/sgl-router/tests/spec/rerank.rs b/sgl-router/tests/spec/rerank.rs index 3a0ca9aa8..a2e23b9ae 100644 --- a/sgl-router/tests/spec/rerank.rs +++ b/sgl-router/tests/spec/rerank.rs @@ -1,9 +1,10 @@ use serde_json::{from_str, to_string, Number, Value}; -use sglang_router_rs::protocols::spec::{ - GenerationRequest, RerankRequest, RerankResponse, RerankResult, StringOrArray, UsageInfo, - V1RerankReqInput, +use sglang_router_rs::protocols::common::{GenerationRequest, StringOrArray, UsageInfo}; +use sglang_router_rs::protocols::rerank::{ + RerankRequest, RerankResponse, RerankResult, V1RerankReqInput, }; use std::collections::HashMap; +use validator::Validate; #[test] fn test_rerank_request_serialization() { @@ -75,8 +76,7 @@ fn test_rerank_request_validation_empty_query() { }; let result = request.validate(); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "Query cannot be empty"); + assert!(result.is_err(), "Should reject empty query"); } #[test] @@ -92,8 +92,7 @@ fn test_rerank_request_validation_whitespace_query() { }; let result = request.validate(); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "Query cannot be empty"); + assert!(result.is_err(), "Should reject whitespace-only query"); } #[test] @@ -109,8 +108,7 @@ fn test_rerank_request_validation_empty_documents() { }; let result = request.validate(); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "Documents list cannot be empty"); + assert!(result.is_err(), "Should reject empty documents list"); } #[test] @@ -126,8 +124,7 @@ fn test_rerank_request_validation_top_k_zero() { }; let result = request.validate(); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "top_k must be greater than 0"); + assert!(result.is_err(), "Should reject top_k of zero"); } #[test] diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index 5dca1a327..fa249cebb 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -9,18 +9,20 @@ use axum::{ Json, Router, }; use serde_json::json; -use sglang_router_rs::data_connector::MemoryConversationItemStorage; use sglang_router_rs::{ config::{ ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode, }, data_connector::{ - MemoryConversationStorage, MemoryResponseStorage, ResponseId, ResponseStorage, - StoredResponse, + MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, + ResponseId, ResponseStorage, StoredResponse, }, - protocols::spec::{ - ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput, - ResponsesGetParams, ResponsesRequest, UserMessageContent, + protocols::{ + chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}, + common::StringOrArray, + completion::CompletionRequest, + generate::GenerateRequest, + responses::{ResponseInput, ResponsesGetParams, ResponsesRequest}, }, routers::{openai::OpenAIRouter, RouterTrait}, }; @@ -52,7 +54,7 @@ fn create_minimal_chat_request() -> ChatCompletionRequest { fn create_minimal_completion_request() -> CompletionRequest { CompletionRequest { model: "gpt-3.5-turbo".to_string(), - prompt: sglang_router_rs::protocols::spec::StringOrArray::String("Hello".to_string()), + prompt: StringOrArray::String("Hello".to_string()), suffix: None, max_tokens: Some(100), temperature: None, @@ -605,12 +607,12 @@ async fn test_unsupported_endpoints() { video_data: None, audio_data: None, sampling_params: None, - stream: false, return_logprob: Some(false), logprob_start_len: None, top_logprobs_num: None, token_ids_logprob: None, return_text_in_logprobs: false, + stream: false, log_metrics: true, return_hidden_states: false, modalities: None,