From 5ef545e6789da24cf2c86c189846211a0bc664a4 Mon Sep 17 00:00:00 2001 From: Keyang Ru Date: Fri, 22 Aug 2025 14:18:47 -0700 Subject: [PATCH] [router] Move all protocols to spec.rs file (#9519) --- sgl-router/benches/request_processing.rs | 10 +- sgl-router/src/protocols/common.rs | 61 - sgl-router/src/protocols/generate/mod.rs | 8 - sgl-router/src/protocols/generate/request.rs | 97 - sgl-router/src/protocols/generate/types.rs | 82 - sgl-router/src/protocols/mod.rs | 4 +- sgl-router/src/protocols/openai/chat/mod.rs | 13 - .../src/protocols/openai/chat/request.rs | 216 -- .../src/protocols/openai/chat/response.rs | 59 - sgl-router/src/protocols/openai/chat/types.rs | 185 -- .../src/protocols/openai/chat/validation.rs | 477 ----- sgl-router/src/protocols/openai/common.rs | 58 - .../src/protocols/openai/completions/mod.rs | 10 - .../protocols/openai/completions/request.rs | 158 -- .../protocols/openai/completions/response.rs | 56 - sgl-router/src/protocols/openai/errors.rs | 19 - sgl-router/src/protocols/openai/mod.rs | 8 - .../src/protocols/openai/responses/mod.rs | 10 - .../src/protocols/openai/responses/request.rs | 300 --- .../protocols/openai/responses/response.rs | 280 --- .../src/protocols/openai/responses/types.rs | 296 --- sgl-router/src/protocols/spec.rs | 1867 +++++++++++++++++ sgl-router/src/protocols/validation.rs | 680 ++++-- sgl-router/src/routers/mod.rs | 5 +- sgl-router/src/routers/pd_router.rs | 10 +- sgl-router/src/routers/router.rs | 6 +- sgl-router/src/server.rs | 5 +- sgl-router/tests/benchmark_integration.rs | 10 +- sgl-router/tests/responses_api_test.rs | 14 +- 29 files changed, 2432 insertions(+), 2572 deletions(-) delete mode 100644 sgl-router/src/protocols/common.rs delete mode 100644 sgl-router/src/protocols/generate/mod.rs delete mode 100644 sgl-router/src/protocols/generate/request.rs delete mode 100644 sgl-router/src/protocols/generate/types.rs delete mode 100644 sgl-router/src/protocols/openai/chat/mod.rs delete mode 100644 sgl-router/src/protocols/openai/chat/request.rs delete mode 100644 sgl-router/src/protocols/openai/chat/response.rs delete mode 100644 sgl-router/src/protocols/openai/chat/types.rs delete mode 100644 sgl-router/src/protocols/openai/chat/validation.rs delete mode 100644 sgl-router/src/protocols/openai/common.rs delete mode 100644 sgl-router/src/protocols/openai/completions/mod.rs delete mode 100644 sgl-router/src/protocols/openai/completions/request.rs delete mode 100644 sgl-router/src/protocols/openai/completions/response.rs delete mode 100644 sgl-router/src/protocols/openai/errors.rs delete mode 100644 sgl-router/src/protocols/openai/mod.rs delete mode 100644 sgl-router/src/protocols/openai/responses/mod.rs delete mode 100644 sgl-router/src/protocols/openai/responses/request.rs delete mode 100644 sgl-router/src/protocols/openai/responses/response.rs delete mode 100644 sgl-router/src/protocols/openai/responses/types.rs create mode 100644 sgl-router/src/protocols/spec.rs diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index 70de06361..3edb2fc3d 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -3,13 +3,9 @@ use serde_json::{from_str, to_string, to_value, to_vec}; use std::time::Instant; use sglang_router_rs::core::{BasicWorker, Worker, WorkerType}; -use sglang_router_rs::protocols::{ - common::StringOrArray, - generate::{GenerateParameters, GenerateRequest, SamplingParams}, - openai::{ - chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}, - completions::CompletionRequest, - }, +use sglang_router_rs::protocols::spec::{ + ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, + SamplingParams, StringOrArray, UserMessageContent, }; use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap}; diff --git a/sgl-router/src/protocols/common.rs b/sgl-router/src/protocols/common.rs deleted file mode 100644 index 8e7cb729f..000000000 --- a/sgl-router/src/protocols/common.rs +++ /dev/null @@ -1,61 +0,0 @@ -// Common types shared across all protocol implementations - -use serde::{Deserialize, Serialize}; - -/// Helper function for serde default value -pub fn default_true() -> bool { - true -} - -/// Common trait for all generation requests across different APIs -pub trait GenerationRequest: Send + Sync { - /// Check if the request is for streaming - fn is_stream(&self) -> bool; - - /// Get the model name if specified - fn get_model(&self) -> Option<&str>; - - /// Extract text content for routing decisions - fn extract_text_for_routing(&self) -> String; -} - -/// Helper type for string or array of strings -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum StringOrArray { - String(String), - Array(Vec), -} -impl StringOrArray { - /// Get the number of items in the StringOrArray - pub fn len(&self) -> usize { - match self { - StringOrArray::String(_) => 1, - StringOrArray::Array(arr) => arr.len(), - } - } - - /// Check if the StringOrArray is empty - pub fn is_empty(&self) -> bool { - match self { - StringOrArray::String(s) => s.is_empty(), - StringOrArray::Array(arr) => arr.is_empty(), - } - } - - /// Convert to a vector of strings - pub fn to_vec(&self) -> Vec { - match self { - StringOrArray::String(s) => vec![s.clone()], - StringOrArray::Array(arr) => arr.clone(), - } - } -} - -/// LoRA adapter path - can be single path or batch of paths (SGLang extension) -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum LoRAPath { - Single(Option), - Batch(Vec>), -} diff --git a/sgl-router/src/protocols/generate/mod.rs b/sgl-router/src/protocols/generate/mod.rs deleted file mode 100644 index 7b2b1d97e..000000000 --- a/sgl-router/src/protocols/generate/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -// SGLang native Generate API module (/generate) - -pub mod request; -pub mod types; - -// Re-export main types for convenience -pub use request::GenerateRequest; -pub use types::{GenerateParameters, InputIds, SamplingParams}; diff --git a/sgl-router/src/protocols/generate/request.rs b/sgl-router/src/protocols/generate/request.rs deleted file mode 100644 index b3bb3fe46..000000000 --- a/sgl-router/src/protocols/generate/request.rs +++ /dev/null @@ -1,97 +0,0 @@ -// Generate API request types (/generate) - -use crate::protocols::common::{GenerationRequest, LoRAPath, StringOrArray}; -use crate::protocols::generate::types::{GenerateParameters, InputIds, SamplingParams}; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct GenerateRequest { - /// The prompt to generate from (OpenAI style) - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt: Option, - - /// Text input - SGLang native format - #[serde(skip_serializing_if = "Option::is_none")] - pub text: Option, - - /// Input IDs for tokenized input - #[serde(skip_serializing_if = "Option::is_none")] - pub input_ids: Option, - - /// Generation parameters - #[serde(default, skip_serializing_if = "Option::is_none")] - pub parameters: Option, - - /// Sampling parameters (sglang style) - #[serde(skip_serializing_if = "Option::is_none")] - pub sampling_params: Option, - - /// Whether to stream the response - #[serde(default)] - pub stream: bool, - - /// Whether to return logprobs - #[serde(default)] - pub return_logprob: bool, - - // ============= SGLang Extensions ============= - /// Path to LoRA adapter(s) for model customization - #[serde(skip_serializing_if = "Option::is_none")] - pub lora_path: Option, - - /// Session parameters for continual prompting - #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, - - /// Return model hidden states - #[serde(default)] - pub return_hidden_states: bool, - - /// Request ID for tracking - #[serde(skip_serializing_if = "Option::is_none")] - pub rid: Option, -} - -impl GenerationRequest for GenerateRequest { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_model(&self) -> Option<&str> { - // Generate requests typically don't have a model field - None - } - - fn extract_text_for_routing(&self) -> String { - // Check fields in priority order: text, prompt, inputs - if let Some(ref text) = self.text { - return text.clone(); - } - - if let Some(ref prompt) = self.prompt { - return match prompt { - StringOrArray::String(s) => s.clone(), - StringOrArray::Array(v) => v.join(" "), - }; - } - - if let Some(ref input_ids) = self.input_ids { - return match input_ids { - InputIds::Single(ids) => ids - .iter() - .map(|&id| id.to_string()) - .collect::>() - .join(" "), - InputIds::Batch(batches) => batches - .iter() - .flat_map(|batch| batch.iter().map(|&id| id.to_string())) - .collect::>() - .join(" "), - }; - } - - // No text input found - String::new() - } -} diff --git a/sgl-router/src/protocols/generate/types.rs b/sgl-router/src/protocols/generate/types.rs deleted file mode 100644 index 4ddf363dc..000000000 --- a/sgl-router/src/protocols/generate/types.rs +++ /dev/null @@ -1,82 +0,0 @@ -// Types for the SGLang native /generate API - -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum InputIds { - Single(Vec), - Batch(Vec>), -} - -#[derive(Debug, Clone, Deserialize, Serialize, Default)] -pub struct GenerateParameters { - #[serde(skip_serializing_if = "Option::is_none")] - pub best_of: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub decoder_input_details: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub details: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub do_sample: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_new_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub repetition_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub return_full_text: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub truncate: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub typical_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub watermark: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize, Default)] -pub struct SamplingParams { - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_new_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub repetition_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub ignore_eos: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub skip_special_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub json_schema: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub regex: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub ebnf: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub min_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub min_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub no_stop_trim: Option, -} diff --git a/sgl-router/src/protocols/mod.rs b/sgl-router/src/protocols/mod.rs index 2b405eed0..5243c645f 100644 --- a/sgl-router/src/protocols/mod.rs +++ b/sgl-router/src/protocols/mod.rs @@ -1,7 +1,5 @@ // Protocol definitions and validation for various LLM APIs // This module provides a structured approach to handling different API protocols -pub mod common; -pub mod generate; -pub mod openai; +pub mod spec; pub mod validation; diff --git a/sgl-router/src/protocols/openai/chat/mod.rs b/sgl-router/src/protocols/openai/chat/mod.rs deleted file mode 100644 index 9a2025ae9..000000000 --- a/sgl-router/src/protocols/openai/chat/mod.rs +++ /dev/null @@ -1,13 +0,0 @@ -// Chat Completions API module - -pub mod request; -pub mod response; -pub mod types; -pub mod validation; - -// Re-export main types for convenience -pub use request::ChatCompletionRequest; -pub use response::{ - ChatChoice, ChatCompletionResponse, ChatCompletionStreamResponse, ChatStreamChoice, -}; -pub use types::*; diff --git a/sgl-router/src/protocols/openai/chat/request.rs b/sgl-router/src/protocols/openai/chat/request.rs deleted file mode 100644 index b7570c676..000000000 --- a/sgl-router/src/protocols/openai/chat/request.rs +++ /dev/null @@ -1,216 +0,0 @@ -// Chat Completions API request types - -use crate::protocols::common::{default_true, GenerationRequest, LoRAPath, StringOrArray}; -use crate::protocols::openai::chat::types::*; -use crate::protocols::openai::common::StreamOptions; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionRequest { - /// ID of the model to use - pub model: String, - - /// A list of messages comprising the conversation so far - pub messages: Vec, - - /// What sampling temperature to use, between 0 and 2 - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - - /// An alternative to sampling with temperature - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - - /// How many chat completion choices to generate for each input message - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, - - /// If set, partial message deltas will be sent - #[serde(default)] - pub stream: bool, - - /// Options for streaming response - #[serde(skip_serializing_if = "Option::is_none")] - pub stream_options: Option, - - /// Up to 4 sequences where the API will stop generating further tokens - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, - - /// The maximum number of tokens to generate - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - - /// An upper bound for the number of tokens that can be generated for a completion - #[serde(skip_serializing_if = "Option::is_none")] - pub max_completion_tokens: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - - /// Modify the likelihood of specified tokens appearing in the completion - #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, - - /// A unique identifier representing your end-user - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, - - /// If specified, our system will make a best effort to sample deterministically - #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, - - /// Whether to return log probabilities of the output tokens - #[serde(default)] - pub logprobs: bool, - - /// An integer between 0 and 20 specifying the number of most likely tokens to return - #[serde(skip_serializing_if = "Option::is_none")] - pub top_logprobs: Option, - - /// An object specifying the format that the model must output - #[serde(skip_serializing_if = "Option::is_none")] - pub response_format: Option, - - /// A list of tools the model may call - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, - - /// Controls which (if any) tool is called by the model - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_choice: Option, - - /// Whether to enable parallel function calling during tool use - #[serde(skip_serializing_if = "Option::is_none")] - pub parallel_tool_calls: Option, - - /// Deprecated: use tools instead - #[serde(skip_serializing_if = "Option::is_none")] - pub functions: Option>, - - /// Deprecated: use tool_choice instead - #[serde(skip_serializing_if = "Option::is_none")] - pub function_call: Option, - - // ============= SGLang Extensions ============= - /// Top-k sampling parameter (-1 to disable) - #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, - - /// Min-p nucleus sampling parameter - #[serde(skip_serializing_if = "Option::is_none")] - pub min_p: Option, - - /// Minimum number of tokens to generate - #[serde(skip_serializing_if = "Option::is_none")] - pub min_tokens: Option, - - /// Repetition penalty for reducing repetitive text - #[serde(skip_serializing_if = "Option::is_none")] - pub repetition_penalty: Option, - - /// Regex constraint for output generation - #[serde(skip_serializing_if = "Option::is_none")] - pub regex: Option, - - /// EBNF grammar constraint for structured output - #[serde(skip_serializing_if = "Option::is_none")] - pub ebnf: Option, - - /// Specific token IDs to use as stop conditions - #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, - - /// Skip trimming stop tokens from output - #[serde(default)] - pub no_stop_trim: bool, - - /// Ignore end-of-sequence tokens during generation - #[serde(default)] - pub ignore_eos: bool, - - /// Continue generating from final assistant message - #[serde(default)] - pub continue_final_message: bool, - - /// Skip special tokens during detokenization - #[serde(default = "default_true")] - pub skip_special_tokens: bool, - - // ============= SGLang Extensions ============= - /// Path to LoRA adapter(s) for model customization - #[serde(skip_serializing_if = "Option::is_none")] - pub lora_path: Option, - - /// Session parameters for continual prompting - #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, - - /// Separate reasoning content from final answer (O1-style models) - #[serde(default = "default_true")] - pub separate_reasoning: bool, - - /// Stream reasoning tokens during generation - #[serde(default = "default_true")] - pub stream_reasoning: bool, - - /// Return model hidden states - #[serde(default)] - pub return_hidden_states: bool, -} - -impl GenerationRequest for ChatCompletionRequest { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_model(&self) -> Option<&str> { - Some(&self.model) - } - - fn extract_text_for_routing(&self) -> String { - // Extract text from messages for routing decisions - self.messages - .iter() - .filter_map(|msg| match msg { - ChatMessage::System { content, .. } => Some(content.clone()), - ChatMessage::User { content, .. } => match content { - UserMessageContent::Text(text) => Some(text.clone()), - UserMessageContent::Parts(parts) => { - let texts: Vec = parts - .iter() - .filter_map(|part| match part { - ContentPart::Text { text } => Some(text.clone()), - _ => None, - }) - .collect(); - Some(texts.join(" ")) - } - }, - ChatMessage::Assistant { - content, - reasoning_content, - .. - } => { - // Combine content and reasoning content for routing decisions - let main_content = content.clone().unwrap_or_default(); - let reasoning = reasoning_content.clone().unwrap_or_default(); - if main_content.is_empty() && reasoning.is_empty() { - None - } else { - Some(format!("{} {}", main_content, reasoning).trim().to_string()) - } - } - ChatMessage::Tool { content, .. } => Some(content.clone()), - ChatMessage::Function { content, .. } => Some(content.clone()), - }) - .collect::>() - .join(" ") - } -} diff --git a/sgl-router/src/protocols/openai/chat/response.rs b/sgl-router/src/protocols/openai/chat/response.rs deleted file mode 100644 index 3ac480462..000000000 --- a/sgl-router/src/protocols/openai/chat/response.rs +++ /dev/null @@ -1,59 +0,0 @@ -// Chat Completions API response types - -use crate::protocols::openai::chat::types::{ChatMessage, ChatMessageDelta}; -use crate::protocols::openai::common::{ChatLogProbs, Usage}; -use serde::{Deserialize, Serialize}; - -// ============= Regular Response ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionResponse { - pub id: String, - pub object: String, // "chat.completion" - pub created: u64, - pub model: String, - pub choices: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatChoice { - pub index: u32, - pub message: ChatMessage, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - pub finish_reason: Option, // "stop", "length", "tool_calls", "content_filter", "function_call" - /// Information about which stop condition was matched - #[serde(skip_serializing_if = "Option::is_none")] - pub matched_stop: Option, // Can be string or integer - /// Hidden states from the model (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - pub hidden_states: Option>, -} - -// ============= Streaming Response ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionStreamResponse { - pub id: String, - pub object: String, // "chat.completion.chunk" - pub created: u64, - pub model: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, - pub choices: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatStreamChoice { - pub index: u32, - pub delta: ChatMessageDelta, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - pub finish_reason: Option, -} diff --git a/sgl-router/src/protocols/openai/chat/types.rs b/sgl-router/src/protocols/openai/chat/types.rs deleted file mode 100644 index 01bf836cf..000000000 --- a/sgl-router/src/protocols/openai/chat/types.rs +++ /dev/null @@ -1,185 +0,0 @@ -// Types specific to the Chat Completions API - -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -// ============= Message Types ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum ChatMessage { - System { - role: String, // "system" - content: String, - #[serde(skip_serializing_if = "Option::is_none")] - name: Option, - }, - User { - role: String, // "user" - content: UserMessageContent, - #[serde(skip_serializing_if = "Option::is_none")] - name: Option, - }, - Assistant { - role: String, // "assistant" - #[serde(skip_serializing_if = "Option::is_none")] - content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tool_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - function_call: Option, - /// Reasoning content for O1-style models (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - reasoning_content: Option, - }, - Tool { - role: String, // "tool" - content: String, - tool_call_id: String, - }, - Function { - role: String, // "function" - content: String, - name: String, - }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum UserMessageContent { - Text(String), - Parts(Vec), -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -pub enum ContentPart { - #[serde(rename = "text")] - Text { text: String }, - #[serde(rename = "image_url")] - ImageUrl { image_url: ImageUrl }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ImageUrl { - pub url: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub detail: Option, // "auto", "low", or "high" -} - -// ============= Response Format Types ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -pub enum ResponseFormat { - #[serde(rename = "text")] - Text, - #[serde(rename = "json_object")] - JsonObject, - #[serde(rename = "json_schema")] - JsonSchema { json_schema: JsonSchemaFormat }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct JsonSchemaFormat { - pub name: String, - pub schema: Value, - #[serde(skip_serializing_if = "Option::is_none")] - pub strict: Option, -} - -// ============= Tool/Function Types ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Tool { - #[serde(rename = "type")] - pub tool_type: String, // "function" - pub function: Function, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Function { - pub name: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - pub parameters: Value, // JSON Schema -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum ToolChoice { - None, - Auto, - Required, - Function { - #[serde(rename = "type")] - tool_type: String, // "function" - function: FunctionChoice, - }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct FunctionChoice { - pub name: String, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ToolCall { - pub id: String, - #[serde(rename = "type")] - pub tool_type: String, // "function" - pub function: FunctionCallResponse, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum FunctionCall { - None, - Auto, - Function { name: String }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct FunctionCallResponse { - pub name: String, - pub arguments: String, // JSON string -} - -// ============= Streaming Delta Types ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatMessageDelta { - #[serde(skip_serializing_if = "Option::is_none")] - pub role: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub function_call: Option, - /// Reasoning content delta for O1-style models (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning_content: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ToolCallDelta { - pub index: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(rename = "type")] - pub tool_type: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub function: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct FunctionCallDelta { - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option, -} diff --git a/sgl-router/src/protocols/openai/chat/validation.rs b/sgl-router/src/protocols/openai/chat/validation.rs deleted file mode 100644 index cb9f5071b..000000000 --- a/sgl-router/src/protocols/openai/chat/validation.rs +++ /dev/null @@ -1,477 +0,0 @@ -// Validation implementation for Chat Completions API - -use crate::protocols::common::StringOrArray; -use crate::protocols::openai::chat::request::ChatCompletionRequest; -use crate::protocols::openai::chat::types::{ChatMessage, ResponseFormat, UserMessageContent}; -use crate::protocols::validation::{ - utils::{ - validate_common_request_params, validate_conflicting_parameters, - validate_mutually_exclusive_options, validate_non_empty_array, - }, - CompletionCountProvider, LogProbsProvider, SGLangExtensionsProvider, SamplingOptionsProvider, - StopConditionsProvider, TokenLimitsProvider, ValidatableRequest, ValidationError, -}; - -impl SamplingOptionsProvider for ChatCompletionRequest { - fn get_temperature(&self) -> Option { - self.temperature - } - fn get_top_p(&self) -> Option { - self.top_p - } - fn get_frequency_penalty(&self) -> Option { - self.frequency_penalty - } - fn get_presence_penalty(&self) -> Option { - self.presence_penalty - } -} - -impl StopConditionsProvider for ChatCompletionRequest { - fn get_stop_sequences(&self) -> Option<&StringOrArray> { - self.stop.as_ref() - } -} - -impl TokenLimitsProvider for ChatCompletionRequest { - fn get_max_tokens(&self) -> Option { - // Prefer max_completion_tokens over max_tokens if both are set - self.max_completion_tokens.or(self.max_tokens) - } - - fn get_min_tokens(&self) -> Option { - self.min_tokens - } -} - -impl LogProbsProvider for ChatCompletionRequest { - fn get_logprobs(&self) -> Option { - // For chat API, logprobs is a boolean, return 1 if true for validation purposes - if self.logprobs { - Some(1) - } else { - None - } - } - - fn get_top_logprobs(&self) -> Option { - self.top_logprobs - } -} - -impl SGLangExtensionsProvider for ChatCompletionRequest { - fn get_top_k(&self) -> Option { - self.top_k - } - - fn get_min_p(&self) -> Option { - self.min_p - } - - fn get_repetition_penalty(&self) -> Option { - self.repetition_penalty - } -} - -impl CompletionCountProvider for ChatCompletionRequest { - fn get_n(&self) -> Option { - self.n - } -} - -impl ChatCompletionRequest { - /// Validate message-specific requirements - pub fn validate_messages(&self) -> Result<(), ValidationError> { - // Ensure messages array is not empty - validate_non_empty_array(&self.messages, "messages")?; - - // Validate message content is not empty - for (i, msg) in self.messages.iter().enumerate() { - if let ChatMessage::User { content, .. } = msg { - match content { - UserMessageContent::Text(text) if text.is_empty() => { - return Err(ValidationError::InvalidValue { - parameter: format!("messages[{}].content", i), - value: "empty".to_string(), - reason: "message content cannot be empty".to_string(), - }); - } - UserMessageContent::Parts(parts) if parts.is_empty() => { - return Err(ValidationError::InvalidValue { - parameter: format!("messages[{}].content", i), - value: "empty array".to_string(), - reason: "message content parts cannot be empty".to_string(), - }); - } - _ => {} - } - } - } - - Ok(()) - } - - /// Validate response format if specified - pub fn validate_response_format(&self) -> Result<(), ValidationError> { - if let Some(ResponseFormat::JsonSchema { json_schema }) = &self.response_format { - if json_schema.name.is_empty() { - return Err(ValidationError::InvalidValue { - parameter: "response_format.json_schema.name".to_string(), - value: "empty".to_string(), - reason: "JSON schema name cannot be empty".to_string(), - }); - } - } - Ok(()) - } - - /// Validate chat API specific logprobs requirements - pub fn validate_chat_logprobs(&self) -> Result<(), ValidationError> { - // In chat API, if logprobs=true, top_logprobs must be specified - if self.logprobs && self.top_logprobs.is_none() { - return Err(ValidationError::MissingRequired { - parameter: "top_logprobs".to_string(), - }); - } - - // If top_logprobs is specified, logprobs should be true - if self.top_logprobs.is_some() && !self.logprobs { - return Err(ValidationError::InvalidValue { - parameter: "logprobs".to_string(), - value: "false".to_string(), - reason: "must be true when top_logprobs is specified".to_string(), - }); - } - - Ok(()) - } - - /// Validate cross-parameter relationships specific to chat completions - pub fn validate_chat_cross_parameters(&self) -> Result<(), ValidationError> { - // Validate that both max_tokens and max_completion_tokens aren't set - validate_conflicting_parameters( - "max_tokens", - self.max_tokens.is_some(), - "max_completion_tokens", - self.max_completion_tokens.is_some(), - "cannot specify both max_tokens and max_completion_tokens", - )?; - - // Validate that tools and functions aren't both specified (deprecated) - validate_conflicting_parameters( - "tools", - self.tools.is_some(), - "functions", - self.functions.is_some(), - "functions is deprecated, use tools instead", - )?; - - // Validate structured output constraints don't conflict with JSON response format - let has_json_format = matches!( - self.response_format, - Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. }) - ); - - validate_conflicting_parameters( - "response_format", - has_json_format, - "regex", - self.regex.is_some(), - "cannot use regex constraint with JSON response format", - )?; - - validate_conflicting_parameters( - "response_format", - has_json_format, - "ebnf", - self.ebnf.is_some(), - "cannot use EBNF constraint with JSON response format", - )?; - - // Only one structured output constraint should be active - let structured_constraints = [ - ("regex", self.regex.is_some()), - ("ebnf", self.ebnf.is_some()), - ( - "json_schema", - matches!( - self.response_format, - Some(ResponseFormat::JsonSchema { .. }) - ), - ), - ]; - - validate_mutually_exclusive_options( - &structured_constraints, - "Only one structured output constraint (regex, ebnf, or json_schema) can be active at a time", - )?; - - Ok(()) - } -} - -impl ValidatableRequest for ChatCompletionRequest { - fn validate(&self) -> Result<(), ValidationError> { - // Call the common validation function from the validation module - validate_common_request_params(self)?; - - // Then validate chat-specific parameters - self.validate_messages()?; - self.validate_response_format()?; - self.validate_chat_logprobs()?; - self.validate_chat_cross_parameters()?; - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::protocols::openai::chat::types::*; - - fn create_valid_request() -> ChatCompletionRequest { - ChatCompletionRequest { - model: "gpt-4".to_string(), - messages: vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Text("Hello".to_string()), - name: None, - }], - temperature: Some(1.0), - top_p: Some(0.9), - n: Some(1), - stream: false, - stream_options: None, - stop: None, - max_tokens: Some(100), - max_completion_tokens: None, - presence_penalty: Some(0.0), - frequency_penalty: Some(0.0), - logit_bias: None, - user: None, - seed: None, - logprobs: false, - top_logprobs: None, - response_format: None, - tools: None, - tool_choice: None, - parallel_tool_calls: None, - functions: None, - function_call: None, - // SGLang extensions - top_k: None, - min_p: None, - min_tokens: None, - repetition_penalty: None, - regex: None, - ebnf: None, - stop_token_ids: None, - no_stop_trim: false, - ignore_eos: false, - continue_final_message: false, - skip_special_tokens: true, - lora_path: None, - session_params: None, - separate_reasoning: true, - stream_reasoning: true, - return_hidden_states: false, - } - } - - #[test] - fn test_valid_chat_request() { - let request = create_valid_request(); - assert!(request.validate().is_ok()); - } - - #[test] - fn test_invalid_temperature() { - let mut request = create_valid_request(); - request.temperature = Some(3.0); // Too high - - let result = request.validate(); - assert!(result.is_err()); - match result.unwrap_err() { - ValidationError::OutOfRange { parameter, .. } => { - assert_eq!(parameter, "temperature"); - } - _ => panic!("Expected OutOfRange error"), - } - } - - #[test] - fn test_invalid_top_p() { - let mut request = create_valid_request(); - request.top_p = Some(1.5); // Too high - - assert!(request.validate().is_err()); - } - - #[test] - fn test_too_many_stop_sequences() { - let mut request = create_valid_request(); - request.stop = Some(StringOrArray::Array(vec![ - "stop1".to_string(), - "stop2".to_string(), - "stop3".to_string(), - "stop4".to_string(), - "stop5".to_string(), // Too many - ])); - - let result = request.validate(); - assert!(result.is_err()); - } - - #[test] - fn test_empty_stop_sequence() { - let mut request = create_valid_request(); - request.stop = Some(StringOrArray::String("".to_string())); - - let result = request.validate(); - assert!(result.is_err()); - match result.unwrap_err() { - ValidationError::InvalidValue { - parameter, reason, .. - } => { - assert_eq!(parameter, "stop"); - assert!(reason.contains("empty")); - } - _ => panic!("Expected InvalidValue error"), - } - } - - #[test] - fn test_empty_messages() { - let mut request = create_valid_request(); - request.messages = vec![]; - - let result = request.validate(); - assert!(result.is_err()); - match result.unwrap_err() { - ValidationError::MissingRequired { parameter } => { - assert_eq!(parameter, "messages"); - } - _ => panic!("Expected MissingRequired error"), - } - } - - #[test] - fn test_invalid_n_parameter() { - let mut request = create_valid_request(); - request.n = Some(0); - - let result = request.validate(); - assert!(result.is_err()); - - request.n = Some(20); // Too high - assert!(request.validate().is_err()); - } - - #[test] - fn test_conflicting_max_tokens() { - let mut request = create_valid_request(); - request.max_tokens = Some(100); - request.max_completion_tokens = Some(200); - - let result = request.validate(); - assert!(result.is_err()); - match result.unwrap_err() { - ValidationError::ConflictingParameters { - parameter1, - parameter2, - .. - } => { - assert!(parameter1.contains("max_tokens")); - assert!(parameter2.contains("max_completion_tokens")); - } - _ => panic!("Expected ConflictingParameters error"), - } - } - - #[test] - fn test_logprobs_without_top_logprobs() { - let mut request = create_valid_request(); - request.logprobs = true; - request.top_logprobs = None; - - let result = request.validate(); - assert!(result.is_err()); - } - - #[test] - fn test_sglang_extensions() { - let mut request = create_valid_request(); - - // Valid top_k - request.top_k = Some(-1); // Disabled - assert!(request.validate().is_ok()); - - request.top_k = Some(50); // Valid positive - assert!(request.validate().is_ok()); - - request.top_k = Some(0); // Invalid - assert!(request.validate().is_err()); - - // Valid min_p - request.top_k = None; - request.min_p = Some(0.1); - assert!(request.validate().is_ok()); - - request.min_p = Some(1.5); // Too high - assert!(request.validate().is_err()); - - // Valid repetition_penalty - request.min_p = None; - request.repetition_penalty = Some(1.2); - assert!(request.validate().is_ok()); - - request.repetition_penalty = Some(0.0); // Valid - minimum value - assert!(request.validate().is_ok()); - - request.repetition_penalty = Some(2.0); // Valid - maximum value - assert!(request.validate().is_ok()); - - request.repetition_penalty = Some(2.1); // Invalid - too high - assert!(request.validate().is_err()); - - request.repetition_penalty = Some(-0.1); // Invalid - negative - assert!(request.validate().is_err()); - } - - #[test] - fn test_structured_output_conflicts() { - let mut request = create_valid_request(); - - // JSON response format with regex should conflict - request.response_format = Some(ResponseFormat::JsonObject); - request.regex = Some(".*".to_string()); - - let result = request.validate(); - assert!(result.is_err()); - - // Multiple structured constraints should conflict - request.response_format = None; - request.regex = Some(".*".to_string()); - request.ebnf = Some("grammar".to_string()); - - let result = request.validate(); - assert!(result.is_err()); - } - - #[test] - fn test_min_max_tokens_validation() { - let mut request = create_valid_request(); - request.min_tokens = Some(100); - request.max_tokens = Some(50); // min > max - - let result = request.validate(); - assert!(result.is_err()); - - // Should work with max_completion_tokens too - request.max_tokens = None; - request.max_completion_tokens = Some(200); - request.min_tokens = Some(100); - assert!(request.validate().is_ok()); - } -} diff --git a/sgl-router/src/protocols/openai/common.rs b/sgl-router/src/protocols/openai/common.rs deleted file mode 100644 index 69ed6d7b4..000000000 --- a/sgl-router/src/protocols/openai/common.rs +++ /dev/null @@ -1,58 +0,0 @@ -// Common types shared across OpenAI API implementations - -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -// ============= Shared Request Components ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct StreamOptions { - #[serde(skip_serializing_if = "Option::is_none")] - pub include_usage: Option, -} - -// ============= Usage Tracking ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub completion_tokens_details: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionTokensDetails { - pub reasoning_tokens: Option, -} - -// ============= Logprobs Types ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct LogProbs { - pub tokens: Vec, - pub token_logprobs: Vec>, - pub top_logprobs: Vec>>, - pub text_offset: Vec, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatLogProbs { - pub content: Option>, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatLogProbsContent { - pub token: String, - pub logprob: f32, - pub bytes: Option>, - pub top_logprobs: Vec, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct TopLogProb { - pub token: String, - pub logprob: f32, - pub bytes: Option>, -} diff --git a/sgl-router/src/protocols/openai/completions/mod.rs b/sgl-router/src/protocols/openai/completions/mod.rs deleted file mode 100644 index c87dbbfe5..000000000 --- a/sgl-router/src/protocols/openai/completions/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -// Completions API module (v1/completions) - -pub mod request; -pub mod response; - -// Re-export main types for convenience -pub use request::CompletionRequest; -pub use response::{ - CompletionChoice, CompletionResponse, CompletionStreamChoice, CompletionStreamResponse, -}; diff --git a/sgl-router/src/protocols/openai/completions/request.rs b/sgl-router/src/protocols/openai/completions/request.rs deleted file mode 100644 index c340dc6a5..000000000 --- a/sgl-router/src/protocols/openai/completions/request.rs +++ /dev/null @@ -1,158 +0,0 @@ -// Completions API request types (v1/completions) - DEPRECATED but still supported - -use crate::protocols::common::{default_true, GenerationRequest, LoRAPath, StringOrArray}; -use crate::protocols::openai::common::StreamOptions; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionRequest { - /// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang) - pub model: String, - - /// The prompt(s) to generate completions for - pub prompt: StringOrArray, - - /// The suffix that comes after a completion of inserted text - #[serde(skip_serializing_if = "Option::is_none")] - pub suffix: Option, - - /// The maximum number of tokens to generate - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - - /// What sampling temperature to use, between 0 and 2 - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - - /// An alternative to sampling with temperature (nucleus sampling) - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - - /// How many completions to generate for each prompt - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, - - /// Whether to stream back partial progress - #[serde(default)] - pub stream: bool, - - /// Options for streaming response - #[serde(skip_serializing_if = "Option::is_none")] - pub stream_options: Option, - - /// Include the log probabilities on the logprobs most likely tokens - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - - /// Echo back the prompt in addition to the completion - #[serde(default)] - pub echo: bool, - - /// Up to 4 sequences where the API will stop generating further tokens - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - - /// Generates best_of completions server-side and returns the "best" - #[serde(skip_serializing_if = "Option::is_none")] - pub best_of: Option, - - /// Modify the likelihood of specified tokens appearing in the completion - #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, - - /// A unique identifier representing your end-user - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, - - /// If specified, our system will make a best effort to sample deterministically - #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, - - // ============= SGLang Extensions ============= - /// Top-k sampling parameter (-1 to disable) - #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, - - /// Min-p nucleus sampling parameter - #[serde(skip_serializing_if = "Option::is_none")] - pub min_p: Option, - - /// Minimum number of tokens to generate - #[serde(skip_serializing_if = "Option::is_none")] - pub min_tokens: Option, - - /// Repetition penalty for reducing repetitive text - #[serde(skip_serializing_if = "Option::is_none")] - pub repetition_penalty: Option, - - /// Regex constraint for output generation - #[serde(skip_serializing_if = "Option::is_none")] - pub regex: Option, - - /// EBNF grammar constraint for structured output - #[serde(skip_serializing_if = "Option::is_none")] - pub ebnf: Option, - - /// JSON schema constraint for structured output - #[serde(skip_serializing_if = "Option::is_none")] - pub json_schema: Option, - - /// Specific token IDs to use as stop conditions - #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, - - /// Skip trimming stop tokens from output - #[serde(default)] - pub no_stop_trim: bool, - - /// Ignore end-of-sequence tokens during generation - #[serde(default)] - pub ignore_eos: bool, - - /// Skip special tokens during detokenization - #[serde(default = "default_true")] - pub skip_special_tokens: bool, - - // ============= SGLang Extensions ============= - /// Path to LoRA adapter(s) for model customization - #[serde(skip_serializing_if = "Option::is_none")] - pub lora_path: Option, - - /// Session parameters for continual prompting - #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, - - /// Return model hidden states - #[serde(default)] - pub return_hidden_states: bool, - - /// Additional fields including bootstrap info for PD routing - #[serde(flatten)] - pub other: serde_json::Map, -} - -impl GenerationRequest for CompletionRequest { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_model(&self) -> Option<&str> { - Some(&self.model) - } - - fn extract_text_for_routing(&self) -> String { - match &self.prompt { - StringOrArray::String(s) => s.clone(), - StringOrArray::Array(v) => v.join(" "), - } - } -} diff --git a/sgl-router/src/protocols/openai/completions/response.rs b/sgl-router/src/protocols/openai/completions/response.rs deleted file mode 100644 index 4734ba134..000000000 --- a/sgl-router/src/protocols/openai/completions/response.rs +++ /dev/null @@ -1,56 +0,0 @@ -// Completions API response types - -use crate::protocols::openai::common::{LogProbs, Usage}; -use serde::{Deserialize, Serialize}; - -// ============= Regular Response ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionResponse { - pub id: String, - pub object: String, // "text_completion" - pub created: u64, - pub model: String, - pub choices: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionChoice { - pub text: String, - pub index: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - pub finish_reason: Option, // "stop", "length", "content_filter", etc. - /// Information about which stop condition was matched - #[serde(skip_serializing_if = "Option::is_none")] - pub matched_stop: Option, // Can be string or integer - /// Hidden states from the model (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - pub hidden_states: Option>, -} - -// ============= Streaming Response ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionStreamResponse { - pub id: String, - pub object: String, // "text_completion" - pub created: u64, - pub choices: Vec, - pub model: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionStreamChoice { - pub text: String, - pub index: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - pub finish_reason: Option, -} diff --git a/sgl-router/src/protocols/openai/errors.rs b/sgl-router/src/protocols/openai/errors.rs deleted file mode 100644 index 9ec6b2e0b..000000000 --- a/sgl-router/src/protocols/openai/errors.rs +++ /dev/null @@ -1,19 +0,0 @@ -// OpenAI API error response types - -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ErrorResponse { - pub error: ErrorDetail, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ErrorDetail { - pub message: String, - #[serde(rename = "type")] - pub error_type: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub param: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub code: Option, -} diff --git a/sgl-router/src/protocols/openai/mod.rs b/sgl-router/src/protocols/openai/mod.rs deleted file mode 100644 index 08495b92b..000000000 --- a/sgl-router/src/protocols/openai/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -// OpenAI protocol module -// This module contains all OpenAI API-compatible types and future validation logic - -pub mod chat; -pub mod common; -pub mod completions; -pub mod errors; -pub mod responses; diff --git a/sgl-router/src/protocols/openai/responses/mod.rs b/sgl-router/src/protocols/openai/responses/mod.rs deleted file mode 100644 index e513116fd..000000000 --- a/sgl-router/src/protocols/openai/responses/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -// Responses API module - -pub mod request; -pub mod response; -pub mod types; - -// Re-export main types for convenience -pub use request::ResponsesRequest; -pub use response::ResponsesResponse; -pub use types::*; diff --git a/sgl-router/src/protocols/openai/responses/request.rs b/sgl-router/src/protocols/openai/responses/request.rs deleted file mode 100644 index 575b487de..000000000 --- a/sgl-router/src/protocols/openai/responses/request.rs +++ /dev/null @@ -1,300 +0,0 @@ -// Responses API request types - -use crate::protocols::common::{GenerationRequest, StringOrArray}; -use crate::protocols::openai::responses::types::*; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -fn generate_request_id() -> String { - format!("resp_{}", uuid::Uuid::new_v4().simple()) -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ResponsesRequest { - // ============= Core OpenAI API fields ============= - /// Run the request in the background - #[serde(default)] - pub background: bool, - - /// Fields to include in the response - #[serde(skip_serializing_if = "Option::is_none")] - pub include: Option>, - - /// Input content - can be string or structured items - pub input: ResponseInput, - - /// System instructions for the model - #[serde(skip_serializing_if = "Option::is_none")] - pub instructions: Option, - - /// Maximum number of output tokens - #[serde(skip_serializing_if = "Option::is_none")] - pub max_output_tokens: Option, - - /// Maximum number of tool calls - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tool_calls: Option, - - /// Additional metadata - #[serde(skip_serializing_if = "Option::is_none")] - pub metadata: Option>, - - /// Model to use (optional to match vLLM) - #[serde(skip_serializing_if = "Option::is_none")] - pub model: Option, - - /// Whether to enable parallel tool calls - #[serde(default = "default_true")] - pub parallel_tool_calls: bool, - - /// ID of previous response to continue from - #[serde(skip_serializing_if = "Option::is_none")] - pub previous_response_id: Option, - - /// Reasoning configuration - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning: Option, - - /// Service tier - #[serde(default)] - pub service_tier: ServiceTier, - - /// Whether to store the response - #[serde(default = "default_true")] - pub store: bool, - - /// Whether to stream the response - #[serde(default)] - pub stream: bool, - - /// Temperature for sampling - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - - /// Tool choice behavior - #[serde(default)] - pub tool_choice: ToolChoice, - - /// Available tools - #[serde(default)] - pub tools: Vec, - - /// Number of top logprobs to return - #[serde(default)] - pub top_logprobs: u32, - - /// Top-p sampling parameter - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - - /// Truncation behavior - #[serde(default)] - pub truncation: Truncation, - - /// User identifier - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, - - // ============= SGLang Extensions ============= - /// Request ID - #[serde(default = "generate_request_id")] - pub request_id: String, - - /// Request priority - #[serde(default)] - pub priority: i32, - - /// Frequency penalty - #[serde(default)] - pub frequency_penalty: f32, - - /// Presence penalty - #[serde(default)] - pub presence_penalty: f32, - - /// Stop sequences - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, - - /// Top-k sampling parameter - #[serde(default = "default_top_k")] - pub top_k: i32, - - /// Min-p sampling parameter - #[serde(default)] - pub min_p: f32, - - /// Repetition penalty - #[serde(default = "default_repetition_penalty")] - pub repetition_penalty: f32, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum ResponseInput { - Text(String), - Items(Vec), -} - -fn default_top_k() -> i32 { - -1 -} - -fn default_repetition_penalty() -> f32 { - 1.0 -} - -fn default_true() -> bool { - true -} - -impl ResponsesRequest { - /// Default sampling parameters - const DEFAULT_TEMPERATURE: f32 = 0.7; - const DEFAULT_TOP_P: f32 = 1.0; - - /// Convert to sampling parameters for generation - pub fn to_sampling_params( - &self, - default_max_tokens: u32, - default_params: Option>, - ) -> HashMap { - let mut params = HashMap::new(); - - // Use max_output_tokens if available - let max_tokens = if let Some(max_output) = self.max_output_tokens { - std::cmp::min(max_output, default_max_tokens) - } else { - default_max_tokens - }; - - // Avoid exceeding context length by minus 1 token - let max_tokens = max_tokens.saturating_sub(1); - - // Temperature - let temperature = self.temperature.unwrap_or_else(|| { - default_params - .as_ref() - .and_then(|p| p.get("temperature")) - .and_then(|v| v.as_f64()) - .map(|v| v as f32) - .unwrap_or(Self::DEFAULT_TEMPERATURE) - }); - - // Top-p - let top_p = self.top_p.unwrap_or_else(|| { - default_params - .as_ref() - .and_then(|p| p.get("top_p")) - .and_then(|v| v.as_f64()) - .map(|v| v as f32) - .unwrap_or(Self::DEFAULT_TOP_P) - }); - - params.insert( - "max_new_tokens".to_string(), - serde_json::Value::Number(serde_json::Number::from(max_tokens)), - ); - params.insert( - "temperature".to_string(), - serde_json::Value::Number(serde_json::Number::from_f64(temperature as f64).unwrap()), - ); - params.insert( - "top_p".to_string(), - serde_json::Value::Number(serde_json::Number::from_f64(top_p as f64).unwrap()), - ); - params.insert( - "frequency_penalty".to_string(), - serde_json::Value::Number( - serde_json::Number::from_f64(self.frequency_penalty as f64).unwrap(), - ), - ); - params.insert( - "presence_penalty".to_string(), - serde_json::Value::Number( - serde_json::Number::from_f64(self.presence_penalty as f64).unwrap(), - ), - ); - params.insert( - "top_k".to_string(), - serde_json::Value::Number(serde_json::Number::from(self.top_k)), - ); - params.insert( - "min_p".to_string(), - serde_json::Value::Number(serde_json::Number::from_f64(self.min_p as f64).unwrap()), - ); - params.insert( - "repetition_penalty".to_string(), - serde_json::Value::Number( - serde_json::Number::from_f64(self.repetition_penalty as f64).unwrap(), - ), - ); - - if let Some(ref stop) = self.stop { - match serde_json::to_value(stop) { - Ok(value) => params.insert("stop".to_string(), value), - Err(_) => params.insert("stop".to_string(), serde_json::Value::Null), - }; - } - - // Apply any additional default parameters - if let Some(default_params) = default_params { - for (key, value) in default_params { - params.entry(key).or_insert(value); - } - } - - params - } -} - -impl GenerationRequest for ResponsesRequest { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_model(&self) -> Option<&str> { - self.model.as_deref() - } - - fn extract_text_for_routing(&self) -> String { - match &self.input { - ResponseInput::Text(text) => text.clone(), - ResponseInput::Items(items) => items - .iter() - .filter_map(|item| match item { - ResponseInputOutputItem::Message { content, .. } => { - let texts: Vec = content - .iter() - .map(|part| match part { - ResponseContentPart::OutputText { text, .. } => text.clone(), - }) - .collect(); - if texts.is_empty() { - None - } else { - Some(texts.join(" ")) - } - } - ResponseInputOutputItem::Reasoning { content, .. } => { - let texts: Vec = content - .iter() - .map(|part| match part { - ResponseReasoningContent::ReasoningText { text } => text.clone(), - }) - .collect(); - if texts.is_empty() { - None - } else { - Some(texts.join(" ")) - } - } - ResponseInputOutputItem::FunctionToolCall { arguments, .. } => { - Some(arguments.clone()) - } - }) - .collect::>() - .join(" "), - } - } -} diff --git a/sgl-router/src/protocols/openai/responses/response.rs b/sgl-router/src/protocols/openai/responses/response.rs deleted file mode 100644 index b124ce7d4..000000000 --- a/sgl-router/src/protocols/openai/responses/response.rs +++ /dev/null @@ -1,280 +0,0 @@ -// Responses API response types - -use crate::protocols::openai::responses::request::ResponsesRequest; -use crate::protocols::openai::responses::types::*; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -fn generate_response_id() -> String { - format!("resp_{}", uuid::Uuid::new_v4().simple()) -} - -fn current_timestamp() -> i64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs() as i64 -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ResponsesResponse { - /// Response ID - #[serde(default = "generate_response_id")] - pub id: String, - - /// Object type - #[serde(default = "default_object_type")] - pub object: String, - - /// Creation timestamp - #[serde(default = "current_timestamp")] - pub created_at: i64, - - /// Model name - pub model: String, - - /// Output items - #[serde(default)] - pub output: Vec, - - /// Response status - pub status: ResponseStatus, - - /// Usage information - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, - - /// Whether parallel tool calls are enabled - #[serde(default = "default_true")] - pub parallel_tool_calls: bool, - - /// Tool choice setting - #[serde(default = "default_tool_choice")] - pub tool_choice: String, - - /// Available tools - #[serde(default)] - pub tools: Vec, -} - -fn default_object_type() -> String { - "response".to_string() -} - -fn default_true() -> bool { - true -} - -fn default_tool_choice() -> String { - "auto".to_string() -} - -impl ResponsesResponse { - /// Create a response from a request - #[allow(clippy::too_many_arguments)] - pub fn from_request( - request: &ResponsesRequest, - _sampling_params: &HashMap, - model_name: String, - created_time: i64, - output: Vec, - status: ResponseStatus, - usage: Option, - ) -> Self { - Self { - id: request.request_id.clone(), - object: "response".to_string(), - created_at: created_time, - model: model_name, - output, - status, - usage, - parallel_tool_calls: request.parallel_tool_calls, - tool_choice: match request.tool_choice { - ToolChoice::Auto => "auto".to_string(), - ToolChoice::Required => "required".to_string(), - ToolChoice::None => "none".to_string(), - }, - tools: request.tools.clone(), - } - } - - /// Create a new response with default values - pub fn new(request_id: String, model: String, status: ResponseStatus) -> Self { - Self { - id: request_id, - object: "response".to_string(), - created_at: current_timestamp(), - model, - output: Vec::new(), - status, - usage: None, - parallel_tool_calls: true, - tool_choice: "auto".to_string(), - tools: Vec::new(), - } - } - - /// Add an output item to the response - pub fn add_output(&mut self, item: ResponseOutputItem) { - self.output.push(item); - } - - /// Set the usage information - pub fn set_usage(&mut self, usage: UsageInfo) { - self.usage = Some(usage); - } - - /// Update the status - pub fn set_status(&mut self, status: ResponseStatus) { - self.status = status; - } - - /// Check if the response is complete - pub fn is_complete(&self) -> bool { - matches!(self.status, ResponseStatus::Completed) - } - - /// Check if the response is in progress - pub fn is_in_progress(&self) -> bool { - matches!(self.status, ResponseStatus::InProgress) - } - - /// Check if the response failed - pub fn is_failed(&self) -> bool { - matches!(self.status, ResponseStatus::Failed) - } - - /// Check if the response was cancelled - pub fn is_cancelled(&self) -> bool { - matches!(self.status, ResponseStatus::Cancelled) - } - - /// Check if the response is queued - pub fn is_queued(&self) -> bool { - matches!(self.status, ResponseStatus::Queued) - } - - /// Convert usage to OpenAI Responses API format - pub fn usage_in_response_format( - &self, - ) -> Option { - self.usage.as_ref().map(|usage| usage.to_response_usage()) - } - - /// Get the response as a JSON value with usage in response format - pub fn to_response_format(&self) -> serde_json::Value { - let mut response = serde_json::to_value(self).unwrap_or(serde_json::Value::Null); - - // Convert usage to response format if present - if let Some(usage) = &self.usage { - if let Ok(usage_value) = serde_json::to_value(usage.to_response_usage()) { - response["usage"] = usage_value; - } - } - - response - } -} - -// ============= Helper Functions ============= - -impl ResponseOutputItem { - /// Create a new message output item - pub fn new_message( - id: String, - role: String, - content: Vec, - status: String, - ) -> Self { - Self::Message { - id, - role, - content, - status, - } - } - - /// Create a new reasoning output item - pub fn new_reasoning( - id: String, - summary: Vec, - content: Vec, - status: Option, - ) -> Self { - Self::Reasoning { - id, - summary, - content, - status, - } - } - - /// Create a new function tool call output item - pub fn new_function_tool_call( - id: String, - name: String, - arguments: String, - output: Option, - status: String, - ) -> Self { - Self::FunctionToolCall { - id, - name, - arguments, - output, - status, - } - } -} - -impl ResponseContentPart { - /// Create a new text content part - pub fn new_text( - text: String, - annotations: Vec, - logprobs: Option, - ) -> Self { - Self::OutputText { - text, - annotations, - logprobs, - } - } -} - -impl ResponseReasoningContent { - /// Create a new reasoning text content - pub fn new_reasoning_text(text: String) -> Self { - Self::ReasoningText { text } - } -} - -impl UsageInfo { - /// Create a new usage info with token counts - pub fn new(prompt_tokens: u32, completion_tokens: u32, reasoning_tokens: Option) -> Self { - Self { - prompt_tokens, - completion_tokens, - total_tokens: prompt_tokens + completion_tokens, - reasoning_tokens, - prompt_tokens_details: None, - } - } - - /// Create usage info with cached token details - pub fn new_with_cached( - prompt_tokens: u32, - completion_tokens: u32, - reasoning_tokens: Option, - cached_tokens: u32, - ) -> Self { - Self { - prompt_tokens, - completion_tokens, - total_tokens: prompt_tokens + completion_tokens, - reasoning_tokens, - prompt_tokens_details: Some(PromptTokenUsageInfo { cached_tokens }), - } - } -} diff --git a/sgl-router/src/protocols/openai/responses/types.rs b/sgl-router/src/protocols/openai/responses/types.rs deleted file mode 100644 index 588772662..000000000 --- a/sgl-router/src/protocols/openai/responses/types.rs +++ /dev/null @@ -1,296 +0,0 @@ -// Supporting types for Responses API - -use crate::protocols::openai::common::ChatLogProbs; -use serde::{Deserialize, Serialize}; - -// ============= Tool Definitions ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ResponseTool { - #[serde(rename = "type")] - pub r#type: ResponseToolType, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ResponseToolType { - WebSearchPreview, - CodeInterpreter, -} - -// ============= Reasoning Configuration ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ResponseReasoningParam { - #[serde(default = "default_reasoning_effort")] - pub effort: Option, -} - -fn default_reasoning_effort() -> Option { - Some(ReasoningEffort::Medium) -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ReasoningEffort { - Low, - Medium, - High, -} - -// ============= Input/Output Items ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -#[serde(rename_all = "snake_case")] -pub enum ResponseInputOutputItem { - #[serde(rename = "message")] - Message { - id: String, - role: String, - content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - status: Option, - }, - #[serde(rename = "reasoning")] - Reasoning { - id: String, - #[serde(skip_serializing_if = "Vec::is_empty")] - summary: Vec, - content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - status: Option, - }, - #[serde(rename = "function_tool_call")] - FunctionToolCall { - id: String, - name: String, - arguments: String, - #[serde(skip_serializing_if = "Option::is_none")] - output: Option, - #[serde(skip_serializing_if = "Option::is_none")] - status: Option, - }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -#[serde(rename_all = "snake_case")] -pub enum ResponseContentPart { - #[serde(rename = "output_text")] - OutputText { - text: String, - #[serde(skip_serializing_if = "Vec::is_empty")] - annotations: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - logprobs: Option, - }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -#[serde(rename_all = "snake_case")] -pub enum ResponseReasoningContent { - #[serde(rename = "reasoning_text")] - ReasoningText { text: String }, -} - -// ============= Output Items for Response ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -#[serde(rename_all = "snake_case")] -pub enum ResponseOutputItem { - #[serde(rename = "message")] - Message { - id: String, - role: String, - content: Vec, - status: String, - }, - #[serde(rename = "reasoning")] - Reasoning { - id: String, - #[serde(skip_serializing_if = "Vec::is_empty")] - summary: Vec, - content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - status: Option, - }, - #[serde(rename = "function_tool_call")] - FunctionToolCall { - id: String, - name: String, - arguments: String, - #[serde(skip_serializing_if = "Option::is_none")] - output: Option, - status: String, - }, -} - -// ============= Service Tier ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ServiceTier { - Auto, - Default, - Flex, - Scale, - Priority, -} - -impl Default for ServiceTier { - fn default() -> Self { - Self::Auto - } -} - -// ============= Tool Choice ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ToolChoice { - Auto, - Required, - None, -} - -impl Default for ToolChoice { - fn default() -> Self { - Self::Auto - } -} - -// ============= Truncation ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum Truncation { - Auto, - Disabled, -} - -impl Default for Truncation { - fn default() -> Self { - Self::Disabled - } -} - -// ============= Response Status ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ResponseStatus { - Queued, - InProgress, - Completed, - Failed, - Cancelled, -} - -// ============= Include Fields ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum IncludeField { - #[serde(rename = "code_interpreter_call.outputs")] - CodeInterpreterCallOutputs, - #[serde(rename = "computer_call_output.output.image_url")] - ComputerCallOutputImageUrl, - #[serde(rename = "file_search_call.results")] - FileSearchCallResults, - #[serde(rename = "message.input_image.image_url")] - MessageInputImageUrl, - #[serde(rename = "message.output_text.logprobs")] - MessageOutputTextLogprobs, - #[serde(rename = "reasoning.encrypted_content")] - ReasoningEncryptedContent, -} - -// ============= Usage Info ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct UsageInfo { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt_tokens_details: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct PromptTokenUsageInfo { - pub cached_tokens: u32, -} - -// ============= Response Usage Format ============= - -/// OpenAI Responses API usage format (different from standard UsageInfo) -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ResponseUsage { - pub input_tokens: u32, - pub output_tokens: u32, - pub total_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub input_tokens_details: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub output_tokens_details: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct InputTokensDetails { - pub cached_tokens: u32, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct OutputTokensDetails { - pub reasoning_tokens: u32, -} - -impl UsageInfo { - /// Convert to OpenAI Responses API format - pub fn to_response_usage(&self) -> ResponseUsage { - ResponseUsage { - input_tokens: self.prompt_tokens, - output_tokens: self.completion_tokens, - total_tokens: self.total_tokens, - input_tokens_details: self.prompt_tokens_details.as_ref().map(|details| { - InputTokensDetails { - cached_tokens: details.cached_tokens, - } - }), - output_tokens_details: self.reasoning_tokens.map(|tokens| OutputTokensDetails { - reasoning_tokens: tokens, - }), - } - } -} - -impl From for ResponseUsage { - fn from(usage: UsageInfo) -> Self { - usage.to_response_usage() - } -} - -impl ResponseUsage { - /// Convert back to standard UsageInfo format - pub fn to_usage_info(&self) -> UsageInfo { - UsageInfo { - prompt_tokens: self.input_tokens, - completion_tokens: self.output_tokens, - total_tokens: self.total_tokens, - reasoning_tokens: self - .output_tokens_details - .as_ref() - .map(|details| details.reasoning_tokens), - prompt_tokens_details: self.input_tokens_details.as_ref().map(|details| { - PromptTokenUsageInfo { - cached_tokens: details.cached_tokens, - } - }), - } - } -} diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs new file mode 100644 index 000000000..986f991cb --- /dev/null +++ b/sgl-router/src/protocols/spec.rs @@ -0,0 +1,1867 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; + +// # Protocol Specifications +// +// This module contains all protocol definitions for OpenAI and SGLang APIs. +// +// ## Table of Contents +// +// 1. **OPENAI SPEC - Chat Completions API** +// - Message Types +// - Response Format Types +// - Tool/Function Types +// - Streaming Delta Types +// - Request/Response structures +// +// 2. **OPENAI SPEC - Completions API** +// - Request/Response structures +// - Streaming support +// +// 3. **OPENAI SPEC - Responses API** +// - Tool Definitions +// - Reasoning Configuration +// - Input/Output Items +// - Service Tier & Tool Choice +// - Request/Response structures +// +// 4. **OPENAI SPEC - Common** +// - Shared Request Components +// - Tool Choice Types +// - Usage Tracking +// - Logprobs Types +// - Error Response Types +// +// 5. **SGLANG SPEC - GENERATE API** +// - Generate Parameters +// - Sampling Parameters +// - Request/Response structures +// +// 6. **COMMON** +// - GenerationRequest trait +// - StringOrArray & LoRAPath types +// - Helper functions + +// ================================================================== +// = OPENAI SPEC - Chat Completions API = +// ================================================================== + +// ============= Message Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ChatMessage { + System { + role: String, + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + User { + role: String, // "user" + content: UserMessageContent, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + Assistant { + role: String, // "assistant" + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + function_call: Option, + /// Reasoning content for O1-style models (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_content: Option, + }, + Tool { + role: String, // "tool" + content: String, + tool_call_id: String, + }, + Function { + role: String, // "function" + content: String, + name: String, + }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum UserMessageContent { + Text(String), + Parts(Vec), +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +pub enum ContentPart { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + ImageUrl { image_url: ImageUrl }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ImageUrl { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, // "auto", "low", or "high" +} + +// ============= Response Format Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +pub enum ResponseFormat { + #[serde(rename = "text")] + Text, + #[serde(rename = "json_object")] + JsonObject, + #[serde(rename = "json_schema")] + JsonSchema { json_schema: JsonSchemaFormat }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct JsonSchemaFormat { + pub name: String, + pub schema: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +// ============= Streaming Delta Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatMessageDelta { + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, + /// Reasoning content delta for O1-style models (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolCallDelta { + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "type")] + pub tool_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionCallDelta { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, +} + +// ============= Request ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionRequest { + /// ID of the model to use + pub model: String, + + /// A list of messages comprising the conversation so far + pub messages: Vec, + + /// What sampling temperature to use, between 0 and 2 + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// An alternative to sampling with temperature + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// How many chat completion choices to generate for each input message + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + + /// If set, partial message deltas will be sent + #[serde(default)] + pub stream: bool, + + /// Options for streaming response + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + + /// Up to 4 sequences where the API will stop generating further tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + + /// The maximum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// An upper bound for the number of tokens that can be generated for a completion + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + + /// Modify the likelihood of specified tokens appearing in the completion + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + + /// A unique identifier representing your end-user + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// If specified, our system will make a best effort to sample deterministically + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + + /// Whether to return log probabilities of the output tokens + #[serde(default)] + pub logprobs: bool, + + /// An integer between 0 and 20 specifying the number of most likely tokens to return + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option, + + /// An object specifying the format that the model must output + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + + /// A list of tools the model may call + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + /// Controls which (if any) tool is called by the model + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + + /// Whether to enable parallel function calling during tool use + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + + /// Deprecated: use tools instead + #[serde(skip_serializing_if = "Option::is_none")] + pub functions: Option>, + + /// Deprecated: use tool_choice instead + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, + + // ============= SGLang Extensions ============= + /// Top-k sampling parameter (-1 to disable) + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + + /// Min-p nucleus sampling parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub min_p: Option, + + /// Minimum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + + /// Repetition penalty for reducing repetitive text + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + + /// Regex constraint for output generation + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + + /// EBNF grammar constraint for structured output + #[serde(skip_serializing_if = "Option::is_none")] + pub ebnf: Option, + + /// Specific token IDs to use as stop conditions + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + + /// Skip trimming stop tokens from output + #[serde(default)] + pub no_stop_trim: bool, + + /// Ignore end-of-sequence tokens during generation + #[serde(default)] + pub ignore_eos: bool, + + /// Continue generating from final assistant message + #[serde(default)] + pub continue_final_message: bool, + + /// Skip special tokens during detokenization + #[serde(default = "default_true")] + pub skip_special_tokens: bool, + + // ============= SGLang Extensions ============= + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Separate reasoning content from final answer (O1-style models) + #[serde(default = "default_true")] + pub separate_reasoning: bool, + + /// Stream reasoning tokens during generation + #[serde(default = "default_true")] + pub stream_reasoning: bool, + + /// Return model hidden states + #[serde(default)] + pub return_hidden_states: bool, +} + +impl GenerationRequest for ChatCompletionRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn extract_text_for_routing(&self) -> String { + // Extract text from messages for routing decisions + self.messages + .iter() + .filter_map(|msg| match msg { + ChatMessage::System { content, .. } => Some(content.clone()), + ChatMessage::User { content, .. } => match content { + UserMessageContent::Text(text) => Some(text.clone()), + UserMessageContent::Parts(parts) => { + let texts: Vec = parts + .iter() + .filter_map(|part| match part { + ContentPart::Text { text } => Some(text.clone()), + _ => None, + }) + .collect(); + Some(texts.join(" ")) + } + }, + ChatMessage::Assistant { + content, + reasoning_content, + .. + } => { + // Combine content and reasoning content for routing decisions + let main_content = content.clone().unwrap_or_default(); + let reasoning = reasoning_content.clone().unwrap_or_default(); + if main_content.is_empty() && reasoning.is_empty() { + None + } else { + Some(format!("{} {}", main_content, reasoning).trim().to_string()) + } + } + ChatMessage::Tool { content, .. } => Some(content.clone()), + ChatMessage::Function { content, .. } => Some(content.clone()), + }) + .collect::>() + .join(" ") + } +} + +// ============= Regular Response ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, // "chat.completion" + pub created: u64, + pub model: String, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatChoice { + pub index: u32, + pub message: ChatMessage, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, // "stop", "length", "tool_calls", "content_filter", "function_call" + /// Information about which stop condition was matched + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, // Can be string or integer + /// Hidden states from the model (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + pub hidden_states: Option>, +} + +// ============= Streaming Response ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionStreamResponse { + pub id: String, + pub object: String, // "chat.completion.chunk" + pub created: u64, + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatStreamChoice { + pub index: u32, + pub delta: ChatMessageDelta, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, +} + +// ================================================================== +// = OPENAI SPEC - Completions API = +// ================================================================== +// Completions API request types (v1/completions) - DEPRECATED but still supported + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionRequest { + /// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang) + pub model: String, + + /// The prompt(s) to generate completions for + pub prompt: StringOrArray, + + /// The suffix that comes after a completion of inserted text + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, + + /// The maximum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// What sampling temperature to use, between 0 and 2 + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// An alternative to sampling with temperature (nucleus sampling) + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// How many completions to generate for each prompt + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + + /// Whether to stream back partial progress + #[serde(default)] + pub stream: bool, + + /// Options for streaming response + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + + /// Include the log probabilities on the logprobs most likely tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + + /// Echo back the prompt in addition to the completion + #[serde(default)] + pub echo: bool, + + /// Up to 4 sequences where the API will stop generating further tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + + /// Generates best_of completions server-side and returns the "best" + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of: Option, + + /// Modify the likelihood of specified tokens appearing in the completion + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + + /// A unique identifier representing your end-user + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// If specified, our system will make a best effort to sample deterministically + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + + // ============= SGLang Extensions ============= + /// Top-k sampling parameter (-1 to disable) + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + + /// Min-p nucleus sampling parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub min_p: Option, + + /// Minimum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + + /// Repetition penalty for reducing repetitive text + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + + /// Regex constraint for output generation + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + + /// EBNF grammar constraint for structured output + #[serde(skip_serializing_if = "Option::is_none")] + pub ebnf: Option, + + /// JSON schema constraint for structured output + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, + + /// Specific token IDs to use as stop conditions + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + + /// Skip trimming stop tokens from output + #[serde(default)] + pub no_stop_trim: bool, + + /// Ignore end-of-sequence tokens during generation + #[serde(default)] + pub ignore_eos: bool, + + /// Skip special tokens during detokenization + #[serde(default = "default_true")] + pub skip_special_tokens: bool, + + // ============= SGLang Extensions ============= + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Return model hidden states + #[serde(default)] + pub return_hidden_states: bool, + + /// Additional fields including bootstrap info for PD routing + #[serde(flatten)] + pub other: serde_json::Map, +} + +impl GenerationRequest for CompletionRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn extract_text_for_routing(&self) -> String { + match &self.prompt { + StringOrArray::String(s) => s.clone(), + StringOrArray::Array(v) => v.join(" "), + } + } +} + +// ============= Regular Response ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionResponse { + pub id: String, + pub object: String, // "text_completion" + pub created: u64, + pub model: String, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionChoice { + pub text: String, + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, // "stop", "length", "content_filter", etc. + /// Information about which stop condition was matched + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, // Can be string or integer + /// Hidden states from the model (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + pub hidden_states: Option>, +} + +// ============= Streaming Response ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionStreamResponse { + pub id: String, + pub object: String, // "text_completion" + pub created: u64, + pub choices: Vec, + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionStreamChoice { + pub text: String, + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, +} + +// ================================================================== +// = OPENAI SPEC - Responses API = +// ================================================================== + +// ============= Tool Definitions ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponseTool { + #[serde(rename = "type")] + pub r#type: ResponseToolType, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ResponseToolType { + WebSearchPreview, + CodeInterpreter, +} + +// ============= Reasoning Configuration ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponseReasoningParam { + #[serde(default = "default_reasoning_effort")] + pub effort: Option, +} + +fn default_reasoning_effort() -> Option { + Some(ReasoningEffort::Medium) +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ReasoningEffort { + Low, + Medium, + High, +} + +// ============= Input/Output Items ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ResponseInputOutputItem { + #[serde(rename = "message")] + Message { + id: String, + role: String, + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + #[serde(rename = "reasoning")] + Reasoning { + id: String, + #[serde(skip_serializing_if = "Vec::is_empty")] + summary: Vec, + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + #[serde(rename = "function_tool_call")] + FunctionToolCall { + id: String, + name: String, + arguments: String, + #[serde(skip_serializing_if = "Option::is_none")] + output: Option, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ResponseContentPart { + #[serde(rename = "output_text")] + OutputText { + text: String, + #[serde(skip_serializing_if = "Vec::is_empty")] + annotations: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + logprobs: Option, + }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ResponseReasoningContent { + #[serde(rename = "reasoning_text")] + ReasoningText { text: String }, +} + +// ============= Output Items for Response ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ResponseOutputItem { + #[serde(rename = "message")] + Message { + id: String, + role: String, + content: Vec, + status: String, + }, + #[serde(rename = "reasoning")] + Reasoning { + id: String, + #[serde(skip_serializing_if = "Vec::is_empty")] + summary: Vec, + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + #[serde(rename = "function_tool_call")] + FunctionToolCall { + id: String, + name: String, + arguments: String, + #[serde(skip_serializing_if = "Option::is_none")] + output: Option, + status: String, + }, +} + +// ============= Service Tier ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ServiceTier { + Auto, + Default, + Flex, + Scale, + Priority, +} + +impl Default for ServiceTier { + fn default() -> Self { + Self::Auto + } +} + +// ============= Truncation ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum Truncation { + Auto, + Disabled, +} + +impl Default for Truncation { + fn default() -> Self { + Self::Disabled + } +} + +// ============= Response Status ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ResponseStatus { + Queued, + InProgress, + Completed, + Failed, + Cancelled, +} + +// ============= Include Fields ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum IncludeField { + #[serde(rename = "code_interpreter_call.outputs")] + CodeInterpreterCallOutputs, + #[serde(rename = "computer_call_output.output.image_url")] + ComputerCallOutputImageUrl, + #[serde(rename = "file_search_call.results")] + FileSearchCallResults, + #[serde(rename = "message.input_image.image_url")] + MessageInputImageUrl, + #[serde(rename = "message.output_text.logprobs")] + MessageOutputTextLogprobs, + #[serde(rename = "reasoning.encrypted_content")] + ReasoningEncryptedContent, +} + +// ============= Usage Info ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct UsageInfo { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_tokens_details: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct PromptTokenUsageInfo { + pub cached_tokens: u32, +} + +// ============= Response Usage Format ============= + +/// OpenAI Responses API usage format (different from standard UsageInfo) +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponseUsage { + pub input_tokens: u32, + pub output_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub input_tokens_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens_details: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct InputTokensDetails { + pub cached_tokens: u32, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct OutputTokensDetails { + pub reasoning_tokens: u32, +} + +impl UsageInfo { + /// Convert to OpenAI Responses API format + pub fn to_response_usage(&self) -> ResponseUsage { + ResponseUsage { + input_tokens: self.prompt_tokens, + output_tokens: self.completion_tokens, + total_tokens: self.total_tokens, + input_tokens_details: self.prompt_tokens_details.as_ref().map(|details| { + InputTokensDetails { + cached_tokens: details.cached_tokens, + } + }), + output_tokens_details: self.reasoning_tokens.map(|tokens| OutputTokensDetails { + reasoning_tokens: tokens, + }), + } + } +} + +impl From for ResponseUsage { + fn from(usage: UsageInfo) -> Self { + usage.to_response_usage() + } +} + +impl ResponseUsage { + /// Convert back to standard UsageInfo format + pub fn to_usage_info(&self) -> UsageInfo { + UsageInfo { + prompt_tokens: self.input_tokens, + completion_tokens: self.output_tokens, + total_tokens: self.total_tokens, + reasoning_tokens: self + .output_tokens_details + .as_ref() + .map(|details| details.reasoning_tokens), + prompt_tokens_details: self.input_tokens_details.as_ref().map(|details| { + PromptTokenUsageInfo { + cached_tokens: details.cached_tokens, + } + }), + } + } +} + +fn generate_request_id() -> String { + format!("resp_{}", uuid::Uuid::new_v4().simple()) +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponsesRequest { + // ============= Core OpenAI API fields ============= + /// Run the request in the background + #[serde(default)] + pub background: bool, + + /// Fields to include in the response + #[serde(skip_serializing_if = "Option::is_none")] + pub include: Option>, + + /// Input content - can be string or structured items + pub input: ResponseInput, + + /// System instructions for the model + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + /// Maximum number of output tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + + /// Maximum number of tool calls + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tool_calls: Option, + + /// Additional metadata + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + + /// Model to use (optional to match vLLM) + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + /// Whether to enable parallel tool calls + #[serde(default = "default_true")] + pub parallel_tool_calls: bool, + + /// ID of previous response to continue from + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_response_id: Option, + + /// Reasoning configuration + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + + /// Service tier + #[serde(default)] + pub service_tier: ServiceTier, + + /// Whether to store the response + #[serde(default = "default_true")] + pub store: bool, + + /// Whether to stream the response + #[serde(default)] + pub stream: bool, + + /// Temperature for sampling + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// Tool choice behavior + #[serde(default)] + pub tool_choice: ToolChoice, + + /// Available tools + #[serde(default)] + pub tools: Vec, + + /// Number of top logprobs to return + #[serde(default)] + pub top_logprobs: u32, + + /// Top-p sampling parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// Truncation behavior + #[serde(default)] + pub truncation: Truncation, + + /// User identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + // ============= SGLang Extensions ============= + /// Request ID + #[serde(default = "generate_request_id")] + pub request_id: String, + + /// Request priority + #[serde(default)] + pub priority: i32, + + /// Frequency penalty + #[serde(default)] + pub frequency_penalty: f32, + + /// Presence penalty + #[serde(default)] + pub presence_penalty: f32, + + /// Stop sequences + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + + /// Top-k sampling parameter + #[serde(default = "default_top_k")] + pub top_k: i32, + + /// Min-p sampling parameter + #[serde(default)] + pub min_p: f32, + + /// Repetition penalty + #[serde(default = "default_repetition_penalty")] + pub repetition_penalty: f32, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ResponseInput { + Text(String), + Items(Vec), +} + +fn default_top_k() -> i32 { + -1 +} + +fn default_repetition_penalty() -> f32 { + 1.0 +} + +impl ResponsesRequest { + /// Default sampling parameters + const DEFAULT_TEMPERATURE: f32 = 0.7; + const DEFAULT_TOP_P: f32 = 1.0; + + /// Convert to sampling parameters for generation + pub fn to_sampling_params( + &self, + default_max_tokens: u32, + default_params: Option>, + ) -> HashMap { + let mut params = HashMap::new(); + + // Use max_output_tokens if available + let max_tokens = if let Some(max_output) = self.max_output_tokens { + std::cmp::min(max_output, default_max_tokens) + } else { + default_max_tokens + }; + + // Avoid exceeding context length by minus 1 token + let max_tokens = max_tokens.saturating_sub(1); + + // Temperature + let temperature = self.temperature.unwrap_or_else(|| { + default_params + .as_ref() + .and_then(|p| p.get("temperature")) + .and_then(|v| v.as_f64()) + .map(|v| v as f32) + .unwrap_or(Self::DEFAULT_TEMPERATURE) + }); + + // Top-p + let top_p = self.top_p.unwrap_or_else(|| { + default_params + .as_ref() + .and_then(|p| p.get("top_p")) + .and_then(|v| v.as_f64()) + .map(|v| v as f32) + .unwrap_or(Self::DEFAULT_TOP_P) + }); + + params.insert( + "max_new_tokens".to_string(), + serde_json::Value::Number(serde_json::Number::from(max_tokens)), + ); + params.insert( + "temperature".to_string(), + serde_json::Value::Number(serde_json::Number::from_f64(temperature as f64).unwrap()), + ); + params.insert( + "top_p".to_string(), + serde_json::Value::Number(serde_json::Number::from_f64(top_p as f64).unwrap()), + ); + params.insert( + "frequency_penalty".to_string(), + serde_json::Value::Number( + serde_json::Number::from_f64(self.frequency_penalty as f64).unwrap(), + ), + ); + params.insert( + "presence_penalty".to_string(), + serde_json::Value::Number( + serde_json::Number::from_f64(self.presence_penalty as f64).unwrap(), + ), + ); + params.insert( + "top_k".to_string(), + serde_json::Value::Number(serde_json::Number::from(self.top_k)), + ); + params.insert( + "min_p".to_string(), + serde_json::Value::Number(serde_json::Number::from_f64(self.min_p as f64).unwrap()), + ); + params.insert( + "repetition_penalty".to_string(), + serde_json::Value::Number( + serde_json::Number::from_f64(self.repetition_penalty as f64).unwrap(), + ), + ); + + if let Some(ref stop) = self.stop { + match serde_json::to_value(stop) { + Ok(value) => params.insert("stop".to_string(), value), + Err(_) => params.insert("stop".to_string(), serde_json::Value::Null), + }; + } + + // Apply any additional default parameters + if let Some(default_params) = default_params { + for (key, value) in default_params { + params.entry(key).or_insert(value); + } + } + + params + } +} + +impl GenerationRequest for ResponsesRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + self.model.as_deref() + } + + fn extract_text_for_routing(&self) -> String { + match &self.input { + ResponseInput::Text(text) => text.clone(), + ResponseInput::Items(items) => items + .iter() + .filter_map(|item| match item { + ResponseInputOutputItem::Message { content, .. } => { + let texts: Vec = content + .iter() + .map(|part| match part { + ResponseContentPart::OutputText { text, .. } => text.clone(), + }) + .collect(); + if texts.is_empty() { + None + } else { + Some(texts.join(" ")) + } + } + ResponseInputOutputItem::Reasoning { content, .. } => { + let texts: Vec = content + .iter() + .map(|part| match part { + ResponseReasoningContent::ReasoningText { text } => text.clone(), + }) + .collect(); + if texts.is_empty() { + None + } else { + Some(texts.join(" ")) + } + } + ResponseInputOutputItem::FunctionToolCall { arguments, .. } => { + Some(arguments.clone()) + } + }) + .collect::>() + .join(" "), + } + } +} + +fn generate_response_id() -> String { + format!("resp_{}", uuid::Uuid::new_v4().simple()) +} + +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs() as i64 +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponsesResponse { + /// Response ID + #[serde(default = "generate_response_id")] + pub id: String, + + /// Object type + #[serde(default = "default_object_type")] + pub object: String, + + /// Creation timestamp + #[serde(default = "current_timestamp")] + pub created_at: i64, + + /// Model name + pub model: String, + + /// Output items + #[serde(default)] + pub output: Vec, + + /// Response status + pub status: ResponseStatus, + + /// Usage information + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + + /// Whether parallel tool calls are enabled + #[serde(default = "default_true")] + pub parallel_tool_calls: bool, + + /// Tool choice setting + #[serde(default = "default_tool_choice")] + pub tool_choice: String, + + /// Available tools + #[serde(default)] + pub tools: Vec, +} + +fn default_object_type() -> String { + "response".to_string() +} + +fn default_tool_choice() -> String { + "auto".to_string() +} + +impl ResponsesResponse { + /// Create a response from a request + #[allow(clippy::too_many_arguments)] + pub fn from_request( + request: &ResponsesRequest, + _sampling_params: &HashMap, + model_name: String, + created_time: i64, + output: Vec, + status: ResponseStatus, + usage: Option, + ) -> Self { + Self { + id: request.request_id.clone(), + object: "response".to_string(), + created_at: created_time, + model: model_name, + output, + status, + usage, + parallel_tool_calls: request.parallel_tool_calls, + tool_choice: match &request.tool_choice { + ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(), + ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(), + ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(), + ToolChoice::Function { .. } => "function".to_string(), + }, + tools: request.tools.clone(), + } + } + + /// Create a new response with default values + pub fn new(request_id: String, model: String, status: ResponseStatus) -> Self { + Self { + id: request_id, + object: "response".to_string(), + created_at: current_timestamp(), + model, + output: Vec::new(), + status, + usage: None, + parallel_tool_calls: true, + tool_choice: "auto".to_string(), + tools: Vec::new(), + } + } + + /// Add an output item to the response + pub fn add_output(&mut self, item: ResponseOutputItem) { + self.output.push(item); + } + + /// Set the usage information + pub fn set_usage(&mut self, usage: UsageInfo) { + self.usage = Some(usage); + } + + /// Update the status + pub fn set_status(&mut self, status: ResponseStatus) { + self.status = status; + } + + /// Check if the response is complete + pub fn is_complete(&self) -> bool { + matches!(self.status, ResponseStatus::Completed) + } + + /// Check if the response is in progress + pub fn is_in_progress(&self) -> bool { + matches!(self.status, ResponseStatus::InProgress) + } + + /// Check if the response failed + pub fn is_failed(&self) -> bool { + matches!(self.status, ResponseStatus::Failed) + } + + /// Check if the response was cancelled + pub fn is_cancelled(&self) -> bool { + matches!(self.status, ResponseStatus::Cancelled) + } + + /// Check if the response is queued + pub fn is_queued(&self) -> bool { + matches!(self.status, ResponseStatus::Queued) + } + + /// Convert usage to OpenAI Responses API format + pub fn usage_in_response_format(&self) -> Option { + self.usage.as_ref().map(|usage| usage.to_response_usage()) + } + + /// Get the response as a JSON value with usage in response format + pub fn to_response_format(&self) -> serde_json::Value { + let mut response = serde_json::to_value(self).unwrap_or(serde_json::Value::Null); + + // Convert usage to response format if present + if let Some(usage) = &self.usage { + if let Ok(usage_value) = serde_json::to_value(usage.to_response_usage()) { + response["usage"] = usage_value; + } + } + + response + } +} + +// ============= Helper Functions ============= + +impl ResponseOutputItem { + /// Create a new message output item + pub fn new_message( + id: String, + role: String, + content: Vec, + status: String, + ) -> Self { + Self::Message { + id, + role, + content, + status, + } + } + + /// Create a new reasoning output item + pub fn new_reasoning( + id: String, + summary: Vec, + content: Vec, + status: Option, + ) -> Self { + Self::Reasoning { + id, + summary, + content, + status, + } + } + + /// Create a new function tool call output item + pub fn new_function_tool_call( + id: String, + name: String, + arguments: String, + output: Option, + status: String, + ) -> Self { + Self::FunctionToolCall { + id, + name, + arguments, + output, + status, + } + } +} + +impl ResponseContentPart { + /// Create a new text content part + pub fn new_text( + text: String, + annotations: Vec, + logprobs: Option, + ) -> Self { + Self::OutputText { + text, + annotations, + logprobs, + } + } +} + +impl ResponseReasoningContent { + /// Create a new reasoning text content + pub fn new_reasoning_text(text: String) -> Self { + Self::ReasoningText { text } + } +} + +impl UsageInfo { + /// Create a new usage info with token counts + pub fn new(prompt_tokens: u32, completion_tokens: u32, reasoning_tokens: Option) -> Self { + Self { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + reasoning_tokens, + prompt_tokens_details: None, + } + } + + /// Create usage info with cached token details + pub fn new_with_cached( + prompt_tokens: u32, + completion_tokens: u32, + reasoning_tokens: Option, + cached_tokens: u32, + ) -> Self { + Self { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + reasoning_tokens, + prompt_tokens_details: Some(PromptTokenUsageInfo { cached_tokens }), + } + } +} + +// ================================================================== +// = OPENAI SPEC - Common = +// ================================================================== + +// ============= Shared Request Components ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct StreamOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub include_usage: Option, +} + +// ============= Tool Choice Types ============= + +/// Tool choice value for simple string options +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ToolChoiceValue { + Auto, + Required, + None, +} + +/// Tool choice for both Chat Completion and Responses APIs +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ToolChoice { + Value(ToolChoiceValue), + Function { + #[serde(rename = "type")] + tool_type: String, // "function" + function: FunctionChoice, + }, +} + +impl Default for ToolChoice { + fn default() -> Self { + Self::Value(ToolChoiceValue::Auto) + } +} + +/// Function choice specification for ToolChoice::Function +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionChoice { + pub name: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Tool { + #[serde(rename = "type")] + pub tool_type: String, // "function" + pub function: Function, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Function { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub parameters: Value, // JSON Schema +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub tool_type: String, // "function" + pub function: FunctionCallResponse, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum FunctionCall { + None, + Auto, + Function { name: String }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionCallResponse { + pub name: String, + pub arguments: String, // JSON string +} + +// ============= Usage Tracking ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_tokens_details: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionTokensDetails { + pub reasoning_tokens: Option, +} + +// ============= Logprobs Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct LogProbs { + pub tokens: Vec, + pub token_logprobs: Vec>, + pub top_logprobs: Vec>>, + pub text_offset: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatLogProbs { + pub content: Option>, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatLogProbsContent { + pub token: String, + pub logprob: f32, + pub bytes: Option>, + pub top_logprobs: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct TopLogProb { + pub token: String, + pub logprob: f32, + pub bytes: Option>, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ErrorResponse { + pub error: ErrorDetail, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ErrorDetail { + pub message: String, + #[serde(rename = "type")] + pub error_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub param: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, +} + +// ================================================================== +// = SGLANG SPEC - GENERATE API = +// ================================================================== + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum InputIds { + Single(Vec), + Batch(Vec>), +} + +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct GenerateParameters { + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub decoder_input_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub do_sample: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_new_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub return_full_text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub truncate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub typical_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub watermark: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct SamplingParams { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_new_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ignore_eos: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub skip_special_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ebnf: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub no_stop_trim: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GenerateRequest { + /// The prompt to generate from (OpenAI style) + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + + /// Text input - SGLang native format + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + + /// Input IDs for tokenized input + #[serde(skip_serializing_if = "Option::is_none")] + pub input_ids: Option, + + /// Generation parameters + #[serde(default, skip_serializing_if = "Option::is_none")] + pub parameters: Option, + + /// Sampling parameters (sglang style) + #[serde(skip_serializing_if = "Option::is_none")] + pub sampling_params: Option, + + /// Whether to stream the response + #[serde(default)] + pub stream: bool, + + /// Whether to return logprobs + #[serde(default)] + pub return_logprob: bool, + + // ============= SGLang Extensions ============= + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Return model hidden states + #[serde(default)] + pub return_hidden_states: bool, + + /// Request ID for tracking + #[serde(skip_serializing_if = "Option::is_none")] + pub rid: Option, +} + +impl GenerationRequest for GenerateRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + // Generate requests typically don't have a model field + None + } + + fn extract_text_for_routing(&self) -> String { + // Check fields in priority order: text, prompt, inputs + if let Some(ref text) = self.text { + return text.clone(); + } + + if let Some(ref prompt) = self.prompt { + return match prompt { + StringOrArray::String(s) => s.clone(), + StringOrArray::Array(v) => v.join(" "), + }; + } + + if let Some(ref input_ids) = self.input_ids { + return match input_ids { + InputIds::Single(ids) => ids + .iter() + .map(|&id| id.to_string()) + .collect::>() + .join(" "), + InputIds::Batch(batches) => batches + .iter() + .flat_map(|batch| batch.iter().map(|&id| id.to_string())) + .collect::>() + .join(" "), + }; + } + + // No text input found + String::new() + } +} + +// ================================================================== +// = COMMON = +// ================================================================== + +/// Helper function for serde default value +pub fn default_true() -> bool { + true +} + +/// Common trait for all generation requests across different APIs +pub trait GenerationRequest: Send + Sync { + /// Check if the request is for streaming + fn is_stream(&self) -> bool; + + /// Get the model name if specified + fn get_model(&self) -> Option<&str>; + + /// Extract text content for routing decisions + fn extract_text_for_routing(&self) -> String; +} + +/// Helper type for string or array of strings +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum StringOrArray { + String(String), + Array(Vec), +} +impl StringOrArray { + /// Get the number of items in the StringOrArray + pub fn len(&self) -> usize { + match self { + StringOrArray::String(_) => 1, + StringOrArray::Array(arr) => arr.len(), + } + } + + /// Check if the StringOrArray is empty + pub fn is_empty(&self) -> bool { + match self { + StringOrArray::String(s) => s.is_empty(), + StringOrArray::Array(arr) => arr.is_empty(), + } + } + + /// Convert to a vector of strings + pub fn to_vec(&self) -> Vec { + match self { + StringOrArray::String(s) => vec![s.clone()], + StringOrArray::Array(arr) => arr.clone(), + } + } +} + +/// LoRA adapter path - can be single path or batch of paths (SGLang extension) +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum LoRAPath { + Single(Option), + Batch(Vec>), +} diff --git a/sgl-router/src/protocols/validation.rs b/sgl-router/src/protocols/validation.rs index 2fe89e228..69f3946ac 100644 --- a/sgl-router/src/protocols/validation.rs +++ b/sgl-router/src/protocols/validation.rs @@ -4,6 +4,11 @@ use anyhow::Result; use serde::{Deserialize, Serialize}; use std::fmt::Display; +// Import types from spec module +use crate::protocols::spec::{ + ChatCompletionRequest, ChatMessage, ResponseFormat, StringOrArray, UserMessageContent, +}; + /// Validation constants for OpenAI API parameters pub mod constants { /// Temperature range: 0.0 to 2.0 (OpenAI spec) @@ -257,7 +262,7 @@ pub mod utils { ) -> Result<(), ValidationError> { if let Some(stop) = request.get_stop_sequences() { match stop { - crate::protocols::common::StringOrArray::String(s) => { + StringOrArray::String(s) => { if s.is_empty() { return Err(ValidationError::InvalidValue { parameter: "stop".to_string(), @@ -266,7 +271,7 @@ pub mod utils { }); } } - crate::protocols::common::StringOrArray::Array(arr) => { + StringOrArray::Array(arr) => { validate_max_items(arr, constants::MAX_STOP_SEQUENCES, "stop")?; for (i, s) in arr.iter().enumerate() { if s.is_empty() { @@ -469,7 +474,7 @@ pub trait SamplingOptionsProvider { /// Trait for validating stop conditions pub trait StopConditionsProvider { /// Get stop sequences - fn get_stop_sequences(&self) -> Option<&crate::protocols::common::StringOrArray>; + fn get_stop_sequences(&self) -> Option<&StringOrArray>; } /// Trait for validating token limits @@ -532,25 +537,237 @@ pub trait ValidatableRequest: } } +// ================================================================== +// = OPENAI CHAT COMPLETION VALIDATION = +// ================================================================== + +impl SamplingOptionsProvider for ChatCompletionRequest { + fn get_temperature(&self) -> Option { + self.temperature + } + fn get_top_p(&self) -> Option { + self.top_p + } + fn get_frequency_penalty(&self) -> Option { + self.frequency_penalty + } + fn get_presence_penalty(&self) -> Option { + self.presence_penalty + } +} + +impl StopConditionsProvider for ChatCompletionRequest { + fn get_stop_sequences(&self) -> Option<&StringOrArray> { + self.stop.as_ref() + } +} + +impl TokenLimitsProvider for ChatCompletionRequest { + fn get_max_tokens(&self) -> Option { + // Prefer max_completion_tokens over max_tokens if both are set + self.max_completion_tokens.or(self.max_tokens) + } + + fn get_min_tokens(&self) -> Option { + self.min_tokens + } +} + +impl LogProbsProvider for ChatCompletionRequest { + fn get_logprobs(&self) -> Option { + // For chat API, logprobs is a boolean, return 1 if true for validation purposes + if self.logprobs { + Some(1) + } else { + None + } + } + + fn get_top_logprobs(&self) -> Option { + self.top_logprobs + } +} + +impl SGLangExtensionsProvider for ChatCompletionRequest { + fn get_top_k(&self) -> Option { + self.top_k + } + + fn get_min_p(&self) -> Option { + self.min_p + } + + fn get_repetition_penalty(&self) -> Option { + self.repetition_penalty + } +} + +impl CompletionCountProvider for ChatCompletionRequest { + fn get_n(&self) -> Option { + self.n + } +} + +impl ChatCompletionRequest { + /// Validate message-specific requirements + pub fn validate_messages(&self) -> Result<(), ValidationError> { + // Ensure messages array is not empty + utils::validate_non_empty_array(&self.messages, "messages")?; + + // Validate message content is not empty + for (i, msg) in self.messages.iter().enumerate() { + if let ChatMessage::User { content, .. } = msg { + match content { + UserMessageContent::Text(text) if text.is_empty() => { + return Err(ValidationError::InvalidValue { + parameter: format!("messages[{}].content", i), + value: "empty".to_string(), + reason: "message content cannot be empty".to_string(), + }); + } + UserMessageContent::Parts(parts) if parts.is_empty() => { + return Err(ValidationError::InvalidValue { + parameter: format!("messages[{}].content", i), + value: "empty array".to_string(), + reason: "message content parts cannot be empty".to_string(), + }); + } + _ => {} + } + } + } + + Ok(()) + } + + /// Validate response format if specified + pub fn validate_response_format(&self) -> Result<(), ValidationError> { + if let Some(ResponseFormat::JsonSchema { json_schema }) = &self.response_format { + if json_schema.name.is_empty() { + return Err(ValidationError::InvalidValue { + parameter: "response_format.json_schema.name".to_string(), + value: "empty".to_string(), + reason: "JSON schema name cannot be empty".to_string(), + }); + } + } + Ok(()) + } + + /// Validate chat API specific logprobs requirements + pub fn validate_chat_logprobs(&self) -> Result<(), ValidationError> { + // In chat API, if logprobs=true, top_logprobs must be specified + if self.logprobs && self.top_logprobs.is_none() { + return Err(ValidationError::MissingRequired { + parameter: "top_logprobs".to_string(), + }); + } + + // If top_logprobs is specified, logprobs should be true + if self.top_logprobs.is_some() && !self.logprobs { + return Err(ValidationError::InvalidValue { + parameter: "logprobs".to_string(), + value: "false".to_string(), + reason: "must be true when top_logprobs is specified".to_string(), + }); + } + + Ok(()) + } + + /// Validate cross-parameter relationships specific to chat completions + pub fn validate_chat_cross_parameters(&self) -> Result<(), ValidationError> { + // Validate that both max_tokens and max_completion_tokens aren't set + utils::validate_conflicting_parameters( + "max_tokens", + self.max_tokens.is_some(), + "max_completion_tokens", + self.max_completion_tokens.is_some(), + "cannot specify both max_tokens and max_completion_tokens", + )?; + + // Validate that tools and functions aren't both specified (deprecated) + utils::validate_conflicting_parameters( + "tools", + self.tools.is_some(), + "functions", + self.functions.is_some(), + "functions is deprecated, use tools instead", + )?; + + // Validate structured output constraints don't conflict with JSON response format + let has_json_format = matches!( + self.response_format, + Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. }) + ); + + utils::validate_conflicting_parameters( + "response_format", + has_json_format, + "regex", + self.regex.is_some(), + "cannot use regex constraint with JSON response format", + )?; + + utils::validate_conflicting_parameters( + "response_format", + has_json_format, + "ebnf", + self.ebnf.is_some(), + "cannot use EBNF constraint with JSON response format", + )?; + + // Only one structured output constraint should be active + let structured_constraints = [ + ("regex", self.regex.is_some()), + ("ebnf", self.ebnf.is_some()), + ( + "json_schema", + matches!( + self.response_format, + Some(ResponseFormat::JsonSchema { .. }) + ), + ), + ]; + + utils::validate_mutually_exclusive_options( + &structured_constraints, + "Only one structured output constraint (regex, ebnf, or json_schema) can be active at a time", + )?; + + Ok(()) + } +} + +impl ValidatableRequest for ChatCompletionRequest { + fn validate(&self) -> Result<(), ValidationError> { + // Call the common validation function from the validation module + utils::validate_common_request_params(self)?; + + // Then validate chat-specific parameters + self.validate_messages()?; + self.validate_response_format()?; + self.validate_chat_logprobs()?; + self.validate_chat_cross_parameters()?; + + Ok(()) + } +} + #[cfg(test)] mod tests { use super::constants::*; use super::utils::*; use super::*; - use crate::protocols::common::StringOrArray; + use crate::protocols::spec::StringOrArray; // Mock request type for testing validation traits #[derive(Debug, Default)] struct MockRequest { temperature: Option, - top_p: Option, - frequency_penalty: Option, - presence_penalty: Option, stop: Option, max_tokens: Option, min_tokens: Option, - logprobs: Option, - top_logprobs: Option, } impl SamplingOptionsProvider for MockRequest { @@ -558,13 +775,13 @@ mod tests { self.temperature } fn get_top_p(&self) -> Option { - self.top_p + None } fn get_frequency_penalty(&self) -> Option { - self.frequency_penalty + None } fn get_presence_penalty(&self) -> Option { - self.presence_penalty + None } } @@ -585,97 +802,36 @@ mod tests { impl LogProbsProvider for MockRequest { fn get_logprobs(&self) -> Option { - self.logprobs + None } fn get_top_logprobs(&self) -> Option { - self.top_logprobs + None } } - impl SGLangExtensionsProvider for MockRequest { - // Default implementations return None, so no custom logic needed - } - - impl CompletionCountProvider for MockRequest { - // Default implementation returns None, so no custom logic needed - } - + impl SGLangExtensionsProvider for MockRequest {} + impl CompletionCountProvider for MockRequest {} impl ValidatableRequest for MockRequest {} #[test] - fn test_validate_range_valid() { - let result = validate_range(1.5f32, &TEMPERATURE_RANGE, "temperature"); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), 1.5f32); + fn test_range_validation() { + // Valid range + assert!(validate_range(1.5f32, &TEMPERATURE_RANGE, "temperature").is_ok()); + // Invalid range + assert!(validate_range(-0.1f32, &TEMPERATURE_RANGE, "temperature").is_err()); + assert!(validate_range(3.0f32, &TEMPERATURE_RANGE, "temperature").is_err()); } #[test] - fn test_validate_range_too_low() { - let result = validate_range(-0.1f32, &TEMPERATURE_RANGE, "temperature"); - assert!(result.is_err()); - match result.unwrap_err() { - ValidationError::OutOfRange { parameter, .. } => { - assert_eq!(parameter, "temperature"); - } - _ => panic!("Expected OutOfRange error"), - } - } - - #[test] - fn test_validate_positive_valid() { - let result = validate_positive(5i32, "max_tokens"); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), 5i32); - } - - #[test] - fn test_validate_max_items_valid() { - let items = vec!["stop1", "stop2"]; - let result = validate_max_items(&items, MAX_STOP_SEQUENCES, "stop"); - assert!(result.is_ok()); - } - - #[test] - fn test_validate_top_k() { + fn test_sglang_top_k_validation() { assert!(validate_top_k(-1).is_ok()); // Disabled - assert!(validate_top_k(50).is_ok()); // Positive + assert!(validate_top_k(50).is_ok()); // Valid positive assert!(validate_top_k(0).is_err()); // Invalid assert!(validate_top_k(-5).is_err()); // Invalid } #[test] - fn test_valid_request() { - let request = MockRequest { - temperature: Some(1.0), - top_p: Some(0.9), - frequency_penalty: Some(0.5), - presence_penalty: Some(-0.5), - stop: Some(StringOrArray::Array(vec![ - "stop1".to_string(), - "stop2".to_string(), - ])), - max_tokens: Some(100), - min_tokens: Some(10), - logprobs: Some(3), - top_logprobs: Some(15), - }; - - assert!(request.validate().is_ok()); - } - - #[test] - fn test_invalid_temperature() { - let request = MockRequest { - temperature: Some(3.0), // Invalid: too high - ..Default::default() - }; - - let result = request.validate(); - assert!(result.is_err()); - } - - #[test] - fn test_too_many_stop_sequences() { + fn test_stop_sequences_limits() { let request = MockRequest { stop: Some(StringOrArray::Array(vec![ "stop1".to_string(), @@ -686,72 +842,322 @@ mod tests { ])), ..Default::default() }; - - let result = request.validate(); - assert!(result.is_err()); - match result.unwrap_err() { - ValidationError::TooManyItems { - parameter, - count, - max, - } => { - assert_eq!(parameter, "stop"); - assert_eq!(count, 5); - assert_eq!(max, MAX_STOP_SEQUENCES); - } - _ => panic!("Expected TooManyItems error"), - } + assert!(request.validate().is_err()); } #[test] - fn test_conflicting_token_limits() { + fn test_token_limits_conflict() { let request = MockRequest { min_tokens: Some(100), - max_tokens: Some(50), // Invalid: min > max + max_tokens: Some(50), // min > max ..Default::default() }; - - let result = request.validate(); - assert!(result.is_err()); - match result.unwrap_err() { - ValidationError::ConflictingParameters { - parameter1, - parameter2, - .. - } => { - assert_eq!(parameter1, "min_tokens"); - assert_eq!(parameter2, "max_tokens"); - } - _ => panic!("Expected ConflictingParameters error"), - } + assert!(request.validate().is_err()); } #[test] - fn test_boundary_values() { + fn test_valid_request() { let request = MockRequest { - temperature: Some(0.0), // Boundary: minimum - top_p: Some(1.0), // Boundary: maximum - frequency_penalty: Some(-2.0), // Boundary: minimum - presence_penalty: Some(2.0), // Boundary: maximum - logprobs: Some(0), // Boundary: minimum - top_logprobs: Some(20), // Boundary: maximum - ..Default::default() + temperature: Some(1.0), + stop: Some(StringOrArray::Array(vec!["stop".to_string()])), + max_tokens: Some(100), + min_tokens: Some(10), }; - assert!(request.validate().is_ok()); } - #[test] - fn test_validation_error_display() { - let error = ValidationError::OutOfRange { - parameter: "temperature".to_string(), - value: "3.0".to_string(), - min: "0.0".to_string(), - max: "2.0".to_string(), - }; + // Chat completion specific tests + #[cfg(test)] + mod chat_tests { + use super::*; - let message = format!("{}", error); - assert!(message.contains("temperature")); - assert!(message.contains("3.0")); + fn create_valid_chat_request() -> ChatCompletionRequest { + ChatCompletionRequest { + model: "gpt-4".to_string(), + messages: vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Hello".to_string()), + name: None, + }], + temperature: Some(1.0), + top_p: Some(0.9), + n: Some(1), + stream: false, + stream_options: None, + stop: None, + max_tokens: Some(100), + max_completion_tokens: None, + presence_penalty: Some(0.0), + frequency_penalty: Some(0.0), + logit_bias: None, + user: None, + seed: None, + logprobs: false, + top_logprobs: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + functions: None, + function_call: None, + // SGLang extensions + top_k: None, + min_p: None, + min_tokens: None, + repetition_penalty: None, + regex: None, + ebnf: None, + stop_token_ids: None, + no_stop_trim: false, + ignore_eos: false, + continue_final_message: false, + skip_special_tokens: true, + lora_path: None, + session_params: None, + separate_reasoning: true, + stream_reasoning: true, + return_hidden_states: false, + } + } + + #[test] + fn test_chat_validation_basics() { + // Valid request + assert!(create_valid_chat_request().validate().is_ok()); + + // Empty messages + let mut request = create_valid_chat_request(); + request.messages = vec![]; + assert!(request.validate().is_err()); + + // Invalid temperature + let mut request = create_valid_chat_request(); + request.temperature = Some(3.0); + assert!(request.validate().is_err()); + } + + #[test] + fn test_chat_conflicts() { + let mut request = create_valid_chat_request(); + + // Conflicting max_tokens + request.max_tokens = Some(100); + request.max_completion_tokens = Some(200); + assert!(request.validate().is_err()); + + // Logprobs without top_logprobs + request.max_tokens = None; + request.logprobs = true; + request.top_logprobs = None; + assert!(request.validate().is_err()); + } + + #[test] + fn test_sglang_extensions() { + let mut request = create_valid_chat_request(); + + // Valid SGLang parameters + request.top_k = Some(-1); + request.min_p = Some(0.1); + request.repetition_penalty = Some(1.2); + assert!(request.validate().is_ok()); + + // Invalid parameters + request.top_k = Some(0); // Invalid + assert!(request.validate().is_err()); + } + + #[test] + fn test_parameter_ranges() { + let mut request = create_valid_chat_request(); + + // Test temperature range (0.0 to 2.0) + request.temperature = Some(1.5); + assert!(request.validate().is_ok()); + request.temperature = Some(-0.1); + assert!(request.validate().is_err()); + request.temperature = Some(3.0); + assert!(request.validate().is_err()); + + // Test top_p range (0.0 to 1.0) + request.temperature = Some(1.0); // Reset + request.top_p = Some(0.9); + assert!(request.validate().is_ok()); + request.top_p = Some(-0.1); + assert!(request.validate().is_err()); + request.top_p = Some(1.5); + assert!(request.validate().is_err()); + + // Test frequency_penalty range (-2.0 to 2.0) + request.top_p = Some(0.9); // Reset + request.frequency_penalty = Some(1.5); + assert!(request.validate().is_ok()); + request.frequency_penalty = Some(-2.5); + assert!(request.validate().is_err()); + request.frequency_penalty = Some(3.0); + assert!(request.validate().is_err()); + + // Test presence_penalty range (-2.0 to 2.0) + request.frequency_penalty = Some(0.0); // Reset + request.presence_penalty = Some(-1.5); + assert!(request.validate().is_ok()); + request.presence_penalty = Some(-3.0); + assert!(request.validate().is_err()); + request.presence_penalty = Some(2.5); + assert!(request.validate().is_err()); + + // Test repetition_penalty range (0.0 to 2.0) + request.presence_penalty = Some(0.0); // Reset + request.repetition_penalty = Some(1.2); + assert!(request.validate().is_ok()); + request.repetition_penalty = Some(-0.1); + assert!(request.validate().is_err()); + request.repetition_penalty = Some(2.1); + assert!(request.validate().is_err()); + + // Test min_p range (0.0 to 1.0) + request.repetition_penalty = Some(1.0); // Reset + request.min_p = Some(0.5); + assert!(request.validate().is_ok()); + request.min_p = Some(-0.1); + assert!(request.validate().is_err()); + request.min_p = Some(1.5); + assert!(request.validate().is_err()); + } + + #[test] + fn test_structured_output_conflicts() { + let mut request = create_valid_chat_request(); + + // JSON response format with regex should conflict + request.response_format = Some(ResponseFormat::JsonObject); + request.regex = Some(".*".to_string()); + assert!(request.validate().is_err()); + + // JSON response format with EBNF should conflict + request.regex = None; + request.ebnf = Some("grammar".to_string()); + assert!(request.validate().is_err()); + + // Multiple structured constraints should conflict + request.response_format = None; + request.regex = Some(".*".to_string()); + request.ebnf = Some("grammar".to_string()); + assert!(request.validate().is_err()); + + // Only one constraint should work + request.ebnf = None; + request.regex = Some(".*".to_string()); + assert!(request.validate().is_ok()); + + request.regex = None; + request.ebnf = Some("grammar".to_string()); + assert!(request.validate().is_ok()); + + request.ebnf = None; + request.response_format = Some(ResponseFormat::JsonObject); + assert!(request.validate().is_ok()); + } + + #[test] + fn test_stop_sequences_validation() { + let mut request = create_valid_chat_request(); + + // Valid stop sequences + request.stop = Some(StringOrArray::Array(vec![ + "stop1".to_string(), + "stop2".to_string(), + ])); + assert!(request.validate().is_ok()); + + // Too many stop sequences (max 4) + request.stop = Some(StringOrArray::Array(vec![ + "stop1".to_string(), + "stop2".to_string(), + "stop3".to_string(), + "stop4".to_string(), + "stop5".to_string(), + ])); + assert!(request.validate().is_err()); + + // Empty stop sequence should fail + request.stop = Some(StringOrArray::String("".to_string())); + assert!(request.validate().is_err()); + + // Empty string in array should fail + request.stop = Some(StringOrArray::Array(vec![ + "stop1".to_string(), + "".to_string(), + ])); + assert!(request.validate().is_err()); + } + + #[test] + fn test_logprobs_validation() { + let mut request = create_valid_chat_request(); + + // Valid logprobs configuration + request.logprobs = true; + request.top_logprobs = Some(10); + assert!(request.validate().is_ok()); + + // logprobs=true without top_logprobs should fail + request.top_logprobs = None; + assert!(request.validate().is_err()); + + // top_logprobs without logprobs=true should fail + request.logprobs = false; + request.top_logprobs = Some(10); + assert!(request.validate().is_err()); + + // top_logprobs out of range (0-20) + request.logprobs = true; + request.top_logprobs = Some(25); + assert!(request.validate().is_err()); + } + + #[test] + fn test_n_parameter_validation() { + let mut request = create_valid_chat_request(); + + // Valid n values (1-10) + request.n = Some(1); + assert!(request.validate().is_ok()); + request.n = Some(5); + assert!(request.validate().is_ok()); + request.n = Some(10); + assert!(request.validate().is_ok()); + + // Invalid n values + request.n = Some(0); + assert!(request.validate().is_err()); + request.n = Some(15); + assert!(request.validate().is_err()); + } + + #[test] + fn test_min_max_tokens_validation() { + let mut request = create_valid_chat_request(); + + // Valid token limits + request.min_tokens = Some(10); + request.max_tokens = Some(100); + assert!(request.validate().is_ok()); + + // min_tokens > max_tokens should fail + request.min_tokens = Some(150); + request.max_tokens = Some(100); + assert!(request.validate().is_err()); + + // Should work with max_completion_tokens instead + request.max_tokens = None; + request.max_completion_tokens = Some(200); + request.min_tokens = Some(50); + assert!(request.validate().is_ok()); + + // min_tokens > max_completion_tokens should fail + request.min_tokens = Some(250); + assert!(request.validate().is_err()); + } } } diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 83789852b..a0882c176 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -9,10 +9,7 @@ use axum::{ }; use std::fmt::Debug; -use crate::protocols::{ - generate::GenerateRequest, - openai::{chat::ChatCompletionRequest, completions::CompletionRequest}, -}; +use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; pub mod factory; pub mod header_utils; diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index a3e749f93..9dd5ae279 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -12,13 +12,9 @@ use crate::core::{ }; use crate::metrics::RouterMetrics; use crate::policies::LoadBalancingPolicy; -use crate::protocols::{ - common::StringOrArray, - generate::GenerateRequest, - openai::{ - chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}, - completions::CompletionRequest, - }, +use crate::protocols::spec::{ + ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, StringOrArray, + UserMessageContent, }; use crate::routers::{RouterTrait, WorkerManagement}; use async_trait::async_trait; diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index 2c5d278ea..00dbe32dc 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -9,10 +9,8 @@ use crate::core::{ }; use crate::metrics::RouterMetrics; use crate::policies::LoadBalancingPolicy; -use crate::protocols::{ - common::GenerationRequest, - generate::GenerateRequest, - openai::{chat::ChatCompletionRequest, completions::CompletionRequest}, +use crate::protocols::spec::{ + ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, }; use crate::routers::{RouterTrait, WorkerManagement}; use axum::{ diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 85e7648af..7ca6b9388 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,10 +1,7 @@ use crate::config::RouterConfig; use crate::logging::{self, LoggingConfig}; use crate::metrics::{self, PrometheusConfig}; -use crate::protocols::{ - generate::GenerateRequest, - openai::{chat::ChatCompletionRequest, completions::CompletionRequest}, -}; +use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::routers::{RouterFactory, RouterTrait}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use axum::{ diff --git a/sgl-router/tests/benchmark_integration.rs b/sgl-router/tests/benchmark_integration.rs index 6787d8695..e40ca08ab 100644 --- a/sgl-router/tests/benchmark_integration.rs +++ b/sgl-router/tests/benchmark_integration.rs @@ -5,13 +5,9 @@ use serde_json::{from_str, to_string, to_value}; use sglang_router_rs::core::{BasicWorker, WorkerType}; -use sglang_router_rs::protocols::{ - common::StringOrArray, - generate::{GenerateParameters, GenerateRequest, SamplingParams}, - openai::{ - chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}, - completions::CompletionRequest, - }, +use sglang_router_rs::protocols::spec::{ + ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, + SamplingParams, StringOrArray, UserMessageContent, }; /// Create a default GenerateRequest for benchmarks with minimal fields set diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index a5653edd8..dc2253799 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -1,8 +1,10 @@ // Integration test for Responses API -use sglang_router_rs::protocols::common::GenerationRequest; -use sglang_router_rs::protocols::openai::responses::request::ResponseInput; -use sglang_router_rs::protocols::openai::responses::*; +use sglang_router_rs::protocols::spec::{ + GenerationRequest, ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseStatus, + ResponseTool, ResponseToolType, ResponsesRequest, ResponsesResponse, ServiceTier, ToolChoice, + ToolChoiceValue, Truncation, UsageInfo, +}; #[test] fn test_responses_request_creation() { @@ -24,7 +26,7 @@ fn test_responses_request_creation() { store: true, stream: false, temperature: Some(0.7), - tool_choice: ToolChoice::Auto, + tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), tools: vec![ResponseTool { r#type: ResponseToolType::WebSearchPreview, }], @@ -67,7 +69,7 @@ fn test_sampling_params_conversion() { store: true, // Use default true stream: false, temperature: Some(0.8), - tool_choice: ToolChoice::Auto, + tool_choice: ToolChoice::Value(ToolChoiceValue::Auto), tools: vec![], top_logprobs: 0, // Use default 0 top_p: Some(0.95), @@ -177,7 +179,7 @@ fn test_json_serialization() { store: false, stream: true, temperature: Some(0.9), - tool_choice: ToolChoice::Required, + tool_choice: ToolChoice::Value(ToolChoiceValue::Required), tools: vec![ResponseTool { r#type: ResponseToolType::CodeInterpreter, }],