diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index 3b979477c..70de06361 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -3,9 +3,13 @@ use serde_json::{from_str, to_string, to_value, to_vec}; use std::time::Instant; use sglang_router_rs::core::{BasicWorker, Worker, WorkerType}; -use sglang_router_rs::openai_api_types::{ - ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, - SamplingParams, StringOrArray, UserMessageContent, +use sglang_router_rs::protocols::{ + common::StringOrArray, + generate::{GenerateParameters, GenerateRequest, SamplingParams}, + openai::{ + chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}, + completions::CompletionRequest, + }, }; use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap}; diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 00c8e910d..ec29a1740 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -5,8 +5,8 @@ use std::collections::HashMap; pub mod core; pub mod metrics; pub mod middleware; -pub mod openai_api_types; pub mod policies; +pub mod protocols; pub mod reasoning_parser; pub mod routers; pub mod server; diff --git a/sgl-router/src/openai_api_types.rs b/sgl-router/src/openai_api_types.rs deleted file mode 100644 index 4a0fb0ee0..000000000 --- a/sgl-router/src/openai_api_types.rs +++ /dev/null @@ -1,921 +0,0 @@ -// OpenAI-compatible API types for text generation -// Based on OpenAI's API specification: https://platform.openai.com/docs/api-reference -// Reference: Azure OpenAI API documentation which follows OpenAI's specification - -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::collections::HashMap; - -/// Helper function for serde default value -fn default_true() -> bool { - true -} - -// ============= SGLang-Specific Types ============= - -/// LoRA adapter path - can be single path or batch of paths -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum LoRAPath { - Single(Option), - Batch(Vec>), -} - -/// Common trait for all generation requests -pub trait GenerationRequest: Send + Sync { - /// Check if the request is for streaming - fn is_stream(&self) -> bool; - - /// Get the model name if specified - fn get_model(&self) -> Option<&str>; - - /// Extract text content for routing decisions - fn extract_text_for_routing(&self) -> String; -} - -// ============= Completions API (v1/completions) - DEPRECATED but still supported ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionRequest { - /// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang) - pub model: String, - - /// The prompt(s) to generate completions for - pub prompt: StringOrArray, - - /// The suffix that comes after a completion of inserted text - #[serde(skip_serializing_if = "Option::is_none")] - pub suffix: Option, - - /// The maximum number of tokens to generate - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - - /// What sampling temperature to use, between 0 and 2 - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - - /// An alternative to sampling with temperature (nucleus sampling) - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - - /// How many completions to generate for each prompt - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, - - /// Whether to stream back partial progress - #[serde(default)] - pub stream: bool, - - /// Options for streaming response - #[serde(skip_serializing_if = "Option::is_none")] - pub stream_options: Option, - - /// Include the log probabilities on the logprobs most likely tokens - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - - /// Echo back the prompt in addition to the completion - #[serde(default)] - pub echo: bool, - - /// Up to 4 sequences where the API will stop generating further tokens - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - - /// Generates best_of completions server-side and returns the "best" - #[serde(skip_serializing_if = "Option::is_none")] - pub best_of: Option, - - /// Modify the likelihood of specified tokens appearing in the completion - #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, - - /// A unique identifier representing your end-user - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, - - /// If specified, our system will make a best effort to sample deterministically - #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, - - // ============= SGLang Extensions ============= - /// Top-k sampling parameter (-1 to disable) - #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, - - /// Min-p nucleus sampling parameter - #[serde(skip_serializing_if = "Option::is_none")] - pub min_p: Option, - - /// Minimum number of tokens to generate - #[serde(skip_serializing_if = "Option::is_none")] - pub min_tokens: Option, - - /// Repetition penalty for reducing repetitive text - #[serde(skip_serializing_if = "Option::is_none")] - pub repetition_penalty: Option, - - /// Regex constraint for output generation - #[serde(skip_serializing_if = "Option::is_none")] - pub regex: Option, - - /// EBNF grammar constraint for structured output - #[serde(skip_serializing_if = "Option::is_none")] - pub ebnf: Option, - - /// JSON schema constraint for structured output - #[serde(skip_serializing_if = "Option::is_none")] - pub json_schema: Option, - - /// Specific token IDs to use as stop conditions - #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, - - /// Skip trimming stop tokens from output - #[serde(default)] - pub no_stop_trim: bool, - - /// Ignore end-of-sequence tokens during generation - #[serde(default)] - pub ignore_eos: bool, - - /// Skip special tokens during detokenization - #[serde(default = "default_true")] - pub skip_special_tokens: bool, - - // ============= SGLang Extensions ============= - /// Path to LoRA adapter(s) for model customization - #[serde(skip_serializing_if = "Option::is_none")] - pub lora_path: Option, - - /// Session parameters for continual prompting - #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, - - /// Return model hidden states - #[serde(default)] - pub return_hidden_states: bool, - - /// Additional fields including bootstrap info for PD routing - #[serde(flatten)] - pub other: serde_json::Map, -} - -impl GenerationRequest for CompletionRequest { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_model(&self) -> Option<&str> { - Some(&self.model) - } - - fn extract_text_for_routing(&self) -> String { - match &self.prompt { - StringOrArray::String(s) => s.clone(), - StringOrArray::Array(v) => v.join(" "), - } - } -} - -// ============= Chat Completions API (v1/chat/completions) ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionRequest { - /// ID of the model to use - pub model: String, - - /// A list of messages comprising the conversation so far - pub messages: Vec, - - /// What sampling temperature to use, between 0 and 2 - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - - /// An alternative to sampling with temperature - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - - /// How many chat completion choices to generate for each input message - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, - - /// If set, partial message deltas will be sent - #[serde(default)] - pub stream: bool, - - /// Options for streaming response - #[serde(skip_serializing_if = "Option::is_none")] - pub stream_options: Option, - - /// Up to 4 sequences where the API will stop generating further tokens - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, - - /// The maximum number of tokens to generate - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - - /// An upper bound for the number of tokens that can be generated for a completion - #[serde(skip_serializing_if = "Option::is_none")] - pub max_completion_tokens: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - - /// Modify the likelihood of specified tokens appearing in the completion - #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, - - /// A unique identifier representing your end-user - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, - - /// If specified, our system will make a best effort to sample deterministically - #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, - - /// Whether to return log probabilities of the output tokens - #[serde(default)] - pub logprobs: bool, - - /// An integer between 0 and 20 specifying the number of most likely tokens to return - #[serde(skip_serializing_if = "Option::is_none")] - pub top_logprobs: Option, - - /// An object specifying the format that the model must output - #[serde(skip_serializing_if = "Option::is_none")] - pub response_format: Option, - - /// A list of tools the model may call - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, - - /// Controls which (if any) tool is called by the model - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_choice: Option, - - /// Whether to enable parallel function calling during tool use - #[serde(skip_serializing_if = "Option::is_none")] - pub parallel_tool_calls: Option, - - /// Deprecated: use tools instead - #[serde(skip_serializing_if = "Option::is_none")] - pub functions: Option>, - - /// Deprecated: use tool_choice instead - #[serde(skip_serializing_if = "Option::is_none")] - pub function_call: Option, - - // ============= SGLang Extensions ============= - /// Top-k sampling parameter (-1 to disable) - #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, - - /// Min-p nucleus sampling parameter - #[serde(skip_serializing_if = "Option::is_none")] - pub min_p: Option, - - /// Minimum number of tokens to generate - #[serde(skip_serializing_if = "Option::is_none")] - pub min_tokens: Option, - - /// Repetition penalty for reducing repetitive text - #[serde(skip_serializing_if = "Option::is_none")] - pub repetition_penalty: Option, - - /// Regex constraint for output generation - #[serde(skip_serializing_if = "Option::is_none")] - pub regex: Option, - - /// EBNF grammar constraint for structured output - #[serde(skip_serializing_if = "Option::is_none")] - pub ebnf: Option, - - /// Specific token IDs to use as stop conditions - #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, - - /// Skip trimming stop tokens from output - #[serde(default)] - pub no_stop_trim: bool, - - /// Ignore end-of-sequence tokens during generation - #[serde(default)] - pub ignore_eos: bool, - - /// Continue generating from final assistant message - #[serde(default)] - pub continue_final_message: bool, - - /// Skip special tokens during detokenization - #[serde(default = "default_true")] - pub skip_special_tokens: bool, - - // ============= SGLang Extensions ============= - /// Path to LoRA adapter(s) for model customization - #[serde(skip_serializing_if = "Option::is_none")] - pub lora_path: Option, - - /// Session parameters for continual prompting - #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, - - /// Separate reasoning content from final answer (O1-style models) - #[serde(default = "default_true")] - pub separate_reasoning: bool, - - /// Stream reasoning tokens during generation - #[serde(default = "default_true")] - pub stream_reasoning: bool, - - /// Return model hidden states - #[serde(default)] - pub return_hidden_states: bool, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum ChatMessage { - System { - role: String, // "system" - content: String, - #[serde(skip_serializing_if = "Option::is_none")] - name: Option, - }, - User { - role: String, // "user" - content: UserMessageContent, - #[serde(skip_serializing_if = "Option::is_none")] - name: Option, - }, - Assistant { - role: String, // "assistant" - #[serde(skip_serializing_if = "Option::is_none")] - content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tool_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - function_call: Option, - /// Reasoning content for O1-style models (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - reasoning_content: Option, - }, - Tool { - role: String, // "tool" - content: String, - tool_call_id: String, - }, - Function { - role: String, // "function" - content: String, - name: String, - }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum UserMessageContent { - Text(String), - Parts(Vec), -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -pub enum ContentPart { - #[serde(rename = "text")] - Text { text: String }, - #[serde(rename = "image_url")] - ImageUrl { image_url: ImageUrl }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ImageUrl { - pub url: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub detail: Option, // "auto", "low", or "high" -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct StreamOptions { - #[serde(skip_serializing_if = "Option::is_none")] - pub include_usage: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "type")] -pub enum ResponseFormat { - #[serde(rename = "text")] - Text, - #[serde(rename = "json_object")] - JsonObject, - #[serde(rename = "json_schema")] - JsonSchema { json_schema: JsonSchemaFormat }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct JsonSchemaFormat { - pub name: String, - pub schema: Value, - #[serde(skip_serializing_if = "Option::is_none")] - pub strict: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Tool { - #[serde(rename = "type")] - pub tool_type: String, // "function" - pub function: Function, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Function { - pub name: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - pub parameters: Value, // JSON Schema -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum ToolChoice { - None, - Auto, - Required, - Function { - #[serde(rename = "type")] - tool_type: String, // "function" - function: FunctionChoice, - }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct FunctionChoice { - pub name: String, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ToolCall { - pub id: String, - #[serde(rename = "type")] - pub tool_type: String, // "function" - pub function: FunctionCallResponse, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum FunctionCall { - None, - Auto, - Function { name: String }, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct FunctionCallResponse { - pub name: String, - pub arguments: String, // JSON string -} - -impl GenerationRequest for ChatCompletionRequest { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_model(&self) -> Option<&str> { - Some(&self.model) - } - - fn extract_text_for_routing(&self) -> String { - // Extract text from messages for routing decisions - self.messages - .iter() - .filter_map(|msg| match msg { - ChatMessage::System { content, .. } => Some(content.clone()), - ChatMessage::User { content, .. } => match content { - UserMessageContent::Text(text) => Some(text.clone()), - UserMessageContent::Parts(parts) => { - let texts: Vec = parts - .iter() - .filter_map(|part| match part { - ContentPart::Text { text } => Some(text.clone()), - _ => None, - }) - .collect(); - Some(texts.join(" ")) - } - }, - ChatMessage::Assistant { - content, - reasoning_content, - .. - } => { - // Combine content and reasoning content for routing decisions - let main_content = content.clone().unwrap_or_default(); - let reasoning = reasoning_content.clone().unwrap_or_default(); - if main_content.is_empty() && reasoning.is_empty() { - None - } else { - Some(format!("{} {}", main_content, reasoning).trim().to_string()) - } - } - ChatMessage::Tool { content, .. } => Some(content.clone()), - ChatMessage::Function { content, .. } => Some(content.clone()), - }) - .collect::>() - .join(" ") - } -} - -// ============= Generate API (/generate) ============= - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct GenerateRequest { - /// The prompt to generate from (OpenAI style) - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt: Option, - - /// Text input - SGLang native format - #[serde(skip_serializing_if = "Option::is_none")] - pub text: Option, - - /// Input IDs for tokenized input - #[serde(skip_serializing_if = "Option::is_none")] - pub input_ids: Option, - - /// Generation parameters - #[serde(default, skip_serializing_if = "Option::is_none")] - pub parameters: Option, - - /// Sampling parameters (sglang style) - #[serde(skip_serializing_if = "Option::is_none")] - pub sampling_params: Option, - - /// Whether to stream the response - #[serde(default)] - pub stream: bool, - - /// Whether to return logprobs - #[serde(default)] - pub return_logprob: bool, - - // ============= SGLang Extensions ============= - /// Path to LoRA adapter(s) for model customization - #[serde(skip_serializing_if = "Option::is_none")] - pub lora_path: Option, - - /// Session parameters for continual prompting - #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, - - /// Return model hidden states - #[serde(default)] - pub return_hidden_states: bool, - - /// Request ID for tracking - #[serde(skip_serializing_if = "Option::is_none")] - pub rid: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum InputIds { - Single(Vec), - Batch(Vec>), -} - -#[derive(Debug, Clone, Deserialize, Serialize, Default)] -pub struct GenerateParameters { - #[serde(skip_serializing_if = "Option::is_none")] - pub best_of: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub decoder_input_details: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub details: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub do_sample: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_new_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub repetition_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub return_full_text: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub truncate: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub typical_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub watermark: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize, Default)] -pub struct SamplingParams { - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_new_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub repetition_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub ignore_eos: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub skip_special_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub json_schema: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub regex: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub ebnf: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub min_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub min_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub no_stop_trim: Option, -} - -impl GenerationRequest for GenerateRequest { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_model(&self) -> Option<&str> { - // Generate requests typically don't have a model field - None - } - - fn extract_text_for_routing(&self) -> String { - // Check fields in priority order: text, prompt, inputs - if let Some(ref text) = self.text { - return text.clone(); - } - - if let Some(ref prompt) = self.prompt { - return match prompt { - StringOrArray::String(s) => s.clone(), - StringOrArray::Array(v) => v.join(" "), - }; - } - - if let Some(ref input_ids) = self.input_ids { - return match input_ids { - InputIds::Single(ids) => ids - .iter() - .map(|&id| id.to_string()) - .collect::>() - .join(" "), - InputIds::Batch(batches) => batches - .iter() - .flat_map(|batch| batch.iter().map(|&id| id.to_string())) - .collect::>() - .join(" "), - }; - } - - // No text input found - String::new() - } -} - -// ============= Helper Types ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum StringOrArray { - String(String), - Array(Vec), -} - -// ============= Response Types ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionResponse { - pub id: String, - pub object: String, // "text_completion" - pub created: u64, - pub model: String, - pub choices: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionChoice { - pub text: String, - pub index: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - pub finish_reason: Option, // "stop", "length", "content_filter", etc. - /// Information about which stop condition was matched - #[serde(skip_serializing_if = "Option::is_none")] - pub matched_stop: Option, // Can be string or integer - /// Hidden states from the model (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - pub hidden_states: Option>, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct LogProbs { - pub tokens: Vec, - pub token_logprobs: Vec>, - pub top_logprobs: Vec>>, - pub text_offset: Vec, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionResponse { - pub id: String, - pub object: String, // "chat.completion" - pub created: u64, - pub model: String, - pub choices: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatChoice { - pub index: u32, - pub message: ChatMessage, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - pub finish_reason: Option, // "stop", "length", "tool_calls", "content_filter", "function_call" - /// Information about which stop condition was matched - #[serde(skip_serializing_if = "Option::is_none")] - pub matched_stop: Option, // Can be string or integer - /// Hidden states from the model (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - pub hidden_states: Option>, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatLogProbs { - pub content: Option>, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatLogProbsContent { - pub token: String, - pub logprob: f32, - pub bytes: Option>, - pub top_logprobs: Vec, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct TopLogProb { - pub token: String, - pub logprob: f32, - pub bytes: Option>, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub completion_tokens_details: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionTokensDetails { - pub reasoning_tokens: Option, -} - -// ============= Streaming Response Types ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionStreamResponse { - pub id: String, - pub object: String, // "text_completion" - pub created: u64, - pub choices: Vec, - pub model: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CompletionStreamChoice { - pub text: String, - pub index: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - pub finish_reason: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionStreamResponse { - pub id: String, - pub object: String, // "chat.completion.chunk" - pub created: u64, - pub model: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, - pub choices: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatStreamChoice { - pub index: u32, - pub delta: ChatMessageDelta, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - pub finish_reason: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatMessageDelta { - #[serde(skip_serializing_if = "Option::is_none")] - pub role: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub function_call: Option, - /// Reasoning content delta for O1-style models (SGLang extension) - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning_content: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ToolCallDelta { - pub index: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(rename = "type")] - pub tool_type: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub function: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct FunctionCallDelta { - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option, -} - -// ============= Error Response Types ============= - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ErrorResponse { - pub error: ErrorDetail, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ErrorDetail { - pub message: String, - #[serde(rename = "type")] - pub error_type: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub param: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub code: Option, -} diff --git a/sgl-router/src/protocols/common.rs b/sgl-router/src/protocols/common.rs new file mode 100644 index 000000000..54d67851c --- /dev/null +++ b/sgl-router/src/protocols/common.rs @@ -0,0 +1,36 @@ +// Common types shared across all protocol implementations + +use serde::{Deserialize, Serialize}; + +/// Helper function for serde default value +pub fn default_true() -> bool { + true +} + +/// Common trait for all generation requests across different APIs +pub trait GenerationRequest: Send + Sync { + /// Check if the request is for streaming + fn is_stream(&self) -> bool; + + /// Get the model name if specified + fn get_model(&self) -> Option<&str>; + + /// Extract text content for routing decisions + fn extract_text_for_routing(&self) -> String; +} + +/// Helper type for string or array of strings +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum StringOrArray { + String(String), + Array(Vec), +} + +/// LoRA adapter path - can be single path or batch of paths (SGLang extension) +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum LoRAPath { + Single(Option), + Batch(Vec>), +} diff --git a/sgl-router/src/protocols/generate/mod.rs b/sgl-router/src/protocols/generate/mod.rs new file mode 100644 index 000000000..7b2b1d97e --- /dev/null +++ b/sgl-router/src/protocols/generate/mod.rs @@ -0,0 +1,8 @@ +// SGLang native Generate API module (/generate) + +pub mod request; +pub mod types; + +// Re-export main types for convenience +pub use request::GenerateRequest; +pub use types::{GenerateParameters, InputIds, SamplingParams}; diff --git a/sgl-router/src/protocols/generate/request.rs b/sgl-router/src/protocols/generate/request.rs new file mode 100644 index 000000000..b3bb3fe46 --- /dev/null +++ b/sgl-router/src/protocols/generate/request.rs @@ -0,0 +1,97 @@ +// Generate API request types (/generate) + +use crate::protocols::common::{GenerationRequest, LoRAPath, StringOrArray}; +use crate::protocols::generate::types::{GenerateParameters, InputIds, SamplingParams}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GenerateRequest { + /// The prompt to generate from (OpenAI style) + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + + /// Text input - SGLang native format + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + + /// Input IDs for tokenized input + #[serde(skip_serializing_if = "Option::is_none")] + pub input_ids: Option, + + /// Generation parameters + #[serde(default, skip_serializing_if = "Option::is_none")] + pub parameters: Option, + + /// Sampling parameters (sglang style) + #[serde(skip_serializing_if = "Option::is_none")] + pub sampling_params: Option, + + /// Whether to stream the response + #[serde(default)] + pub stream: bool, + + /// Whether to return logprobs + #[serde(default)] + pub return_logprob: bool, + + // ============= SGLang Extensions ============= + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Return model hidden states + #[serde(default)] + pub return_hidden_states: bool, + + /// Request ID for tracking + #[serde(skip_serializing_if = "Option::is_none")] + pub rid: Option, +} + +impl GenerationRequest for GenerateRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + // Generate requests typically don't have a model field + None + } + + fn extract_text_for_routing(&self) -> String { + // Check fields in priority order: text, prompt, inputs + if let Some(ref text) = self.text { + return text.clone(); + } + + if let Some(ref prompt) = self.prompt { + return match prompt { + StringOrArray::String(s) => s.clone(), + StringOrArray::Array(v) => v.join(" "), + }; + } + + if let Some(ref input_ids) = self.input_ids { + return match input_ids { + InputIds::Single(ids) => ids + .iter() + .map(|&id| id.to_string()) + .collect::>() + .join(" "), + InputIds::Batch(batches) => batches + .iter() + .flat_map(|batch| batch.iter().map(|&id| id.to_string())) + .collect::>() + .join(" "), + }; + } + + // No text input found + String::new() + } +} diff --git a/sgl-router/src/protocols/generate/types.rs b/sgl-router/src/protocols/generate/types.rs new file mode 100644 index 000000000..4ddf363dc --- /dev/null +++ b/sgl-router/src/protocols/generate/types.rs @@ -0,0 +1,82 @@ +// Types for the SGLang native /generate API + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum InputIds { + Single(Vec), + Batch(Vec>), +} + +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct GenerateParameters { + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub decoder_input_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub do_sample: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_new_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub return_full_text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub truncate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub typical_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub watermark: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct SamplingParams { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_new_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ignore_eos: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub skip_special_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ebnf: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub no_stop_trim: Option, +} diff --git a/sgl-router/src/protocols/mod.rs b/sgl-router/src/protocols/mod.rs new file mode 100644 index 000000000..ae580546e --- /dev/null +++ b/sgl-router/src/protocols/mod.rs @@ -0,0 +1,6 @@ +// Protocol definitions and validation for various LLM APIs +// This module provides a structured approach to handling different API protocols + +pub mod common; +pub mod generate; +pub mod openai; diff --git a/sgl-router/src/protocols/openai/chat/mod.rs b/sgl-router/src/protocols/openai/chat/mod.rs new file mode 100644 index 000000000..3484ba987 --- /dev/null +++ b/sgl-router/src/protocols/openai/chat/mod.rs @@ -0,0 +1,12 @@ +// Chat Completions API module + +pub mod request; +pub mod response; +pub mod types; + +// Re-export main types for convenience +pub use request::ChatCompletionRequest; +pub use response::{ + ChatChoice, ChatCompletionResponse, ChatCompletionStreamResponse, ChatStreamChoice, +}; +pub use types::*; diff --git a/sgl-router/src/protocols/openai/chat/request.rs b/sgl-router/src/protocols/openai/chat/request.rs new file mode 100644 index 000000000..b7570c676 --- /dev/null +++ b/sgl-router/src/protocols/openai/chat/request.rs @@ -0,0 +1,216 @@ +// Chat Completions API request types + +use crate::protocols::common::{default_true, GenerationRequest, LoRAPath, StringOrArray}; +use crate::protocols::openai::chat::types::*; +use crate::protocols::openai::common::StreamOptions; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionRequest { + /// ID of the model to use + pub model: String, + + /// A list of messages comprising the conversation so far + pub messages: Vec, + + /// What sampling temperature to use, between 0 and 2 + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// An alternative to sampling with temperature + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// How many chat completion choices to generate for each input message + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + + /// If set, partial message deltas will be sent + #[serde(default)] + pub stream: bool, + + /// Options for streaming response + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + + /// Up to 4 sequences where the API will stop generating further tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + + /// The maximum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// An upper bound for the number of tokens that can be generated for a completion + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + + /// Modify the likelihood of specified tokens appearing in the completion + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + + /// A unique identifier representing your end-user + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// If specified, our system will make a best effort to sample deterministically + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + + /// Whether to return log probabilities of the output tokens + #[serde(default)] + pub logprobs: bool, + + /// An integer between 0 and 20 specifying the number of most likely tokens to return + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option, + + /// An object specifying the format that the model must output + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + + /// A list of tools the model may call + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + /// Controls which (if any) tool is called by the model + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + + /// Whether to enable parallel function calling during tool use + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + + /// Deprecated: use tools instead + #[serde(skip_serializing_if = "Option::is_none")] + pub functions: Option>, + + /// Deprecated: use tool_choice instead + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, + + // ============= SGLang Extensions ============= + /// Top-k sampling parameter (-1 to disable) + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + + /// Min-p nucleus sampling parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub min_p: Option, + + /// Minimum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + + /// Repetition penalty for reducing repetitive text + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + + /// Regex constraint for output generation + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + + /// EBNF grammar constraint for structured output + #[serde(skip_serializing_if = "Option::is_none")] + pub ebnf: Option, + + /// Specific token IDs to use as stop conditions + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + + /// Skip trimming stop tokens from output + #[serde(default)] + pub no_stop_trim: bool, + + /// Ignore end-of-sequence tokens during generation + #[serde(default)] + pub ignore_eos: bool, + + /// Continue generating from final assistant message + #[serde(default)] + pub continue_final_message: bool, + + /// Skip special tokens during detokenization + #[serde(default = "default_true")] + pub skip_special_tokens: bool, + + // ============= SGLang Extensions ============= + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Separate reasoning content from final answer (O1-style models) + #[serde(default = "default_true")] + pub separate_reasoning: bool, + + /// Stream reasoning tokens during generation + #[serde(default = "default_true")] + pub stream_reasoning: bool, + + /// Return model hidden states + #[serde(default)] + pub return_hidden_states: bool, +} + +impl GenerationRequest for ChatCompletionRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn extract_text_for_routing(&self) -> String { + // Extract text from messages for routing decisions + self.messages + .iter() + .filter_map(|msg| match msg { + ChatMessage::System { content, .. } => Some(content.clone()), + ChatMessage::User { content, .. } => match content { + UserMessageContent::Text(text) => Some(text.clone()), + UserMessageContent::Parts(parts) => { + let texts: Vec = parts + .iter() + .filter_map(|part| match part { + ContentPart::Text { text } => Some(text.clone()), + _ => None, + }) + .collect(); + Some(texts.join(" ")) + } + }, + ChatMessage::Assistant { + content, + reasoning_content, + .. + } => { + // Combine content and reasoning content for routing decisions + let main_content = content.clone().unwrap_or_default(); + let reasoning = reasoning_content.clone().unwrap_or_default(); + if main_content.is_empty() && reasoning.is_empty() { + None + } else { + Some(format!("{} {}", main_content, reasoning).trim().to_string()) + } + } + ChatMessage::Tool { content, .. } => Some(content.clone()), + ChatMessage::Function { content, .. } => Some(content.clone()), + }) + .collect::>() + .join(" ") + } +} diff --git a/sgl-router/src/protocols/openai/chat/response.rs b/sgl-router/src/protocols/openai/chat/response.rs new file mode 100644 index 000000000..3ac480462 --- /dev/null +++ b/sgl-router/src/protocols/openai/chat/response.rs @@ -0,0 +1,59 @@ +// Chat Completions API response types + +use crate::protocols::openai::chat::types::{ChatMessage, ChatMessageDelta}; +use crate::protocols::openai::common::{ChatLogProbs, Usage}; +use serde::{Deserialize, Serialize}; + +// ============= Regular Response ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, // "chat.completion" + pub created: u64, + pub model: String, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatChoice { + pub index: u32, + pub message: ChatMessage, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, // "stop", "length", "tool_calls", "content_filter", "function_call" + /// Information about which stop condition was matched + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, // Can be string or integer + /// Hidden states from the model (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + pub hidden_states: Option>, +} + +// ============= Streaming Response ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionStreamResponse { + pub id: String, + pub object: String, // "chat.completion.chunk" + pub created: u64, + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatStreamChoice { + pub index: u32, + pub delta: ChatMessageDelta, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, +} diff --git a/sgl-router/src/protocols/openai/chat/types.rs b/sgl-router/src/protocols/openai/chat/types.rs new file mode 100644 index 000000000..01bf836cf --- /dev/null +++ b/sgl-router/src/protocols/openai/chat/types.rs @@ -0,0 +1,185 @@ +// Types specific to the Chat Completions API + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +// ============= Message Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ChatMessage { + System { + role: String, // "system" + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + User { + role: String, // "user" + content: UserMessageContent, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + Assistant { + role: String, // "assistant" + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + function_call: Option, + /// Reasoning content for O1-style models (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_content: Option, + }, + Tool { + role: String, // "tool" + content: String, + tool_call_id: String, + }, + Function { + role: String, // "function" + content: String, + name: String, + }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum UserMessageContent { + Text(String), + Parts(Vec), +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +pub enum ContentPart { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + ImageUrl { image_url: ImageUrl }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ImageUrl { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, // "auto", "low", or "high" +} + +// ============= Response Format Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +pub enum ResponseFormat { + #[serde(rename = "text")] + Text, + #[serde(rename = "json_object")] + JsonObject, + #[serde(rename = "json_schema")] + JsonSchema { json_schema: JsonSchemaFormat }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct JsonSchemaFormat { + pub name: String, + pub schema: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +// ============= Tool/Function Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Tool { + #[serde(rename = "type")] + pub tool_type: String, // "function" + pub function: Function, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Function { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub parameters: Value, // JSON Schema +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ToolChoice { + None, + Auto, + Required, + Function { + #[serde(rename = "type")] + tool_type: String, // "function" + function: FunctionChoice, + }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionChoice { + pub name: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub tool_type: String, // "function" + pub function: FunctionCallResponse, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum FunctionCall { + None, + Auto, + Function { name: String }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionCallResponse { + pub name: String, + pub arguments: String, // JSON string +} + +// ============= Streaming Delta Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatMessageDelta { + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, + /// Reasoning content delta for O1-style models (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolCallDelta { + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "type")] + pub tool_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionCallDelta { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, +} diff --git a/sgl-router/src/protocols/openai/common.rs b/sgl-router/src/protocols/openai/common.rs new file mode 100644 index 000000000..69ed6d7b4 --- /dev/null +++ b/sgl-router/src/protocols/openai/common.rs @@ -0,0 +1,58 @@ +// Common types shared across OpenAI API implementations + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +// ============= Shared Request Components ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct StreamOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub include_usage: Option, +} + +// ============= Usage Tracking ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_tokens_details: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionTokensDetails { + pub reasoning_tokens: Option, +} + +// ============= Logprobs Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct LogProbs { + pub tokens: Vec, + pub token_logprobs: Vec>, + pub top_logprobs: Vec>>, + pub text_offset: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatLogProbs { + pub content: Option>, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatLogProbsContent { + pub token: String, + pub logprob: f32, + pub bytes: Option>, + pub top_logprobs: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct TopLogProb { + pub token: String, + pub logprob: f32, + pub bytes: Option>, +} diff --git a/sgl-router/src/protocols/openai/completions/mod.rs b/sgl-router/src/protocols/openai/completions/mod.rs new file mode 100644 index 000000000..c87dbbfe5 --- /dev/null +++ b/sgl-router/src/protocols/openai/completions/mod.rs @@ -0,0 +1,10 @@ +// Completions API module (v1/completions) + +pub mod request; +pub mod response; + +// Re-export main types for convenience +pub use request::CompletionRequest; +pub use response::{ + CompletionChoice, CompletionResponse, CompletionStreamChoice, CompletionStreamResponse, +}; diff --git a/sgl-router/src/protocols/openai/completions/request.rs b/sgl-router/src/protocols/openai/completions/request.rs new file mode 100644 index 000000000..c340dc6a5 --- /dev/null +++ b/sgl-router/src/protocols/openai/completions/request.rs @@ -0,0 +1,158 @@ +// Completions API request types (v1/completions) - DEPRECATED but still supported + +use crate::protocols::common::{default_true, GenerationRequest, LoRAPath, StringOrArray}; +use crate::protocols::openai::common::StreamOptions; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionRequest { + /// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang) + pub model: String, + + /// The prompt(s) to generate completions for + pub prompt: StringOrArray, + + /// The suffix that comes after a completion of inserted text + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, + + /// The maximum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// What sampling temperature to use, between 0 and 2 + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// An alternative to sampling with temperature (nucleus sampling) + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// How many completions to generate for each prompt + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + + /// Whether to stream back partial progress + #[serde(default)] + pub stream: bool, + + /// Options for streaming response + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + + /// Include the log probabilities on the logprobs most likely tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + + /// Echo back the prompt in addition to the completion + #[serde(default)] + pub echo: bool, + + /// Up to 4 sequences where the API will stop generating further tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + + /// Generates best_of completions server-side and returns the "best" + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of: Option, + + /// Modify the likelihood of specified tokens appearing in the completion + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + + /// A unique identifier representing your end-user + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// If specified, our system will make a best effort to sample deterministically + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + + // ============= SGLang Extensions ============= + /// Top-k sampling parameter (-1 to disable) + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + + /// Min-p nucleus sampling parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub min_p: Option, + + /// Minimum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + + /// Repetition penalty for reducing repetitive text + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + + /// Regex constraint for output generation + #[serde(skip_serializing_if = "Option::is_none")] + pub regex: Option, + + /// EBNF grammar constraint for structured output + #[serde(skip_serializing_if = "Option::is_none")] + pub ebnf: Option, + + /// JSON schema constraint for structured output + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, + + /// Specific token IDs to use as stop conditions + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + + /// Skip trimming stop tokens from output + #[serde(default)] + pub no_stop_trim: bool, + + /// Ignore end-of-sequence tokens during generation + #[serde(default)] + pub ignore_eos: bool, + + /// Skip special tokens during detokenization + #[serde(default = "default_true")] + pub skip_special_tokens: bool, + + // ============= SGLang Extensions ============= + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Return model hidden states + #[serde(default)] + pub return_hidden_states: bool, + + /// Additional fields including bootstrap info for PD routing + #[serde(flatten)] + pub other: serde_json::Map, +} + +impl GenerationRequest for CompletionRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn extract_text_for_routing(&self) -> String { + match &self.prompt { + StringOrArray::String(s) => s.clone(), + StringOrArray::Array(v) => v.join(" "), + } + } +} diff --git a/sgl-router/src/protocols/openai/completions/response.rs b/sgl-router/src/protocols/openai/completions/response.rs new file mode 100644 index 000000000..4734ba134 --- /dev/null +++ b/sgl-router/src/protocols/openai/completions/response.rs @@ -0,0 +1,56 @@ +// Completions API response types + +use crate::protocols::openai::common::{LogProbs, Usage}; +use serde::{Deserialize, Serialize}; + +// ============= Regular Response ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionResponse { + pub id: String, + pub object: String, // "text_completion" + pub created: u64, + pub model: String, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionChoice { + pub text: String, + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, // "stop", "length", "content_filter", etc. + /// Information about which stop condition was matched + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, // Can be string or integer + /// Hidden states from the model (SGLang extension) + #[serde(skip_serializing_if = "Option::is_none")] + pub hidden_states: Option>, +} + +// ============= Streaming Response ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionStreamResponse { + pub id: String, + pub object: String, // "text_completion" + pub created: u64, + pub choices: Vec, + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionStreamChoice { + pub text: String, + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, +} diff --git a/sgl-router/src/protocols/openai/errors.rs b/sgl-router/src/protocols/openai/errors.rs new file mode 100644 index 000000000..9ec6b2e0b --- /dev/null +++ b/sgl-router/src/protocols/openai/errors.rs @@ -0,0 +1,19 @@ +// OpenAI API error response types + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ErrorResponse { + pub error: ErrorDetail, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ErrorDetail { + pub message: String, + #[serde(rename = "type")] + pub error_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub param: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, +} diff --git a/sgl-router/src/protocols/openai/mod.rs b/sgl-router/src/protocols/openai/mod.rs new file mode 100644 index 000000000..83c7ddfba --- /dev/null +++ b/sgl-router/src/protocols/openai/mod.rs @@ -0,0 +1,7 @@ +// OpenAI protocol module +// This module contains all OpenAI API-compatible types and future validation logic + +pub mod chat; +pub mod common; +pub mod completions; +pub mod errors; diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index bfcb5ad2e..83789852b 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -9,7 +9,10 @@ use axum::{ }; use std::fmt::Debug; -use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; +use crate::protocols::{ + generate::GenerateRequest, + openai::{chat::ChatCompletionRequest, completions::CompletionRequest}, +}; pub mod factory; pub mod header_utils; diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index 0d70f4ab9..cba55c5cd 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -11,8 +11,15 @@ use crate::core::{ RetryExecutor, Worker, WorkerFactory, WorkerLoadGuard, WorkerType, }; use crate::metrics::RouterMetrics; -use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::policies::LoadBalancingPolicy; +use crate::protocols::{ + common::StringOrArray, + generate::GenerateRequest, + openai::{ + chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}, + completions::CompletionRequest, + }, +}; use crate::routers::{RouterTrait, WorkerManagement}; use async_trait::async_trait; use axum::{ @@ -616,7 +623,7 @@ impl PDRouter { // Helper to determine batch size from a GenerateRequest fn get_generate_batch_size(req: &GenerateRequest) -> Option { // Check prompt array - if let Some(crate::openai_api_types::StringOrArray::Array(arr)) = &req.prompt { + if let Some(StringOrArray::Array(arr)) = &req.prompt { if !arr.is_empty() { return Some(arr.len()); } @@ -645,7 +652,7 @@ impl PDRouter { // Helper to determine batch size from a CompletionRequest fn get_completion_batch_size(req: &CompletionRequest) -> Option { // Check prompt array - if let crate::openai_api_types::StringOrArray::Array(arr) = &req.prompt { + if let StringOrArray::Array(arr) = &req.prompt { if !arr.is_empty() { return Some(arr.len()); } @@ -1724,10 +1731,8 @@ impl RouterTrait for PDRouter { .as_deref() .or_else(|| { body.prompt.as_ref().and_then(|p| match p { - crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), - crate::openai_api_types::StringOrArray::Array(v) => { - v.first().map(|s| s.as_str()) - } + StringOrArray::String(s) => Some(s.as_str()), + StringOrArray::Array(v) => v.first().map(|s| s.as_str()), }) }) .map(|s| s.to_string()) @@ -1763,13 +1768,11 @@ impl RouterTrait for PDRouter { // Extract text for cache-aware routing let request_text = if self.policies_need_request_text() { body.messages.first().and_then(|msg| match msg { - crate::openai_api_types::ChatMessage::User { content, .. } => match content { - crate::openai_api_types::UserMessageContent::Text(text) => Some(text.clone()), - crate::openai_api_types::UserMessageContent::Parts(_) => None, + ChatMessage::User { content, .. } => match content { + UserMessageContent::Text(text) => Some(text.clone()), + UserMessageContent::Parts(_) => None, }, - crate::openai_api_types::ChatMessage::System { content, .. } => { - Some(content.clone()) - } + ChatMessage::System { content, .. } => Some(content.clone()), _ => None, }) } else { @@ -1804,10 +1807,8 @@ impl RouterTrait for PDRouter { // Extract text for cache-aware routing let request_text = if self.policies_need_request_text() { match &body.prompt { - crate::openai_api_types::StringOrArray::String(s) => Some(s.clone()), - crate::openai_api_types::StringOrArray::Array(v) => { - v.first().map(|s| s.to_string()) - } + StringOrArray::String(s) => Some(s.clone()), + StringOrArray::Array(v) => v.first().map(|s| s.to_string()), } } else { None diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index 87c8b70dd..2c5d278ea 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -8,8 +8,12 @@ use crate::core::{ RetryExecutor, Worker, WorkerFactory, WorkerType, }; use crate::metrics::RouterMetrics; -use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::policies::LoadBalancingPolicy; +use crate::protocols::{ + common::GenerationRequest, + generate::GenerateRequest, + openai::{chat::ChatCompletionRequest, completions::CompletionRequest}, +}; use crate::routers::{RouterTrait, WorkerManagement}; use axum::{ body::Body, @@ -453,9 +457,7 @@ impl Router { Some(available[idx].clone_worker()) } - pub async fn route_typed_request< - T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone, - >( + pub async fn route_typed_request( &self, headers: Option<&HeaderMap>, typed_req: &T, diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 9746e5845..85e7648af 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,7 +1,10 @@ use crate::config::RouterConfig; use crate::logging::{self, LoggingConfig}; use crate::metrics::{self, PrometheusConfig}; -use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; +use crate::protocols::{ + generate::GenerateRequest, + openai::{chat::ChatCompletionRequest, completions::CompletionRequest}, +}; use crate::routers::{RouterFactory, RouterTrait}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use axum::{ diff --git a/sgl-router/tests/benchmark_integration.rs b/sgl-router/tests/benchmark_integration.rs index 16406c461..6787d8695 100644 --- a/sgl-router/tests/benchmark_integration.rs +++ b/sgl-router/tests/benchmark_integration.rs @@ -5,9 +5,13 @@ use serde_json::{from_str, to_string, to_value}; use sglang_router_rs::core::{BasicWorker, WorkerType}; -use sglang_router_rs::openai_api_types::{ - ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, - SamplingParams, StringOrArray, UserMessageContent, +use sglang_router_rs::protocols::{ + common::StringOrArray, + generate::{GenerateParameters, GenerateRequest, SamplingParams}, + openai::{ + chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}, + completions::CompletionRequest, + }, }; /// Create a default GenerateRequest for benchmarks with minimal fields set